Spaces:
Runtime error
Runtime error
step 6
Browse files
app.py
CHANGED
@@ -24,12 +24,6 @@ def preprocess_image(image):
|
|
24 |
])
|
25 |
return preprocess(pil_image).unsqueeze(0)
|
26 |
|
27 |
-
def generate_output_html(top_3_products):
|
28 |
-
output_html = ""
|
29 |
-
for i, product in enumerate(top_3_products):
|
30 |
-
output_html += f'<div class="product"><img src="{product["image_url"]}" alt="{product["name"]}" /><br>{product["name"]}<br>Price: {product["price"]}</div>'
|
31 |
-
return output_html
|
32 |
-
|
33 |
def encode_text(text):
|
34 |
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
|
35 |
return inputs
|
@@ -38,6 +32,34 @@ def encode_image(image):
|
|
38 |
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
|
39 |
return inputs
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def send_message(txt, btn):
|
42 |
if btn is not None:
|
43 |
image = preprocess_image(btn)
|
|
|
24 |
])
|
25 |
return preprocess(pil_image).unsqueeze(0)
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def encode_text(text):
|
28 |
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
|
29 |
return inputs
|
|
|
32 |
inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
|
33 |
return inputs
|
34 |
|
35 |
+
def calculate_similarities(query_image, query_text):
|
36 |
+
query_image_features = model.get_image_features(query_image)
|
37 |
+
query_text_features = model.get_text_features(query_text)
|
38 |
+
|
39 |
+
similarities = []
|
40 |
+
for product in deepfashion_database:
|
41 |
+
product_image_features = torch.Tensor(product["image_features"])
|
42 |
+
product_text_features = torch.Tensor(product["text_features"])
|
43 |
+
|
44 |
+
image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
|
45 |
+
text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
|
46 |
+
|
47 |
+
similarity_score = image_similarity * text_similarity
|
48 |
+
similarities.append(similarity_score)
|
49 |
+
|
50 |
+
return similarities
|
51 |
+
|
52 |
+
def initial_query(image, text):
|
53 |
+
query_image = encode_image(image)
|
54 |
+
query_text = encode_text(text)
|
55 |
+
|
56 |
+
similarities = calculate_similarities(query_image, query_text)
|
57 |
+
sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
|
58 |
+
top_3_indices = sorted_indices[:3]
|
59 |
+
|
60 |
+
top_3_products = [deepfashion_database[i] for i in top_3_indices]
|
61 |
+
return top_3_products
|
62 |
+
|
63 |
def send_message(txt, btn):
|
64 |
if btn is not None:
|
65 |
image = preprocess_image(btn)
|