otoz-smart-search / src /image_search.py
teenaxta's picture
Upload 15 files
7baafc3 verified
import logging
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPVisionModel
import numpy as np
import torch
from openai import OpenAI
from functools import lru_cache
from utils.utils import pil_image_to_data_url
class ImageSearcher:
def __init__(self, image_collection, cars_collection, search_index):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.client = OpenAI()
self.cars_collection = cars_collection
self.image_collection = image_collection
self.search_index = search_index
self.car_list = self.cars_collection.distinct("makeModel")
self.logger = logging.getLogger(__name__)
def compute_image_embeddings(self, image):
try:
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
pooled_output = outputs.pooler_output
pooled_output_np = pooled_output.cpu().numpy().reshape(-1).tolist()
return pooled_output_np
except Exception as e:
self.logger.error(f"Error processing image: {str(e)}")
return None
def classify_car(self, image):
image_url = pil_image_to_data_url(image)
prompt = f"Classify the car into one of the following categories: {self.car_list}. If the car is not in the list, return 'unknown'. Return the label only."
try:
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
max_tokens=300,
)
self.logger.info(f"Classified car: {response.choices[0].message.content}")
return response.choices[0].message.content
except Exception as e:
self.logger.error(f"Error classifying car {image_url}: {str(e)}")
return "unknown"
def __call__(self, image, k=3):
image = image.convert("RGB")
img_embedding = self.compute_image_embeddings(image)
if img_embedding is None:
self.logger.warning(f"Invalid image")
return {"error": "Invalid image provided"}
detected_car = self.classify_car(image)
pipeline = [
{
'$vectorSearch': {
'index': self.search_index,
'filter': {'makeModel': detected_car} if detected_car != 'unknown' else {},
'path': 'embedding',
'queryVector': img_embedding,
'limit': k,
'numCandidates': 100,
}
},
{
'$lookup': {
'from': self.cars_collection.name,
'localField': 'id',
'foreignField': '_id',
'as': 'inventoryDetails'
}
},
{'$unwind': '$inventoryDetails'},
{'$replaceRoot': {'newRoot': '$inventoryDetails'}},
{
'$project': {
'description': 0, # Exclude 'description'
'variants': 0, # Exclude 'variants'
'review_embedding': 0, # Exclude 'review_embedding'
'feature_embedding':0
}
}
]
try:
result = list(self.image_collection.aggregate(pipeline))
return {"detected_car": detected_car, "results": result}
except Exception as e:
self.logger.error(f"Error searching for similar cars: {str(e)}")
return {"detected_car": detected_car, "error": "An error occurred during the search"}