fnauman commited on
Commit
726ea6c
·
verified ·
1 Parent(s): c00c80f

removed style - does not work with gr > 4

Browse files
Files changed (1) hide show
  1. app.py +250 -250
app.py CHANGED
@@ -1,250 +1,250 @@
1
- import os
2
-
3
- import clip
4
- import faiss
5
- import torch
6
- import tarfile
7
- import gradio as gr
8
- import pandas as pd
9
- from PIL import Image
10
- from braceexpand import braceexpand
11
- from torchvision import transforms
12
-
13
-
14
- # Load model
15
- checkpoint_path = "ViT-B/16"
16
- device = "cpu"
17
- model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
18
-
19
-
20
- def generate_caption(img):
21
- # Load caption bank
22
- df = pd.read_parquet("files/captions.parquet")
23
- caption_list = df["caption"].tolist()
24
-
25
- # Load index
26
- index = faiss.read_index("files/caption_bank.index")
27
-
28
- # Encode the image and query the caption bank index
29
- query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
30
- query_features = query_features / query_features.norm(dim=-1, keepdim=True)
31
- query_features = query_features.cpu().detach().numpy().astype("float32")
32
-
33
- # Get nearest captions
34
- d, i = index.search(query_features, 1)
35
- d, i = d[0], i[0]
36
- idx = i[0]
37
- distance = d[0]
38
-
39
- # Start with a description of the image
40
- caption = caption_list[idx]
41
-
42
- print(f"Index: {idx} - Distance: {distance:.2f}")
43
- return caption
44
-
45
-
46
- def predict_brand(img):
47
- # Load brand bank
48
- df = pd.read_parquet("files/brands.parquet")
49
- brand_list = df["brands"].tolist()
50
-
51
- # Load index
52
- index = faiss.read_index("files/brand_bank.index")
53
-
54
- # Encode the image and query the brand bank index
55
- query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
56
- query_features = query_features / query_features.norm(dim=-1, keepdim=True)
57
- query_features = query_features.cpu().detach().numpy().astype("float32")
58
-
59
- # Get nearest brands
60
- d, i = index.search(query_features, 1)
61
- d, i = d[0], i[0]
62
- idx = i[0]
63
- distance = d[0]
64
-
65
- brand = brand_list[idx]
66
- print(f"Index: {idx} - Distance: {distance:.2f}")
67
- return brand
68
-
69
-
70
- def estimate_price_and_usage(img):
71
- query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
72
-
73
- # Estimate usage
74
- num_classes = 2
75
- probe = torch.nn.Linear(
76
- query_features.shape[-1],
77
- num_classes,
78
- dtype=torch.float16,
79
- bias=False
80
- )
81
- # Load weights for the linear layer as a tensor
82
- linear_data = torch.load("files/reuse_linear.pth", map_location="cpu")
83
- probe.weight.data = linear_data["weight"]
84
-
85
- # Do inference
86
- with torch.autocast("cpu"):
87
- probe.eval()
88
- probe = probe.to(device)
89
- output = probe(query_features)
90
- output = torch.softmax(output, dim=-1)
91
- #output = output.cpu().detach().numpy().astype("float32")
92
- reuse = output.argmax(axis=-1)[0]
93
- reuse_classes = ["Reuse", "Export"]
94
-
95
- # Estimate price
96
- num_classes = 4
97
- probe = torch.nn.Linear(
98
- query_features.shape[-1],
99
- num_classes,
100
- dtype=torch.float16,
101
- bias=False
102
- )
103
- # Print output shape for the linear layer
104
- # Load weights for the linear layer as a tensor
105
- linear_data = torch.load("files/price_linear.pth", map_location="cpu")
106
- probe.weight.data = linear_data["weight"]
107
-
108
- # Do inference
109
- with torch.autocast("cpu"):
110
- probe.eval()
111
- probe = probe.to(device)
112
- output = probe(query_features)
113
- output = torch.softmax(output, dim=-1)
114
- #output = output.cpu().detach().numpy().astype("float32")
115
- price = output.argmax(axis=-1)[0]
116
- price_classes = ["<50", "50-100", "100-150", ">150"]
117
-
118
- return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
119
-
120
-
121
- def retrieve(query):
122
- index_folder = "files/index"
123
- num_results = 3
124
-
125
- # Read image metadata
126
- metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
127
- key_list = metadata_df["key"].tolist()
128
-
129
- # Load the index
130
- index = faiss.read_index(os.path.join(index_folder, "image.index"))
131
-
132
- # Encode the query
133
- if isinstance(query, str):
134
- print("Query is a string")
135
- text = clip.tokenize([query]).to(device)
136
- query_features = model.encode_text(text)
137
- else:
138
- print("Query is an image")
139
- query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
140
- query_features = query_features / query_features.norm(dim=-1, keepdim=True)
141
- query_features = query_features.cpu().detach().numpy().astype("float32")
142
-
143
- d, i = index.search(query_features, num_results)
144
- print(f"Found {num_results} items with query '{query}'")
145
- indices = i[0]
146
- similarities = d[0]
147
-
148
- min_d = min(similarities)
149
- max_d = max(similarities)
150
- print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
151
-
152
- # Uncomment to generate combined.tar, combine the image_tars into a single tarfile
153
- """
154
- dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only"
155
- image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))]
156
- with tarfile.open("files/combined.tar", "w") as combined_tar:
157
- for tar in image_tars:
158
- with tarfile.open(tar, "r") as tar_file:
159
- for member in tar_file.getmembers():
160
- combined_tar.addfile(member, tar_file.extractfile(member))
161
- """
162
-
163
- images = []
164
- for idx in indices:
165
- image_name = key_list[idx]
166
- with tarfile.open("files/combined.tar", "r") as tar_file:
167
- image = tar_file.extractfile(f"{image_name}.jpg")
168
- image = Image.open(image).copy()
169
- # Center crop the image
170
- width, height = image.size
171
- new_size = min(width, height)
172
- image = transforms.CenterCrop(new_size)(image)
173
- # Resize the image
174
- image = transforms.Resize((600, 600))(image)
175
- images.append(image)
176
- return images
177
-
178
-
179
- theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
180
-
181
- with gr.Blocks(
182
- theme=theme,
183
- css="footer {visibility: hidden}",
184
- ) as demo:
185
- with gr.Tab("Captioning and Prediction"):
186
- with gr.Row(variant="compact"):
187
- input_img = gr.Image(type="pil", show_label=False)
188
- with gr.Column(min_width="80"):
189
- btn_generate_caption = gr.Button("Create Description").style(size="sm")
190
- generated_caption = gr.Textbox(label="Description", show_label=False)
191
- gr.Examples(
192
- examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"],
193
- fn=generate_caption,
194
- inputs=input_img,
195
- outputs=generated_caption
196
- )
197
-
198
- with gr.Row(variant="compact"):
199
- brand_img = gr.Image(type="pil", show_label=False)
200
- with gr.Column(min_width="80"):
201
- btn_predict_brand = gr.Button("Predict Brand").style(size="sm")
202
- predicted_brand = gr.Textbox(label="Brand", show_label=False)
203
- gr.Examples(
204
- examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"],
205
- fn=predict_brand,
206
- inputs=brand_img,
207
- outputs=predicted_brand
208
- )
209
-
210
- with gr.Column(variant="compact"):
211
- btn_estimate = gr.Button("Estimate Price and Reuse").style(size="sm")
212
- text_box = gr.Textbox(label="Estimates:", show_label=False)
213
-
214
- with gr.Tab("Image Retrieval"):
215
- with gr.Row(variant="compact"):
216
- with gr.Column():
217
- query_img = gr.Image(type="pil", label="Image Query")
218
- btn_image_query = gr.Button("Retrieve Garments").style(size="sm")
219
- img_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
220
- gr.Examples(
221
- examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"],
222
- fn=retrieve,
223
- inputs=query_img,
224
- outputs=img_query_gallery
225
- )
226
-
227
- with gr.Row(variant="compact"):
228
- with gr.Column():
229
- query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
230
- btn_text_query = gr.Button("Retrieve Garments").style(size="sm")
231
- text_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
232
- gr.Examples(
233
- examples=["A purple sweater", "A dress with a floral pattern"],
234
- fn=retrieve,
235
- inputs=query_text,
236
- outputs=text_query_gallery
237
- )
238
-
239
- # Listeners
240
- btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
241
- btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
242
- btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
243
- btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
244
- btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
245
-
246
-
247
- if __name__ == "__main__":
248
- demo.launch(
249
- # inline=True
250
- )
 
1
+ import os
2
+
3
+ import clip
4
+ import faiss
5
+ import torch
6
+ import tarfile
7
+ import gradio as gr
8
+ import pandas as pd
9
+ from PIL import Image
10
+ from braceexpand import braceexpand
11
+ from torchvision import transforms
12
+
13
+
14
+ # Load model
15
+ checkpoint_path = "ViT-B/16"
16
+ device = "cpu"
17
+ model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
18
+
19
+
20
+ def generate_caption(img):
21
+ # Load caption bank
22
+ df = pd.read_parquet("files/captions.parquet")
23
+ caption_list = df["caption"].tolist()
24
+
25
+ # Load index
26
+ index = faiss.read_index("files/caption_bank.index")
27
+
28
+ # Encode the image and query the caption bank index
29
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
30
+ query_features = query_features / query_features.norm(dim=-1, keepdim=True)
31
+ query_features = query_features.cpu().detach().numpy().astype("float32")
32
+
33
+ # Get nearest captions
34
+ d, i = index.search(query_features, 1)
35
+ d, i = d[0], i[0]
36
+ idx = i[0]
37
+ distance = d[0]
38
+
39
+ # Start with a description of the image
40
+ caption = caption_list[idx]
41
+
42
+ print(f"Index: {idx} - Distance: {distance:.2f}")
43
+ return caption
44
+
45
+
46
+ def predict_brand(img):
47
+ # Load brand bank
48
+ df = pd.read_parquet("files/brands.parquet")
49
+ brand_list = df["brands"].tolist()
50
+
51
+ # Load index
52
+ index = faiss.read_index("files/brand_bank.index")
53
+
54
+ # Encode the image and query the brand bank index
55
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
56
+ query_features = query_features / query_features.norm(dim=-1, keepdim=True)
57
+ query_features = query_features.cpu().detach().numpy().astype("float32")
58
+
59
+ # Get nearest brands
60
+ d, i = index.search(query_features, 1)
61
+ d, i = d[0], i[0]
62
+ idx = i[0]
63
+ distance = d[0]
64
+
65
+ brand = brand_list[idx]
66
+ print(f"Index: {idx} - Distance: {distance:.2f}")
67
+ return brand
68
+
69
+
70
+ def estimate_price_and_usage(img):
71
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
72
+
73
+ # Estimate usage
74
+ num_classes = 2
75
+ probe = torch.nn.Linear(
76
+ query_features.shape[-1],
77
+ num_classes,
78
+ dtype=torch.float16,
79
+ bias=False
80
+ )
81
+ # Load weights for the linear layer as a tensor
82
+ linear_data = torch.load("files/reuse_linear.pth", map_location="cpu")
83
+ probe.weight.data = linear_data["weight"]
84
+
85
+ # Do inference
86
+ with torch.autocast("cpu"):
87
+ probe.eval()
88
+ probe = probe.to(device)
89
+ output = probe(query_features)
90
+ output = torch.softmax(output, dim=-1)
91
+ #output = output.cpu().detach().numpy().astype("float32")
92
+ reuse = output.argmax(axis=-1)[0]
93
+ reuse_classes = ["Reuse", "Export"]
94
+
95
+ # Estimate price
96
+ num_classes = 4
97
+ probe = torch.nn.Linear(
98
+ query_features.shape[-1],
99
+ num_classes,
100
+ dtype=torch.float16,
101
+ bias=False
102
+ )
103
+ # Print output shape for the linear layer
104
+ # Load weights for the linear layer as a tensor
105
+ linear_data = torch.load("files/price_linear.pth", map_location="cpu")
106
+ probe.weight.data = linear_data["weight"]
107
+
108
+ # Do inference
109
+ with torch.autocast("cpu"):
110
+ probe.eval()
111
+ probe = probe.to(device)
112
+ output = probe(query_features)
113
+ output = torch.softmax(output, dim=-1)
114
+ #output = output.cpu().detach().numpy().astype("float32")
115
+ price = output.argmax(axis=-1)[0]
116
+ price_classes = ["<50", "50-100", "100-150", ">150"]
117
+
118
+ return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
119
+
120
+
121
+ def retrieve(query):
122
+ index_folder = "files/index"
123
+ num_results = 3
124
+
125
+ # Read image metadata
126
+ metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
127
+ key_list = metadata_df["key"].tolist()
128
+
129
+ # Load the index
130
+ index = faiss.read_index(os.path.join(index_folder, "image.index"))
131
+
132
+ # Encode the query
133
+ if isinstance(query, str):
134
+ print("Query is a string")
135
+ text = clip.tokenize([query]).to(device)
136
+ query_features = model.encode_text(text)
137
+ else:
138
+ print("Query is an image")
139
+ query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
140
+ query_features = query_features / query_features.norm(dim=-1, keepdim=True)
141
+ query_features = query_features.cpu().detach().numpy().astype("float32")
142
+
143
+ d, i = index.search(query_features, num_results)
144
+ print(f"Found {num_results} items with query '{query}'")
145
+ indices = i[0]
146
+ similarities = d[0]
147
+
148
+ min_d = min(similarities)
149
+ max_d = max(similarities)
150
+ print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
151
+
152
+ # Uncomment to generate combined.tar, combine the image_tars into a single tarfile
153
+ """
154
+ dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only"
155
+ image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))]
156
+ with tarfile.open("files/combined.tar", "w") as combined_tar:
157
+ for tar in image_tars:
158
+ with tarfile.open(tar, "r") as tar_file:
159
+ for member in tar_file.getmembers():
160
+ combined_tar.addfile(member, tar_file.extractfile(member))
161
+ """
162
+
163
+ images = []
164
+ for idx in indices:
165
+ image_name = key_list[idx]
166
+ with tarfile.open("files/combined.tar", "r") as tar_file:
167
+ image = tar_file.extractfile(f"{image_name}.jpg")
168
+ image = Image.open(image).copy()
169
+ # Center crop the image
170
+ width, height = image.size
171
+ new_size = min(width, height)
172
+ image = transforms.CenterCrop(new_size)(image)
173
+ # Resize the image
174
+ image = transforms.Resize((600, 600))(image)
175
+ images.append(image)
176
+ return images
177
+
178
+
179
+ theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
180
+
181
+ with gr.Blocks(
182
+ theme=theme,
183
+ css="footer {visibility: hidden}",
184
+ ) as demo:
185
+ with gr.Tab("Captioning and Prediction"):
186
+ with gr.Row(variant="compact"):
187
+ input_img = gr.Image(type="pil", show_label=False)
188
+ with gr.Column(min_width="80"):
189
+ btn_generate_caption = gr.Button("Create Description", size="sm")
190
+ generated_caption = gr.Textbox(label="Description", show_label=False)
191
+ gr.Examples(
192
+ examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"],
193
+ fn=generate_caption,
194
+ inputs=input_img,
195
+ outputs=generated_caption
196
+ )
197
+
198
+ with gr.Row(variant="compact"):
199
+ brand_img = gr.Image(type="pil", show_label=False)
200
+ with gr.Column(min_width="80"):
201
+ btn_predict_brand = gr.Button("Predict Brand", size="sm")
202
+ predicted_brand = gr.Textbox(label="Brand", show_label=False)
203
+ gr.Examples(
204
+ examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"],
205
+ fn=predict_brand,
206
+ inputs=brand_img,
207
+ outputs=predicted_brand
208
+ )
209
+
210
+ with gr.Column(variant="compact"):
211
+ btn_estimate = gr.Button("Estimate Price and Reuse", size="sm")
212
+ text_box = gr.Textbox(label="Estimates:", show_label=False)
213
+
214
+ with gr.Tab("Image Retrieval"):
215
+ with gr.Row(variant="compact"):
216
+ with gr.Column():
217
+ query_img = gr.Image(type="pil", label="Image Query")
218
+ btn_image_query = gr.Button("Retrieve Garments", size="sm")
219
+ img_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3)
220
+ gr.Examples(
221
+ examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"],
222
+ fn=retrieve,
223
+ inputs=query_img,
224
+ outputs=img_query_gallery
225
+ )
226
+
227
+ with gr.Row(variant="compact"):
228
+ with gr.Column():
229
+ query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
230
+ btn_text_query = gr.Button("Retrieve Garments", size="sm")
231
+ text_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3)
232
+ gr.Examples(
233
+ examples=["A purple sweater", "A dress with a floral pattern"],
234
+ fn=retrieve,
235
+ inputs=query_text,
236
+ outputs=text_query_gallery
237
+ )
238
+
239
+ # Listeners
240
+ btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
241
+ btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
242
+ btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
243
+ btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
244
+ btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ demo.launch(
249
+ # inline=True
250
+ )