File size: 4,238 Bytes
57c38eb
 
 
315a218
 
 
 
 
 
5ba62fb
ea689f4
315a218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aedf570
315a218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
current_dir = os.getcwd()
os.environ['HF_HOME'] = os.path.join(current_dir)
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from serpapi import GoogleSearch
from keybert import KeyBERT
from typing import Dict, Any, List
import base64

model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision
)

model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)

model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
sentence_model = SentenceTransformer(model_name, device='cuda')

class ProductSearcher:
    def __init__(self, user_input, image_path):
        self.user_input = user_input
        self.image_path = image_path
        self.predefined_questions = [
            "tôi muốn mua sản phẩm này",
            "tôi muốn thông tin về sản phẩm",
            "tôi muốn biết giá cái này"
        ]
        self.prompts = [
            "Descibe product in image with it color. Only answer in one sentence"
            "Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
            "Estimate the price of the product and provide a detailed description of the product"
        ]
        self.description = ''
        self.keyphrases = []
        self.kw_model= KeyBERT()


    def get_most_similar_sentence(self):
        user_input_embedding = sentence_model.encode(self.user_input)
        predefined_embeddings = sentence_model.encode(self.predefined_questions)
        similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
        most_similar_index = similarity_scores.argmax().item()
        return self.prompts[most_similar_index]

    def generate_description(self):
        prompt = self.get_most_similar_sentence()
        image = Image.open(self.image_path)

        enc_image = model.encode_image(image)
        self.description = model.answer_question(enc_image, prompt, tokenizer)

    def extract_keyphrases(self):
        self.keyphrases = self.kw_model.extract_keywords(self.description)
    def search_products(self, k=3):
        # Concatenate keyphrases to form a question
        q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
        question = " ".join(q)
        search = GoogleSearch({
            "engine": "google",
            # "q": self.keyphrases[0]['word'],
            "q":question,
            "tbm": "shop",
            "api_key": os.environ["API_KEY"]
        })
        results = search.get_dict()
        # Extract top k products from the search results
        products = results.get('shopping_results', [])[:k]
        return products

    def run(self, k=3):
        self.generate_description()
        self.extract_keyphrases()
        results = self.search_products(k)
        return results



class EndpointHandler:
    def __init__(self,path=""):
        pass

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: dict): A dictionary containing the inputs.
                message (:obj: str): The user message.
                image (:obj: str): The base64-encoded image content.
        Return:
            A list of dictionaries containing the product search results.
        """
        inputs = data.get("inputs", {})
        message = inputs.get("message")
        image_content = inputs.get("image")

        # Decode the base64-encoded image content
        image_bytes = base64.b64decode(image_content)

        # Save the image to a temporary file
        image_path = "input/temp_image.jpg"
        os.makedirs("input", exist_ok=True)
        with open(image_path, "wb") as f:
            f.write(image_bytes)

        # Initialize ProductSearcher with the message and image path
        searcher = ProductSearcher(message, image_path)

        # Run the search and get results
        results = searcher.run(k=3)

        # Return the search results
        return results