samnji commited on
Commit
4be0554
·
1 Parent(s): ace92f7
Files changed (2) hide show
  1. app copy.py +127 -0
  2. app.py +37 -25
app copy.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ from datasets import load_dataset
6
+ import torch
7
+
8
+ # Load the pre-trained CLIP model and its tokenizer
9
+ model_name = "openai/clip-vit-base-patch32"
10
+ processor = CLIPProcessor.from_pretrained(model_name)
11
+ model = CLIPModel.from_pretrained(model_name)
12
+
13
+ # Load the fashion product images dataset from Hugging Face
14
+ dataset = load_dataset("ashraq/fashion-product-images-small")
15
+ deepfashion_database = dataset["train"]
16
+
17
+ # Define the preprocessing function for images
18
+ def preprocess_image(image):
19
+ preprocess = Compose([
20
+ Resize(256, interpolation=Image.BICUBIC),
21
+ CenterCrop(224),
22
+ ToTensor(),
23
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
24
+ ])
25
+ return preprocess(image).unsqueeze(0)
26
+
27
+ # Define a function to process the image and text inputs
28
+ def initial_query(image, text):
29
+ return process_query(image, text, deepfashion_database)
30
+
31
+ def process_query(image, text, database):
32
+ image_tensor = preprocess_image(image)
33
+ inputs = processor(text, return_tensors="pt", padding=True, truncation=True)
34
+
35
+ with torch.no_grad():
36
+ outputs = model(input_ids=inputs["input_ids"], pixel_values=image_tensor)
37
+ logits_per_image = outputs.logits_per_image
38
+ probs = logits_per_image.softmax(dim=-1)
39
+ similarities = probs.squeeze()
40
+
41
+ product_scores = []
42
+ for product in database:
43
+ product_image = Image.open(product["image_path"]).convert("RGB")
44
+ product_image_tensor = preprocess_image(product_image)
45
+ product_text = product["description"]
46
+ product_inputs = processor(product_text, return_tensors="pt", padding=True, truncation=True)
47
+
48
+ with torch.no_grad():
49
+ product_outputs = model(input_ids=product_inputs["input_ids"], pixel_values=product_image_tensor)
50
+ product_logits_per_image = product_outputs.logits_per_image
51
+ product_probs = product_logits_per_image.softmax(dim=-1)
52
+ product_similarity = product_probs.squeeze().item()
53
+
54
+ product_scores.append((product, product_similarity))
55
+
56
+ top_3_products = sorted(product_scores, key=lambda x: x[1], reverse=True)[:3]
57
+ return top_3_products
58
+
59
+ def refine_query(selected_product_index, additional_text, initial_results):
60
+ selected_product = initial_results[selected_product_index]
61
+ modified_description = selected_product["description"] + " " + additional_text
62
+ refined_product = {"description": modified_description, "image_path": selected_product["image_path"]}
63
+ refined_database = [product for i, product in enumerate(initial_results) if i != selected_product_index]
64
+ refined_database.append(refined_product)
65
+ return process_query(Image.open(selected_product["image_path"]).convert("RGB"), modified_description, refined_database)
66
+
67
+
68
+
69
+
70
+
71
+ def generate_output_html(products):
72
+ html_output = "<ol>"
73
+ for product in products:
74
+ html_output += f'<li><img src="{product[0]["image_path"]}" width="100" height="100"><br>{product[0]["description"]}</li>'
75
+ html_output += "</ol>"
76
+ return html_output
77
+
78
+ def initial_query_wrapper(image, text):
79
+ top_3_products = initial_query(image, text)
80
+ return generate_output_html(top_3_products),
81
+
82
+ def refine_query_wrapper(selected_product_index, additional_text, initial_results):
83
+ top_3_products = refine_query(selected_product_index, additional_text, initial_results)
84
+ return generate_output_html(top_3_products),
85
+
86
+ def product_search_wrapper(image=None, text=None, selected_product_index=None, additional_text=None):
87
+ if image is not None and text is not None:
88
+ top_3_products = initial_query(image, text)
89
+ return generate_output_html(top_3_products),
90
+ elif selected_product_index is not None and additional_text is not None:
91
+ top_3_products = refine_query(selected_product_index, additional_text)
92
+ return generate_output_html(top_3_products),
93
+ else:
94
+ return "",
95
+
96
+ iface = gr.Interface(
97
+ fn=product_search_wrapper,
98
+ inputs=[
99
+ gr.inputs.Image(optional=True),
100
+ gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
101
+ gr.inputs.Radio(["0", "1", "2"], label="Select Product Index", optional=True),
102
+ gr.inputs.Textbox(lines=3, label="Additional Text Query", optional=True)
103
+ ],
104
+ outputs=[
105
+ gr.outputs.HTML(label="Results")
106
+ ],
107
+ title="Product Search",
108
+ description="Find the best matching products using images and text queries.",
109
+ layout="vertical"
110
+ )
111
+
112
+ # iface = gr.Interface(
113
+ # fn=[initial_query_wrapper, refine_query_wrapper],
114
+ # inputs=[
115
+ # [gr.inputs.Image(), gr.inputs.Textbox(lines=3, label="Initial Text Query")],
116
+ # [gr.inputs.Radio(["0", "1", "2"], label="Select Product Index"), gr.inputs.Textbox(lines=3, label="Additional Text Query"), gr.inputs.Hidden(initial_results="initial_query")]
117
+ # ],
118
+ # outputs=[
119
+ # gr.outputs.HTML(label="Top 3 Matches"),
120
+ # gr.outputs.HTML(label="Refined Top 3 Matches")
121
+ # ],
122
+ # title="Product Search",
123
+ # description="Find the best matching products using images and text queries.",
124
+ # layout="vertical"
125
+ # )
126
+
127
+ iface.launch()
app.py CHANGED
@@ -93,35 +93,47 @@ def product_search_wrapper(image=None, text=None, selected_product_index=None, a
93
  else:
94
  return "",
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  iface = gr.Interface(
97
  fn=product_search_wrapper,
98
  inputs=[
99
- gr.inputs.Image(optional=True),
100
- gr.inputs.Textbox(lines=3, label="Initial Text Query", optional=True),
101
- gr.inputs.Radio(["0", "1", "2"], label="Select Product Index", optional=True),
102
- gr.inputs.Textbox(lines=3, label="Additional Text Query", optional=True)
103
  ],
104
- outputs=[
105
- gr.outputs.HTML(label="Results")
106
- ],
107
- title="Product Search",
108
  description="Find the best matching products using images and text queries.",
109
- layout="vertical"
 
110
  )
111
-
112
- # iface = gr.Interface(
113
- # fn=[initial_query_wrapper, refine_query_wrapper],
114
- # inputs=[
115
- # [gr.inputs.Image(), gr.inputs.Textbox(lines=3, label="Initial Text Query")],
116
- # [gr.inputs.Radio(["0", "1", "2"], label="Select Product Index"), gr.inputs.Textbox(lines=3, label="Additional Text Query"), gr.inputs.Hidden(initial_results="initial_query")]
117
- # ],
118
- # outputs=[
119
- # gr.outputs.HTML(label="Top 3 Matches"),
120
- # gr.outputs.HTML(label="Refined Top 3 Matches")
121
- # ],
122
- # title="Product Search",
123
- # description="Find the best matching products using images and text queries.",
124
- # layout="vertical"
125
- # )
126
-
127
  iface.launch()
 
 
93
  else:
94
  return "",
95
 
96
+ from PIL import Image
97
+
98
+ def preprocess_image(image):
99
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
100
+ preprocess = transforms.Compose([
101
+ transforms.Resize(224),
102
+ transforms.CenterCrop(224),
103
+ transforms.ToTensor(),
104
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
105
+ ])
106
+ return preprocess(pil_image).unsqueeze(0)
107
+
108
+ def append_message_to_history(history, message, role):
109
+ return history + f'<div class="{role}">{message}</div>'
110
+
111
+ def product_search_wrapper(history, message, role):
112
+ if role == "user":
113
+ history = append_message_to_history(history, message, "user")
114
+ image, text = parse_user_message(message)
115
+ if image is not None and text is not None:
116
+ top_3_products = initial_query(image, text)
117
+ output_html = generate_output_html(top_3_products)
118
+ history = append_message_to_history(history, output_html, "bot")
119
+ elif selected_product_index is not None and additional_text is not None:
120
+ top_3_products = refine_query(selected_product_index, additional_text)
121
+ output_html = generate_output_html(top_3_products)
122
+ history = append_message_to_history(history, output_html, "bot")
123
+ return history
124
+
125
  iface = gr.Interface(
126
  fn=product_search_wrapper,
127
  inputs=[
128
+ gr.inputs.Textbox(lines=3, label="Type your message"),
129
+ gr.inputs.Hidden(label="Chat History", default=""),
130
+ gr.inputs.Radio(["user", "bot"], label="Role", default="user"),
 
131
  ],
132
+ outputs=gr.outputs.HTML(label="Chat History"),
133
+ title="Product Search Chatbot",
 
 
134
  description="Find the best matching products using images and text queries.",
135
+ layout="vertical",
136
+ allow_flagging=False
137
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  iface.launch()
139
+