samnji commited on
Commit
fbbda08
·
1 Parent(s): 9426ca5
Files changed (1) hide show
  1. app.py +28 -6
app.py CHANGED
@@ -24,12 +24,6 @@ def preprocess_image(image):
24
  ])
25
  return preprocess(pil_image).unsqueeze(0)
26
 
27
- def generate_output_html(top_3_products):
28
- output_html = ""
29
- for i, product in enumerate(top_3_products):
30
- output_html += f'<div class="product"><img src="{product["image_url"]}" alt="{product["name"]}" /><br>{product["name"]}<br>Price: {product["price"]}</div>'
31
- return output_html
32
-
33
  def encode_text(text):
34
  inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
35
  return inputs
@@ -38,6 +32,34 @@ def encode_image(image):
38
  inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
39
  return inputs
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def send_message(txt, btn):
42
  if btn is not None:
43
  image = preprocess_image(btn)
 
24
  ])
25
  return preprocess(pil_image).unsqueeze(0)
26
 
 
 
 
 
 
 
27
  def encode_text(text):
28
  inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True)
29
  return inputs
 
32
  inputs = processor(images=[image], return_tensors="pt", padding=True, truncation=True)
33
  return inputs
34
 
35
+ def calculate_similarities(query_image, query_text):
36
+ query_image_features = model.get_image_features(query_image)
37
+ query_text_features = model.get_text_features(query_text)
38
+
39
+ similarities = []
40
+ for product in deepfashion_database:
41
+ product_image_features = torch.Tensor(product["image_features"])
42
+ product_text_features = torch.Tensor(product["text_features"])
43
+
44
+ image_similarity = torch.nn.CosineSimilarity(dim=-1)(query_image_features, product_image_features)
45
+ text_similarity = torch.nn.CosineSimilarity(dim=-1)(query_text_features, product_text_features)
46
+
47
+ similarity_score = image_similarity * text_similarity
48
+ similarities.append(similarity_score)
49
+
50
+ return similarities
51
+
52
+ def initial_query(image, text):
53
+ query_image = encode_image(image)
54
+ query_text = encode_text(text)
55
+
56
+ similarities = calculate_similarities(query_image, query_text)
57
+ sorted_indices = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)
58
+ top_3_indices = sorted_indices[:3]
59
+
60
+ top_3_products = [deepfashion_database[i] for i in top_3_indices]
61
+ return top_3_products
62
+
63
  def send_message(txt, btn):
64
  if btn is not None:
65
  image = preprocess_image(btn)