Spaces:
Runtime error
Runtime error
File size: 4,237 Bytes
7baafc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import logging
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPVisionModel
import numpy as np
import torch
from openai import OpenAI
from functools import lru_cache
from utils.utils import pil_image_to_data_url
class ImageSearcher:
def __init__(self, image_collection, cars_collection, search_index):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.client = OpenAI()
self.cars_collection = cars_collection
self.image_collection = image_collection
self.search_index = search_index
self.car_list = self.cars_collection.distinct("makeModel")
self.logger = logging.getLogger(__name__)
def compute_image_embeddings(self, image):
try:
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
pooled_output = outputs.pooler_output
pooled_output_np = pooled_output.cpu().numpy().reshape(-1).tolist()
return pooled_output_np
except Exception as e:
self.logger.error(f"Error processing image: {str(e)}")
return None
def classify_car(self, image):
image_url = pil_image_to_data_url(image)
prompt = f"Classify the car into one of the following categories: {self.car_list}. If the car is not in the list, return 'unknown'. Return the label only."
try:
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
max_tokens=300,
)
self.logger.info(f"Classified car: {response.choices[0].message.content}")
return response.choices[0].message.content
except Exception as e:
self.logger.error(f"Error classifying car {image_url}: {str(e)}")
return "unknown"
def __call__(self, image, k=3):
image = image.convert("RGB")
img_embedding = self.compute_image_embeddings(image)
if img_embedding is None:
self.logger.warning(f"Invalid image")
return {"error": "Invalid image provided"}
detected_car = self.classify_car(image)
pipeline = [
{
'$vectorSearch': {
'index': self.search_index,
'filter': {'makeModel': detected_car} if detected_car != 'unknown' else {},
'path': 'embedding',
'queryVector': img_embedding,
'limit': k,
'numCandidates': 100,
}
},
{
'$lookup': {
'from': self.cars_collection.name,
'localField': 'id',
'foreignField': '_id',
'as': 'inventoryDetails'
}
},
{'$unwind': '$inventoryDetails'},
{'$replaceRoot': {'newRoot': '$inventoryDetails'}},
{
'$project': {
'description': 0, # Exclude 'description'
'variants': 0, # Exclude 'variants'
'review_embedding': 0, # Exclude 'review_embedding'
'feature_embedding':0
}
}
]
try:
result = list(self.image_collection.aggregate(pipeline))
return {"detected_car": detected_car, "results": result}
except Exception as e:
self.logger.error(f"Error searching for similar cars: {str(e)}")
return {"detected_car": detected_car, "error": "An error occurred during the search"}
|