Create handler.py
Browse files- handler.py +115 -0
handler.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from sentence_transformers import SentenceTransformer, util
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
4 |
+
from PIL import Image
|
5 |
+
from serpapi import GoogleSearch
|
6 |
+
from keybert import KeyBERT
|
7 |
+
from typing import Dict, Any, List
|
8 |
+
import os
|
9 |
+
model_id = "vikhyatk/moondream2"
|
10 |
+
revision = "2024-08-26"
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
model_id, trust_remote_code=True, revision=revision
|
13 |
+
)
|
14 |
+
|
15 |
+
model.to('cuda')
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
17 |
+
|
18 |
+
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
19 |
+
sentence_model = SentenceTransformer(model_name, device='cuda')
|
20 |
+
|
21 |
+
class ProductSearcher:
|
22 |
+
def __init__(self, user_input, image_path):
|
23 |
+
self.user_input = user_input
|
24 |
+
self.image_path = image_path
|
25 |
+
self.predefined_questions = [
|
26 |
+
"tôi muốn mua sản phẩm này",
|
27 |
+
"tôi muốn thông tin về sản phẩm",
|
28 |
+
"tôi muốn biết giá cái này"
|
29 |
+
]
|
30 |
+
self.prompts = [
|
31 |
+
"Descibe product in image with it color. Only answer in one sentence"
|
32 |
+
"Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
|
33 |
+
"Estimate the price of the product and provide a detailed description of the product"
|
34 |
+
]
|
35 |
+
self.description = ''
|
36 |
+
self.keyphrases = []
|
37 |
+
self.kw_model= KeyBERT()
|
38 |
+
|
39 |
+
|
40 |
+
def get_most_similar_sentence(self):
|
41 |
+
user_input_embedding = sentence_model.encode(self.user_input)
|
42 |
+
predefined_embeddings = sentence_model.encode(self.predefined_questions)
|
43 |
+
similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
|
44 |
+
most_similar_index = similarity_scores.argmax().item()
|
45 |
+
return self.prompts[most_similar_index]
|
46 |
+
|
47 |
+
def generate_description(self):
|
48 |
+
prompt = self.get_most_similar_sentence()
|
49 |
+
image = Image.open(self.image_path)
|
50 |
+
|
51 |
+
enc_image = model.encode_image(image)
|
52 |
+
self.description = model.answer_question(enc_image, prompt, tokenizer)
|
53 |
+
|
54 |
+
def extract_keyphrases(self):
|
55 |
+
self.keyphrases = self.kw_model.extract_keywords(self.description)
|
56 |
+
def search_products(self, k=3):
|
57 |
+
# Concatenate keyphrases to form a question
|
58 |
+
q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
|
59 |
+
question = " ".join(q)
|
60 |
+
search = GoogleSearch({
|
61 |
+
"engine": "google",
|
62 |
+
# "q": self.keyphrases[0]['word'],
|
63 |
+
"q":question,
|
64 |
+
"tbm": "shop",
|
65 |
+
"api_key": os.environ["API_KEY"]
|
66 |
+
})
|
67 |
+
results = search.get_dict()
|
68 |
+
# Extract top k products from the search results
|
69 |
+
products = results.get('shopping_results', [])[:k]
|
70 |
+
return products
|
71 |
+
|
72 |
+
def run(self, k=3):
|
73 |
+
self.generate_description()
|
74 |
+
self.extract_keyphrases()
|
75 |
+
results = self.search_products(k)
|
76 |
+
return results
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
class EndpointHandler:
|
81 |
+
def __init__(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
85 |
+
"""
|
86 |
+
data args:
|
87 |
+
inputs (:obj: dict): A dictionary containing the inputs.
|
88 |
+
message (:obj: str): The user message.
|
89 |
+
image (:obj: str): The base64-encoded image content.
|
90 |
+
Return:
|
91 |
+
A list of dictionaries containing the product search results.
|
92 |
+
"""
|
93 |
+
inputs = data.get("inputs", {})
|
94 |
+
message = inputs.get("message")
|
95 |
+
image_content = inputs.get("image")
|
96 |
+
|
97 |
+
# Decode the base64-encoded image content
|
98 |
+
image_bytes = base64.b64decode(image_content)
|
99 |
+
|
100 |
+
# Save the image to a temporary file
|
101 |
+
image_path = "input/temp_image.jpg"
|
102 |
+
os.makedirs("input", exist_ok=True)
|
103 |
+
with open(image_path, "wb") as f:
|
104 |
+
f.write(image_bytes)
|
105 |
+
|
106 |
+
# Initialize ProductSearcher with the message and image path
|
107 |
+
searcher = ProductSearcher(message, image_path)
|
108 |
+
|
109 |
+
# Run the search and get results
|
110 |
+
results = searcher.run(k=3)
|
111 |
+
|
112 |
+
# Return the search results
|
113 |
+
return results
|
114 |
+
|
115 |
+
|