image_recommender / handler.py
clfegg's picture
Update handler.py
d050a97 verified
raw
history blame
4.24 kB
import os
current_dir = os.getcwd()
os.environ['HF_HOME'] = os.path.join(current_dir)
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from serpapi import GoogleSearch
from keybert import KeyBERT
from typing import Dict, Any, List
import base64
model_id = "vikhyatk/moondream2"
revision = "2024-08-26"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision
)
model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
sentence_model = SentenceTransformer(model_name, device='cuda')
class ProductSearcher:
def __init__(self, user_input, image_path):
self.user_input = user_input
self.image_path = image_path
self.predefined_questions = [
"tôi muốn mua sản phẩm này",
"tôi muốn thông tin về sản phẩm",
"tôi muốn biết giá cái này"
]
self.prompts = [
"Descibe product in image with it color. Only answer in one sentence",
"Describe the product in detail and provide information about the product. If you don't know the product, you can describe the image",
"Estimate the price of the product and provide a detailed description of the product"
]
self.description = ''
self.keyphrases = []
self.kw_model= KeyBERT()
def get_most_similar_sentence(self):
user_input_embedding = sentence_model.encode(self.user_input)
predefined_embeddings = sentence_model.encode(self.predefined_questions)
similarity_scores = util.pytorch_cos_sim(user_input_embedding, predefined_embeddings)
most_similar_index = similarity_scores.argmax().item()
return self.prompts[most_similar_index]
def generate_description(self):
prompt = self.get_most_similar_sentence()
image = Image.open(self.image_path)
enc_image = model.encode_image(image)
self.description = model.answer_question(enc_image, prompt, tokenizer)
def extract_keyphrases(self):
self.keyphrases = self.kw_model.extract_keywords(self.description)
def search_products(self, k=3):
# Concatenate keyphrases to form a question
q = [keyword[0] for keyword in self.keyphrases if keyword[0] != 'image']
question = " ".join(q)
search = GoogleSearch({
"engine": "google",
# "q": self.keyphrases[0]['word'],
"q":question,
"tbm": "shop",
"api_key": os.environ["API_KEY"]
})
results = search.get_dict()
# Extract top k products from the search results
products = results.get('shopping_results', [])[:k]
return products
def run(self, k=3):
self.generate_description()
self.extract_keyphrases()
results = self.search_products(k)
return results
class EndpointHandler:
def __init__(self,path=""):
pass
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: dict): A dictionary containing the inputs.
message (:obj: str): The user message.
image (:obj: str): The base64-encoded image content.
Return:
A list of dictionaries containing the product search results.
"""
inputs = data.get("inputs", {})
message = inputs.get("message")
image_content = inputs.get("image")
# Decode the base64-encoded image content
image_bytes = base64.b64decode(image_content)
# Save the image to a temporary file
image_path = "input/temp_image.jpg"
os.makedirs("input", exist_ok=True)
with open(image_path, "wb") as f:
f.write(image_bytes)
# Initialize ProductSearcher with the message and image path
searcher = ProductSearcher(message, image_path)
# Run the search and get results
results = searcher.run(k=3)
# Return the search results
return results