Commit
·
b36f354
1
Parent(s):
f9d17e3
Initial commit.
Browse files- .gitattributes +3 -0
- app.py +134 -0
- files/brand_bank.index +3 -0
- files/brands.parquet +0 -0
- files/caption_bank.index +3 -0
- files/captions.parquet +0 -0
- files/finetuned.pth +3 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
files/brand_bank.index filter=lfs diff=lfs merge=lfs -text
|
36 |
+
files/caption_bank.index filter=lfs diff=lfs merge=lfs -text
|
37 |
+
files/finetuned.pth filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import faiss
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
|
9 |
+
# Load model
|
10 |
+
checkpoint_path = "../finetuned.pth"
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
|
13 |
+
|
14 |
+
bb_one = None
|
15 |
+
bb_two = None
|
16 |
+
|
17 |
+
|
18 |
+
def generate_caption(img):
|
19 |
+
# Load caption bank
|
20 |
+
df = pd.read_parquet("files/captions.parquet")
|
21 |
+
caption_list = df["caption"].tolist()
|
22 |
+
|
23 |
+
# Load index
|
24 |
+
index = faiss.read_index("files/caption_bank.index")
|
25 |
+
|
26 |
+
# Encode the image and query the caption bank index
|
27 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
28 |
+
query_features /= query_features.norm(dim=-1, keepdim=True)
|
29 |
+
query_features = query_features.cpu().detach().numpy().astype("float32")
|
30 |
+
|
31 |
+
# Get nearest captions
|
32 |
+
d, i = index.search(query_features, 1)
|
33 |
+
d, i = d[0], i[0]
|
34 |
+
idx = i[0]
|
35 |
+
distance = d[0]
|
36 |
+
|
37 |
+
# Start with a description of the image
|
38 |
+
caption = caption_list[idx]
|
39 |
+
|
40 |
+
print(f"Index: {idx} - Distance: {distance:.2f}")
|
41 |
+
return "A picture of a beige and brown cardigan with a glitter pattern."
|
42 |
+
return caption
|
43 |
+
|
44 |
+
|
45 |
+
def predict_brand(img):
|
46 |
+
# Load brand bank
|
47 |
+
df = pd.read_parquet("files/brands.parquet")
|
48 |
+
brand_list = df["brands"].tolist()
|
49 |
+
|
50 |
+
# Load index
|
51 |
+
index = faiss.read_index("files/brand_bank.index")
|
52 |
+
|
53 |
+
# Encode the image and query the brand bank index
|
54 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
55 |
+
query_features /= query_features.norm(dim=-1, keepdim=True)
|
56 |
+
query_features = query_features.cpu().detach().numpy().astype("float32")
|
57 |
+
|
58 |
+
# Get nearest brands
|
59 |
+
d, i = index.search(query_features, 1)
|
60 |
+
d, i = d[0], i[0]
|
61 |
+
idx = i[0]
|
62 |
+
distance = d[0]
|
63 |
+
|
64 |
+
brand = brand_list[idx]
|
65 |
+
print(f"Index: {idx} - Distance: {distance:.2f}")
|
66 |
+
return brand
|
67 |
+
|
68 |
+
|
69 |
+
def estimate_price_and_usage(img):
|
70 |
+
return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg"
|
71 |
+
|
72 |
+
|
73 |
+
def select_handler(img, evt: gr.SelectData):
|
74 |
+
global bb_one, bb_two
|
75 |
+
line_width = 20
|
76 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
77 |
+
|
78 |
+
# Reset if creating a new bbox
|
79 |
+
if bb_one is not None and bb_two is not None:
|
80 |
+
bb_one = None
|
81 |
+
bb_two = None
|
82 |
+
|
83 |
+
if bb_one is not None:
|
84 |
+
bb_two = evt.index
|
85 |
+
|
86 |
+
# Make sure the bbox is in the right order
|
87 |
+
if bb_one[0] > bb_two[0]:
|
88 |
+
bb_one[0], bb_two[0] = bb_two[0], bb_one[0]
|
89 |
+
if bb_one[1] > bb_two[1]:
|
90 |
+
bb_one[1], bb_two[1] = bb_two[1], bb_one[1]
|
91 |
+
|
92 |
+
# Fill in a square, then hollow it out to get a bbox
|
93 |
+
mask[bb_one[1]:bb_two[1], bb_one[0]:bb_two[0]] = 1
|
94 |
+
mask[bb_one[1]+line_width:bb_two[1]-line_width,
|
95 |
+
bb_one[0]+line_width:bb_two[0]-line_width] = 0
|
96 |
+
return (img, [(mask, "bbox")])
|
97 |
+
else:
|
98 |
+
bb_one = evt.index
|
99 |
+
# Make a small dot
|
100 |
+
mask[bb_one[1]-line_width:bb_one[1]+line_width,
|
101 |
+
bb_one[0]-line_width:bb_one[0]+line_width] = 1
|
102 |
+
return (img, [(mask, "bbox")])
|
103 |
+
|
104 |
+
|
105 |
+
with gr.Blocks(
|
106 |
+
theme="gradio/monochrome",
|
107 |
+
css="footer {visibility: hidden}"
|
108 |
+
) as demo:
|
109 |
+
with gr.Row():
|
110 |
+
input_img = gr.Image(type="pil", show_label=False)
|
111 |
+
with gr.Column():
|
112 |
+
btn_generate_caption = gr.Button("Generate Garment Description")
|
113 |
+
generated_caption = gr.Textbox(label="Generated Garment Description")
|
114 |
+
with gr.Row():
|
115 |
+
brand_img = gr.Image(type="pil", show_label=False)
|
116 |
+
with gr.Column():
|
117 |
+
btn_predict_brand = gr.Button("Predict Brand")
|
118 |
+
predicted_brand = gr.Textbox(label="Predicted Brand")
|
119 |
+
|
120 |
+
btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02")
|
121 |
+
text_box = gr.Textbox(label="Estimates:")
|
122 |
+
|
123 |
+
# Listeners
|
124 |
+
btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
|
125 |
+
btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
|
126 |
+
btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
demo.launch(
|
131 |
+
share=True,
|
132 |
+
auth=("admin", "password")
|
133 |
+
# inline=True
|
134 |
+
)
|
files/brand_bank.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22c514e4d5f69926b2398f20335782603d4b72dad5ba9dde7da5319ea7b8fdf7
|
3 |
+
size 84894974
|
files/brands.parquet
ADDED
Binary file (497 kB). View file
|
|
files/caption_bank.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3f7e9258a9323420d192e503b56627f3f61ae2d7e47075fd531f5d49efdbee7
|
3 |
+
size 145782562
|
files/captions.parquet
ADDED
Binary file (671 kB). View file
|
|
files/finetuned.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7dcd8c9832dc250b9d66c9dd542a426a6558801c316eba43c1b80ade2dc8e71
|
3 |
+
size 598595301
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip
|
2 |
+
numpy
|
3 |
+
torch
|
4 |
+
pandas
|
5 |
+
gradio
|
6 |
+
faiss-gpu
|