clfegg commited on
Commit
315a218
·
verified ·
1 Parent(s): 1643107

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +115 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from PIL import Image
5
+ from serpapi import GoogleSearch
6
+ from keybert import KeyBERT
7
+ from typing import Dict, Any, List
8
+ import os
9
+ model_id = "vikhyatk/moondream2"
10
+ revision = "2024-08-26"
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id, trust_remote_code=True, revision=revision
13
+ )
14
+
15
+ model.to('cuda')
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
17
+
18
+ model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
19
+ sentence_model = SentenceTransformer(model_name, device='cuda')
20
+
21
+ class ProductSearcher:
22
+ def __init__(self, user_input, image_path):
23
+ self.user_input = user_input
24
+ self.image_path = image_path
25
+ self.predefined_questions = [
26
+ "tôi muốn mua sản phẩm này",
27
+ "tôi muốn thông tin về sản phẩm",
28
+ "tôi muốn biết giá cái này"
29
+ ]
30
+ self.prompts = [
31
+ "Descibe product in image with it color. Only answer in one sentence"
32
+ "Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
33
+ "Estimate the price of the product and provide a detailed description of the product"
34
+ ]
35
+ self.description = ''
36
+ self.keyphrases = []
37
+ self.kw_model= KeyBERT()
38
+
39
+
40
+ def get_most_similar_sentence(self):
41
+ user_input_embedding = sentence_model.encode(self.user_input)
42
+ predefined_embeddings = sentence_model.encode(self.predefined_questions)
43
+ similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
44
+ most_similar_index = similarity_scores.argmax().item()
45
+ return self.prompts[most_similar_index]
46
+
47
+ def generate_description(self):
48
+ prompt = self.get_most_similar_sentence()
49
+ image = Image.open(self.image_path)
50
+
51
+ enc_image = model.encode_image(image)
52
+ self.description = model.answer_question(enc_image, prompt, tokenizer)
53
+
54
+ def extract_keyphrases(self):
55
+ self.keyphrases = self.kw_model.extract_keywords(self.description)
56
+ def search_products(self, k=3):
57
+ # Concatenate keyphrases to form a question
58
+ q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
59
+ question = " ".join(q)
60
+ search = GoogleSearch({
61
+ "engine": "google",
62
+ # "q": self.keyphrases[0]['word'],
63
+ "q":question,
64
+ "tbm": "shop",
65
+ "api_key": os.environ["API_KEY"]
66
+ })
67
+ results = search.get_dict()
68
+ # Extract top k products from the search results
69
+ products = results.get('shopping_results', [])[:k]
70
+ return products
71
+
72
+ def run(self, k=3):
73
+ self.generate_description()
74
+ self.extract_keyphrases()
75
+ results = self.search_products(k)
76
+ return results
77
+
78
+
79
+
80
+ class EndpointHandler:
81
+ def __init__(self):
82
+ pass
83
+
84
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
85
+ """
86
+ data args:
87
+ inputs (:obj: dict): A dictionary containing the inputs.
88
+ message (:obj: str): The user message.
89
+ image (:obj: str): The base64-encoded image content.
90
+ Return:
91
+ A list of dictionaries containing the product search results.
92
+ """
93
+ inputs = data.get("inputs", {})
94
+ message = inputs.get("message")
95
+ image_content = inputs.get("image")
96
+
97
+ # Decode the base64-encoded image content
98
+ image_bytes = base64.b64decode(image_content)
99
+
100
+ # Save the image to a temporary file
101
+ image_path = "input/temp_image.jpg"
102
+ os.makedirs("input", exist_ok=True)
103
+ with open(image_path, "wb") as f:
104
+ f.write(image_bytes)
105
+
106
+ # Initialize ProductSearcher with the message and image path
107
+ searcher = ProductSearcher(message, image_path)
108
+
109
+ # Run the search and get results
110
+ results = searcher.run(k=3)
111
+
112
+ # Return the search results
113
+ return results
114
+
115
+