simonhermansson commited on
Commit
b36f354
·
1 Parent(s): f9d17e3

Initial commit.

Browse files
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ files/brand_bank.index filter=lfs diff=lfs merge=lfs -text
36
+ files/caption_bank.index filter=lfs diff=lfs merge=lfs -text
37
+ files/finetuned.pth filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import faiss
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import pandas as pd
7
+
8
+
9
+ # Load model
10
+ checkpoint_path = "../finetuned.pth"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
13
+
14
+ bb_one = None
15
+ bb_two = None
16
+
17
+
18
+ def generate_caption(img):
19
+ # Load caption bank
20
+ df = pd.read_parquet("files/captions.parquet")
21
+ caption_list = df["caption"].tolist()
22
+
23
+ # Load index
24
+ index = faiss.read_index("files/caption_bank.index")
25
+
26
+ # Encode the image and query the caption bank index
27
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
28
+ query_features /= query_features.norm(dim=-1, keepdim=True)
29
+ query_features = query_features.cpu().detach().numpy().astype("float32")
30
+
31
+ # Get nearest captions
32
+ d, i = index.search(query_features, 1)
33
+ d, i = d[0], i[0]
34
+ idx = i[0]
35
+ distance = d[0]
36
+
37
+ # Start with a description of the image
38
+ caption = caption_list[idx]
39
+
40
+ print(f"Index: {idx} - Distance: {distance:.2f}")
41
+ return "A picture of a beige and brown cardigan with a glitter pattern."
42
+ return caption
43
+
44
+
45
+ def predict_brand(img):
46
+ # Load brand bank
47
+ df = pd.read_parquet("files/brands.parquet")
48
+ brand_list = df["brands"].tolist()
49
+
50
+ # Load index
51
+ index = faiss.read_index("files/brand_bank.index")
52
+
53
+ # Encode the image and query the brand bank index
54
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
55
+ query_features /= query_features.norm(dim=-1, keepdim=True)
56
+ query_features = query_features.cpu().detach().numpy().astype("float32")
57
+
58
+ # Get nearest brands
59
+ d, i = index.search(query_features, 1)
60
+ d, i = d[0], i[0]
61
+ idx = i[0]
62
+ distance = d[0]
63
+
64
+ brand = brand_list[idx]
65
+ print(f"Index: {idx} - Distance: {distance:.2f}")
66
+ return brand
67
+
68
+
69
+ def estimate_price_and_usage(img):
70
+ return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg"
71
+
72
+
73
+ def select_handler(img, evt: gr.SelectData):
74
+ global bb_one, bb_two
75
+ line_width = 20
76
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
77
+
78
+ # Reset if creating a new bbox
79
+ if bb_one is not None and bb_two is not None:
80
+ bb_one = None
81
+ bb_two = None
82
+
83
+ if bb_one is not None:
84
+ bb_two = evt.index
85
+
86
+ # Make sure the bbox is in the right order
87
+ if bb_one[0] > bb_two[0]:
88
+ bb_one[0], bb_two[0] = bb_two[0], bb_one[0]
89
+ if bb_one[1] > bb_two[1]:
90
+ bb_one[1], bb_two[1] = bb_two[1], bb_one[1]
91
+
92
+ # Fill in a square, then hollow it out to get a bbox
93
+ mask[bb_one[1]:bb_two[1], bb_one[0]:bb_two[0]] = 1
94
+ mask[bb_one[1]+line_width:bb_two[1]-line_width,
95
+ bb_one[0]+line_width:bb_two[0]-line_width] = 0
96
+ return (img, [(mask, "bbox")])
97
+ else:
98
+ bb_one = evt.index
99
+ # Make a small dot
100
+ mask[bb_one[1]-line_width:bb_one[1]+line_width,
101
+ bb_one[0]-line_width:bb_one[0]+line_width] = 1
102
+ return (img, [(mask, "bbox")])
103
+
104
+
105
+ with gr.Blocks(
106
+ theme="gradio/monochrome",
107
+ css="footer {visibility: hidden}"
108
+ ) as demo:
109
+ with gr.Row():
110
+ input_img = gr.Image(type="pil", show_label=False)
111
+ with gr.Column():
112
+ btn_generate_caption = gr.Button("Generate Garment Description")
113
+ generated_caption = gr.Textbox(label="Generated Garment Description")
114
+ with gr.Row():
115
+ brand_img = gr.Image(type="pil", show_label=False)
116
+ with gr.Column():
117
+ btn_predict_brand = gr.Button("Predict Brand")
118
+ predicted_brand = gr.Textbox(label="Predicted Brand")
119
+
120
+ btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02")
121
+ text_box = gr.Textbox(label="Estimates:")
122
+
123
+ # Listeners
124
+ btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
125
+ btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
126
+ btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
127
+
128
+
129
+ if __name__ == "__main__":
130
+ demo.launch(
131
+ share=True,
132
+ auth=("admin", "password")
133
+ # inline=True
134
+ )
files/brand_bank.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22c514e4d5f69926b2398f20335782603d4b72dad5ba9dde7da5319ea7b8fdf7
3
+ size 84894974
files/brands.parquet ADDED
Binary file (497 kB). View file
 
files/caption_bank.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3f7e9258a9323420d192e503b56627f3f61ae2d7e47075fd531f5d49efdbee7
3
+ size 145782562
files/captions.parquet ADDED
Binary file (671 kB). View file
 
files/finetuned.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7dcd8c9832dc250b9d66c9dd542a426a6558801c316eba43c1b80ade2dc8e71
3
+ size 598595301
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ clip
2
+ numpy
3
+ torch
4
+ pandas
5
+ gradio
6
+ faiss-gpu