import logging from PIL import Image import requests from transformers import AutoProcessor, CLIPVisionModel import numpy as np import torch from openai import OpenAI from functools import lru_cache from utils.utils import pil_image_to_data_url class ImageSearcher: def __init__(self, image_collection, cars_collection, search_index): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device) self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") self.client = OpenAI() self.cars_collection = cars_collection self.image_collection = image_collection self.search_index = search_index self.car_list = self.cars_collection.distinct("makeModel") self.logger = logging.getLogger(__name__) def compute_image_embeddings(self, image): try: inputs = self.processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs) pooled_output = outputs.pooler_output pooled_output_np = pooled_output.cpu().numpy().reshape(-1).tolist() return pooled_output_np except Exception as e: self.logger.error(f"Error processing image: {str(e)}") return None def classify_car(self, image): image_url = pil_image_to_data_url(image) prompt = f"Classify the car into one of the following categories: {self.car_list}. If the car is not in the list, return 'unknown'. Return the label only." try: response = self.client.chat.completions.create( model="gpt-4o-mini", messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}, ], } ], max_tokens=300, ) self.logger.info(f"Classified car: {response.choices[0].message.content}") return response.choices[0].message.content except Exception as e: self.logger.error(f"Error classifying car {image_url}: {str(e)}") return "unknown" def __call__(self, image, k=3): image = image.convert("RGB") img_embedding = self.compute_image_embeddings(image) if img_embedding is None: self.logger.warning(f"Invalid image") return {"error": "Invalid image provided"} detected_car = self.classify_car(image) pipeline = [ { '$vectorSearch': { 'index': self.search_index, 'filter': {'makeModel': detected_car} if detected_car != 'unknown' else {}, 'path': 'embedding', 'queryVector': img_embedding, 'limit': k, 'numCandidates': 100, } }, { '$lookup': { 'from': self.cars_collection.name, 'localField': 'id', 'foreignField': '_id', 'as': 'inventoryDetails' } }, {'$unwind': '$inventoryDetails'}, {'$replaceRoot': {'newRoot': '$inventoryDetails'}}, { '$project': { 'description': 0, # Exclude 'description' 'variants': 0, # Exclude 'variants' 'review_embedding': 0, # Exclude 'review_embedding' 'feature_embedding':0 } } ] try: result = list(self.image_collection.aggregate(pipeline)) return {"detected_car": detected_car, "results": result} except Exception as e: self.logger.error(f"Error searching for similar cars: {str(e)}") return {"detected_car": detected_car, "error": "An error occurred during the search"}