Spaces:
Runtime error
Runtime error
import logging | |
from PIL import Image | |
import requests | |
from transformers import AutoProcessor, CLIPVisionModel | |
import numpy as np | |
import torch | |
from openai import OpenAI | |
from functools import lru_cache | |
from utils.utils import pil_image_to_data_url | |
class ImageSearcher: | |
def __init__(self, image_collection, cars_collection, search_index): | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device) | |
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.client = OpenAI() | |
self.cars_collection = cars_collection | |
self.image_collection = image_collection | |
self.search_index = search_index | |
self.car_list = self.cars_collection.distinct("makeModel") | |
self.logger = logging.getLogger(__name__) | |
def compute_image_embeddings(self, image): | |
try: | |
inputs = self.processor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
pooled_output = outputs.pooler_output | |
pooled_output_np = pooled_output.cpu().numpy().reshape(-1).tolist() | |
return pooled_output_np | |
except Exception as e: | |
self.logger.error(f"Error processing image: {str(e)}") | |
return None | |
def classify_car(self, image): | |
image_url = pil_image_to_data_url(image) | |
prompt = f"Classify the car into one of the following categories: {self.car_list}. If the car is not in the list, return 'unknown'. Return the label only." | |
try: | |
response = self.client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": image_url}}, | |
], | |
} | |
], | |
max_tokens=300, | |
) | |
self.logger.info(f"Classified car: {response.choices[0].message.content}") | |
return response.choices[0].message.content | |
except Exception as e: | |
self.logger.error(f"Error classifying car {image_url}: {str(e)}") | |
return "unknown" | |
def __call__(self, image, k=3): | |
image = image.convert("RGB") | |
img_embedding = self.compute_image_embeddings(image) | |
if img_embedding is None: | |
self.logger.warning(f"Invalid image") | |
return {"error": "Invalid image provided"} | |
detected_car = self.classify_car(image) | |
pipeline = [ | |
{ | |
'$vectorSearch': { | |
'index': self.search_index, | |
'filter': {'makeModel': detected_car} if detected_car != 'unknown' else {}, | |
'path': 'embedding', | |
'queryVector': img_embedding, | |
'limit': k, | |
'numCandidates': 100, | |
} | |
}, | |
{ | |
'$lookup': { | |
'from': self.cars_collection.name, | |
'localField': 'id', | |
'foreignField': '_id', | |
'as': 'inventoryDetails' | |
} | |
}, | |
{'$unwind': '$inventoryDetails'}, | |
{'$replaceRoot': {'newRoot': '$inventoryDetails'}}, | |
{ | |
'$project': { | |
'description': 0, # Exclude 'description' | |
'variants': 0, # Exclude 'variants' | |
'review_embedding': 0, # Exclude 'review_embedding' | |
'feature_embedding':0 | |
} | |
} | |
] | |
try: | |
result = list(self.image_collection.aggregate(pipeline)) | |
return {"detected_car": detected_car, "results": result} | |
except Exception as e: | |
self.logger.error(f"Error searching for similar cars: {str(e)}") | |
return {"detected_car": detected_car, "error": "An error occurred during the search"} | |