simonhermansson commited on
Commit
6fa3157
·
1 Parent(s): 3eeb31d

Added image retrieval

Browse files
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  files/brand_bank.index filter=lfs diff=lfs merge=lfs -text
36
  files/caption_bank.index filter=lfs diff=lfs merge=lfs -text
37
  files/finetuned.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  files/brand_bank.index filter=lfs diff=lfs merge=lfs -text
36
  files/caption_bank.index filter=lfs diff=lfs merge=lfs -text
37
  files/finetuned.pth filter=lfs diff=lfs merge=lfs -text
38
+ files/combined.tar filter=lfs diff=lfs merge=lfs -text
39
+ files/index/image.index filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,20 +1,20 @@
 
 
1
  import clip
2
  import faiss
3
  import torch
4
- import numpy as np
5
  import gradio as gr
6
  import pandas as pd
 
 
7
 
8
 
9
  # Load model
10
- #checkpoint_path = "files/finetuned_from_dopamine.pth"
11
  checkpoint_path = "ViT-B/16"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
14
 
15
- bb_one = None
16
- bb_two = None
17
-
18
 
19
  def generate_caption(img):
20
  # Load caption bank
@@ -70,36 +70,56 @@ def estimate_price_and_usage(img):
70
  return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg"
71
 
72
 
73
- def select_handler(img, evt: gr.SelectData):
74
- global bb_one, bb_two
75
- line_width = 20
76
- mask = np.zeros(img.shape[:2], dtype=np.uint8)
77
-
78
- # Reset if creating a new bbox
79
- if bb_one is not None and bb_two is not None:
80
- bb_one = None
81
- bb_two = None
82
 
83
- if bb_one is not None:
84
- bb_two = evt.index
 
85
 
86
- # Make sure the bbox is in the right order
87
- if bb_one[0] > bb_two[0]:
88
- bb_one[0], bb_two[0] = bb_two[0], bb_one[0]
89
- if bb_one[1] > bb_two[1]:
90
- bb_one[1], bb_two[1] = bb_two[1], bb_one[1]
91
 
92
- # Fill in a square, then hollow it out to get a bbox
93
- mask[bb_one[1]:bb_two[1], bb_one[0]:bb_two[0]] = 1
94
- mask[bb_one[1]+line_width:bb_two[1]-line_width,
95
- bb_one[0]+line_width:bb_two[0]-line_width] = 0
96
- return (img, [(mask, "bbox")])
97
  else:
98
- bb_one = evt.index
99
- # Make a small dot
100
- mask[bb_one[1]-line_width:bb_one[1]+line_width,
101
- bb_one[0]-line_width:bb_one[0]+line_width] = 1
102
- return (img, [(mask, "bbox")])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
@@ -108,25 +128,39 @@ with gr.Blocks(
108
  theme=theme,
109
  css="footer {visibility: hidden}",
110
  ) as demo:
111
- with gr.Row(variant="compact"):
112
- input_img = gr.Image(type="pil", show_label=False)
113
- with gr.Column(min_width="80"):
114
- btn_generate_caption = gr.Button("Create Description").style(size="sm")
115
- generated_caption = gr.Textbox(label="Description", show_label=False)
116
- with gr.Row(variant="compact"):
117
- brand_img = gr.Image(type="pil", show_label=False)
118
- with gr.Column(min_width="80"):
119
- btn_predict_brand = gr.Button("Predict Brand").style(size="sm")
120
- predicted_brand = gr.Textbox(label="Brand", show_label=False)
121
-
122
- with gr.Column(variant="compact"):
123
- btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02").style(size="sm")
124
- text_box = gr.Textbox(label="Estimates:", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # Listeners
127
  btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
128
  btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
129
  btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
 
 
130
 
131
 
132
  if __name__ == "__main__":
 
1
+ import os
2
+
3
  import clip
4
  import faiss
5
  import torch
6
+ import tarfile
7
  import gradio as gr
8
  import pandas as pd
9
+ from PIL import Image
10
+ from braceexpand import braceexpand
11
 
12
 
13
  # Load model
 
14
  checkpoint_path = "ViT-B/16"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
17
 
 
 
 
18
 
19
  def generate_caption(img):
20
  # Load caption bank
 
70
  return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg"
71
 
72
 
73
+ def retrieve(query):
74
+ index_folder = "files/index"
75
+ num_results = 3
 
 
 
 
 
 
76
 
77
+ # Read image metadata
78
+ metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
79
+ key_list = metadata_df["key"].tolist()
80
 
81
+ # Load the index
82
+ index = faiss.read_index(os.path.join(index_folder, "image.index"))
 
 
 
83
 
84
+ # Encode the query
85
+ if isinstance(query, str):
86
+ print("Query is a string")
87
+ text = clip.tokenize([query]).to(device)
88
+ query_features = model.encode_text(text)
89
  else:
90
+ print("Query is an image")
91
+ query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
92
+ query_features = query_features / query_features.norm(dim=-1, keepdim=True)
93
+ query_features = query_features.cpu().detach().numpy().astype("float32")
94
+
95
+ d, i = index.search(query_features, num_results)
96
+ print(f"Found {num_results} items with query '{query}'")
97
+ indices = i[0]
98
+ similarities = d[0]
99
+
100
+ min_d = min(similarities)
101
+ max_d = max(similarities)
102
+ print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
103
+
104
+ # Uncomment to generate combined.tar, combine the image_tars into a single tarfile
105
+ """
106
+ dataset_dir = "/fs/sefs1/circularfashion/sellpy/front_balanced"
107
+ image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{00000..00010}.tar"))]
108
+ with tarfile.open("files/combined.tar", "w") as combined_tar:
109
+ for tar in image_tars:
110
+ with tarfile.open(tar, "r") as tar_file:
111
+ for member in tar_file.getmembers():
112
+ combined_tar.addfile(member, tar_file.extractfile(member))
113
+ """
114
+
115
+ images = []
116
+ for idx in indices:
117
+ image_name = key_list[idx]
118
+ with tarfile.open("files/combined.tar", "r") as tar_file:
119
+ image = tar_file.extractfile(f"{image_name}.jpg")
120
+ image = Image.open(image).copy()
121
+ images.append(image)
122
+ return images
123
 
124
 
125
  theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
 
128
  theme=theme,
129
  css="footer {visibility: hidden}",
130
  ) as demo:
131
+ with gr.Tab("Captioning and Prediction"):
132
+ with gr.Row(variant="compact"):
133
+ input_img = gr.Image(type="pil", show_label=False)
134
+ with gr.Column(min_width="80"):
135
+ btn_generate_caption = gr.Button("Create Description").style(size="sm")
136
+ generated_caption = gr.Textbox(label="Description", show_label=False)
137
+ with gr.Row(variant="compact"):
138
+ brand_img = gr.Image(type="pil", show_label=False)
139
+ with gr.Column(min_width="80"):
140
+ btn_predict_brand = gr.Button("Predict Brand").style(size="sm")
141
+ predicted_brand = gr.Textbox(label="Brand", show_label=False)
142
+
143
+ with gr.Column(variant="compact"):
144
+ btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02").style(size="sm")
145
+ text_box = gr.Textbox(label="Estimates:", show_label=False)
146
+ with gr.Tab("Image Retrieval"):
147
+ with gr.Row(variant="compact"):
148
+ with gr.Column():
149
+ query_img = gr.Image(type="pil", label="Image Query")
150
+ btn_image_query = gr.Button("Retrieve Garments").style(size="sm")
151
+ img_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
152
+ with gr.Row(variant="compact"):
153
+ with gr.Column():
154
+ query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
155
+ btn_text_query = gr.Button("Retrieve Garments").style(size="sm")
156
+ text_query_gallery = gr.Gallery(show_label=False).style(rows=1, columns=3)
157
 
158
  # Listeners
159
  btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
160
  btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
161
  btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
162
+ btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
163
+ btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
164
 
165
 
166
  if __name__ == "__main__":
files/combined.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12201a61246802fa690e35ba79ffaa310adf9a1a99ca28a227779e34ea4a7f5f
3
+ size 723496960
files/index/image.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deeac00cfb73de054828290e83c06b6f0874d3f9d11fa74d260c3ac7b73f6511
3
+ size 40781869
files/index/metadata.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7c5cdbe9c3db67c76d6457f5138215552f668e455ac5f6093d6f6b8da751851
3
+ size 312316