samnji commited on
Commit
47d7fa1
·
1 Parent(s): 86891c2
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import torch
6
+
7
+
8
+ #Load the pre-trained CLIP model and its tokenizer:
9
+ model_name = "openai/clip-vit-base-patch32"
10
+ processor = CLIPProcessor.from_pretrained(model_name)
11
+ model = CLIPModel.from_pretrained(model_name)
12
+
13
+
14
+ #Define the preprocessing function for images:
15
+ def preprocess_image(image):
16
+ preprocess = Compose([
17
+ Resize(256, interpolation=Image.BICUBIC),
18
+ CenterCrop(224),
19
+ ToTensor(),
20
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
21
+ ])
22
+ return preprocess(image).unsqueeze(0)
23
+
24
+
25
+ #Define a function to process the image and text inputs:
26
+ def process_query(image, text):
27
+ image_tensor = preprocess_image(image)
28
+ inputs = processor(text, return_tensors="pt", padding=True, truncation=True)
29
+
30
+ with torch.no_grad():
31
+ outputs = model(input_ids=inputs["input_ids"], pixel_values=image_tensor)
32
+ logits_per_image = outputs.logits_per_image
33
+ probs = logits_per_image.softmax(dim=-1)
34
+ similarities = probs.squeeze()
35
+
36
+ # Calculate the similarity scores for each product in the dataset
37
+ product_scores = []
38
+ for product in deepfashion_database:
39
+ # Preprocess product image and compute similarity
40
+ product_image = Image.open(product["image_path"]).convert("RGB")
41
+ product_image_tensor = preprocess_image(product_image)
42
+ product_text = product["description"]
43
+ product_inputs = processor(product_text, return_tensors="pt", padding=True, truncation=True)
44
+
45
+ with torch.no_grad():
46
+ product_outputs = model(input_ids=product_inputs["input_ids"], pixel_values=product_image_tensor)
47
+ product_logits_per_image = product_outputs.logits_per_image
48
+ product_probs = product_logits_per_image.softmax(dim=-1)
49
+ product_similarity = product_probs.squeeze().item()
50
+
51
+ product_scores.append((product, product_similarity))
52
+
53
+ # Sort products by similarity and return the top 3 matches
54
+ top_3_products = sorted(product_scores, key=lambda x: x[1], reverse=True)[:3]
55
+ return '\n'.join([f"{product['description']} (score: {score:.2f})" for product, score in top_3_products])
56
+
57
+ @gr.app
58
+ def product_search_app(image, text):
59
+ return process_query(image, text)
60
+
61
+
62
+ iface = gr.Interface(
63
+ fn=process_query,
64
+ inputs=[
65
+ gr.inputs.Image(),
66
+ gr.inputs.Textbox(lines=3, label="Text Query")
67
+ ],
68
+ outputs=gr.outputs.Textbox(label="Top 3 Matches"),
69
+ title="Product Search",
70
+ description="Find the best matching products based on image and text queries.",
71
+ )
72
+
73
+
74
+ iface.launch()