File size: 4,237 Bytes
7baafc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
from PIL import Image
import requests
from transformers import AutoProcessor, CLIPVisionModel
import numpy as np
import torch
from openai import OpenAI
from functools import lru_cache

from utils.utils import pil_image_to_data_url

class ImageSearcher:
    def __init__(self, image_collection, cars_collection, search_index):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
        self.processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
        self.client = OpenAI()
        self.cars_collection = cars_collection
        self.image_collection = image_collection
        self.search_index = search_index
        self.car_list = self.cars_collection.distinct("makeModel")
        self.logger = logging.getLogger(__name__)

    def compute_image_embeddings(self, image):
        try:
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)
            with torch.no_grad():
                outputs = self.model(**inputs)
            pooled_output = outputs.pooler_output 
            pooled_output_np = pooled_output.cpu().numpy().reshape(-1).tolist()
            return pooled_output_np
        except Exception as e:
            self.logger.error(f"Error processing image: {str(e)}")
            return None

    def classify_car(self, image):

        image_url = pil_image_to_data_url(image)

        prompt = f"Classify the car into one of the following categories: {self.car_list}. If the car is not in the list, return 'unknown'. Return the label only."
        try:
            response = self.client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image_url", "image_url": {"url": image_url}},
                        ],
                    }
                ],
                max_tokens=300,
            )
            self.logger.info(f"Classified car: {response.choices[0].message.content}")
            return response.choices[0].message.content
        except Exception as e:
            self.logger.error(f"Error classifying car {image_url}: {str(e)}")
            return "unknown"

    def __call__(self, image, k=3):

        image = image.convert("RGB")

        img_embedding = self.compute_image_embeddings(image)
        if img_embedding is None:
            self.logger.warning(f"Invalid image")
            return {"error": "Invalid image provided"}

        detected_car = self.classify_car(image)

        pipeline = [
            {
                '$vectorSearch': {
                    'index': self.search_index,
                    'filter': {'makeModel': detected_car} if detected_car != 'unknown' else {},
                    'path': 'embedding', 
                    'queryVector': img_embedding,
                    'limit': k,
                    'numCandidates': 100,
                }
            },
            {
                '$lookup': {
                    'from': self.cars_collection.name,
                    'localField': 'id',
                    'foreignField': '_id',
                    'as': 'inventoryDetails'
                }
            },
            {'$unwind': '$inventoryDetails'},
            {'$replaceRoot': {'newRoot': '$inventoryDetails'}},
            {
                '$project': {
                    'description': 0,         # Exclude 'description'
                    'variants': 0,            # Exclude 'variants'
                    'review_embedding': 0,     # Exclude 'review_embedding'
                    'feature_embedding':0
                }
            }
        ]

        try:
            result = list(self.image_collection.aggregate(pipeline))
            return {"detected_car": detected_car, "results": result}
        except Exception as e:
            self.logger.error(f"Error searching for similar cars: {str(e)}")
            return {"detected_car": detected_car, "error": "An error occurred during the search"}