removed style - does not work with gr > 4
Browse files
app.py
CHANGED
@@ -1,250 +1,250 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import clip
|
4 |
-
import faiss
|
5 |
-
import torch
|
6 |
-
import tarfile
|
7 |
-
import gradio as gr
|
8 |
-
import pandas as pd
|
9 |
-
from PIL import Image
|
10 |
-
from braceexpand import braceexpand
|
11 |
-
from torchvision import transforms
|
12 |
-
|
13 |
-
|
14 |
-
# Load model
|
15 |
-
checkpoint_path = "ViT-B/16"
|
16 |
-
device = "cpu"
|
17 |
-
model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
|
18 |
-
|
19 |
-
|
20 |
-
def generate_caption(img):
|
21 |
-
# Load caption bank
|
22 |
-
df = pd.read_parquet("files/captions.parquet")
|
23 |
-
caption_list = df["caption"].tolist()
|
24 |
-
|
25 |
-
# Load index
|
26 |
-
index = faiss.read_index("files/caption_bank.index")
|
27 |
-
|
28 |
-
# Encode the image and query the caption bank index
|
29 |
-
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
30 |
-
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
31 |
-
query_features = query_features.cpu().detach().numpy().astype("float32")
|
32 |
-
|
33 |
-
# Get nearest captions
|
34 |
-
d, i = index.search(query_features, 1)
|
35 |
-
d, i = d[0], i[0]
|
36 |
-
idx = i[0]
|
37 |
-
distance = d[0]
|
38 |
-
|
39 |
-
# Start with a description of the image
|
40 |
-
caption = caption_list[idx]
|
41 |
-
|
42 |
-
print(f"Index: {idx} - Distance: {distance:.2f}")
|
43 |
-
return caption
|
44 |
-
|
45 |
-
|
46 |
-
def predict_brand(img):
|
47 |
-
# Load brand bank
|
48 |
-
df = pd.read_parquet("files/brands.parquet")
|
49 |
-
brand_list = df["brands"].tolist()
|
50 |
-
|
51 |
-
# Load index
|
52 |
-
index = faiss.read_index("files/brand_bank.index")
|
53 |
-
|
54 |
-
# Encode the image and query the brand bank index
|
55 |
-
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
56 |
-
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
57 |
-
query_features = query_features.cpu().detach().numpy().astype("float32")
|
58 |
-
|
59 |
-
# Get nearest brands
|
60 |
-
d, i = index.search(query_features, 1)
|
61 |
-
d, i = d[0], i[0]
|
62 |
-
idx = i[0]
|
63 |
-
distance = d[0]
|
64 |
-
|
65 |
-
brand = brand_list[idx]
|
66 |
-
print(f"Index: {idx} - Distance: {distance:.2f}")
|
67 |
-
return brand
|
68 |
-
|
69 |
-
|
70 |
-
def estimate_price_and_usage(img):
|
71 |
-
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
72 |
-
|
73 |
-
# Estimate usage
|
74 |
-
num_classes = 2
|
75 |
-
probe = torch.nn.Linear(
|
76 |
-
query_features.shape[-1],
|
77 |
-
num_classes,
|
78 |
-
dtype=torch.float16,
|
79 |
-
bias=False
|
80 |
-
)
|
81 |
-
# Load weights for the linear layer as a tensor
|
82 |
-
linear_data = torch.load("files/reuse_linear.pth", map_location="cpu")
|
83 |
-
probe.weight.data = linear_data["weight"]
|
84 |
-
|
85 |
-
# Do inference
|
86 |
-
with torch.autocast("cpu"):
|
87 |
-
probe.eval()
|
88 |
-
probe = probe.to(device)
|
89 |
-
output = probe(query_features)
|
90 |
-
output = torch.softmax(output, dim=-1)
|
91 |
-
#output = output.cpu().detach().numpy().astype("float32")
|
92 |
-
reuse = output.argmax(axis=-1)[0]
|
93 |
-
reuse_classes = ["Reuse", "Export"]
|
94 |
-
|
95 |
-
# Estimate price
|
96 |
-
num_classes = 4
|
97 |
-
probe = torch.nn.Linear(
|
98 |
-
query_features.shape[-1],
|
99 |
-
num_classes,
|
100 |
-
dtype=torch.float16,
|
101 |
-
bias=False
|
102 |
-
)
|
103 |
-
# Print output shape for the linear layer
|
104 |
-
# Load weights for the linear layer as a tensor
|
105 |
-
linear_data = torch.load("files/price_linear.pth", map_location="cpu")
|
106 |
-
probe.weight.data = linear_data["weight"]
|
107 |
-
|
108 |
-
# Do inference
|
109 |
-
with torch.autocast("cpu"):
|
110 |
-
probe.eval()
|
111 |
-
probe = probe.to(device)
|
112 |
-
output = probe(query_features)
|
113 |
-
output = torch.softmax(output, dim=-1)
|
114 |
-
#output = output.cpu().detach().numpy().astype("float32")
|
115 |
-
price = output.argmax(axis=-1)[0]
|
116 |
-
price_classes = ["<50", "50-100", "100-150", ">150"]
|
117 |
-
|
118 |
-
return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
|
119 |
-
|
120 |
-
|
121 |
-
def retrieve(query):
|
122 |
-
index_folder = "files/index"
|
123 |
-
num_results = 3
|
124 |
-
|
125 |
-
# Read image metadata
|
126 |
-
metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
|
127 |
-
key_list = metadata_df["key"].tolist()
|
128 |
-
|
129 |
-
# Load the index
|
130 |
-
index = faiss.read_index(os.path.join(index_folder, "image.index"))
|
131 |
-
|
132 |
-
# Encode the query
|
133 |
-
if isinstance(query, str):
|
134 |
-
print("Query is a string")
|
135 |
-
text = clip.tokenize([query]).to(device)
|
136 |
-
query_features = model.encode_text(text)
|
137 |
-
else:
|
138 |
-
print("Query is an image")
|
139 |
-
query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
|
140 |
-
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
141 |
-
query_features = query_features.cpu().detach().numpy().astype("float32")
|
142 |
-
|
143 |
-
d, i = index.search(query_features, num_results)
|
144 |
-
print(f"Found {num_results} items with query '{query}'")
|
145 |
-
indices = i[0]
|
146 |
-
similarities = d[0]
|
147 |
-
|
148 |
-
min_d = min(similarities)
|
149 |
-
max_d = max(similarities)
|
150 |
-
print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
|
151 |
-
|
152 |
-
# Uncomment to generate combined.tar, combine the image_tars into a single tarfile
|
153 |
-
"""
|
154 |
-
dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only"
|
155 |
-
image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))]
|
156 |
-
with tarfile.open("files/combined.tar", "w") as combined_tar:
|
157 |
-
for tar in image_tars:
|
158 |
-
with tarfile.open(tar, "r") as tar_file:
|
159 |
-
for member in tar_file.getmembers():
|
160 |
-
combined_tar.addfile(member, tar_file.extractfile(member))
|
161 |
-
"""
|
162 |
-
|
163 |
-
images = []
|
164 |
-
for idx in indices:
|
165 |
-
image_name = key_list[idx]
|
166 |
-
with tarfile.open("files/combined.tar", "r") as tar_file:
|
167 |
-
image = tar_file.extractfile(f"{image_name}.jpg")
|
168 |
-
image = Image.open(image).copy()
|
169 |
-
# Center crop the image
|
170 |
-
width, height = image.size
|
171 |
-
new_size = min(width, height)
|
172 |
-
image = transforms.CenterCrop(new_size)(image)
|
173 |
-
# Resize the image
|
174 |
-
image = transforms.Resize((600, 600))(image)
|
175 |
-
images.append(image)
|
176 |
-
return images
|
177 |
-
|
178 |
-
|
179 |
-
theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
|
180 |
-
|
181 |
-
with gr.Blocks(
|
182 |
-
theme=theme,
|
183 |
-
css="footer {visibility: hidden}",
|
184 |
-
) as demo:
|
185 |
-
with gr.Tab("Captioning and Prediction"):
|
186 |
-
with gr.Row(variant="compact"):
|
187 |
-
input_img = gr.Image(type="pil", show_label=False)
|
188 |
-
with gr.Column(min_width="80"):
|
189 |
-
btn_generate_caption = gr.Button("Create Description"
|
190 |
-
generated_caption = gr.Textbox(label="Description", show_label=False)
|
191 |
-
gr.Examples(
|
192 |
-
examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"],
|
193 |
-
fn=generate_caption,
|
194 |
-
inputs=input_img,
|
195 |
-
outputs=generated_caption
|
196 |
-
)
|
197 |
-
|
198 |
-
with gr.Row(variant="compact"):
|
199 |
-
brand_img = gr.Image(type="pil", show_label=False)
|
200 |
-
with gr.Column(min_width="80"):
|
201 |
-
btn_predict_brand = gr.Button("Predict Brand"
|
202 |
-
predicted_brand = gr.Textbox(label="Brand", show_label=False)
|
203 |
-
gr.Examples(
|
204 |
-
examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"],
|
205 |
-
fn=predict_brand,
|
206 |
-
inputs=brand_img,
|
207 |
-
outputs=predicted_brand
|
208 |
-
)
|
209 |
-
|
210 |
-
with gr.Column(variant="compact"):
|
211 |
-
btn_estimate = gr.Button("Estimate Price and Reuse"
|
212 |
-
text_box = gr.Textbox(label="Estimates:", show_label=False)
|
213 |
-
|
214 |
-
with gr.Tab("Image Retrieval"):
|
215 |
-
with gr.Row(variant="compact"):
|
216 |
-
with gr.Column():
|
217 |
-
query_img = gr.Image(type="pil", label="Image Query")
|
218 |
-
btn_image_query = gr.Button("Retrieve Garments"
|
219 |
-
img_query_gallery = gr.Gallery(show_label=False
|
220 |
-
gr.Examples(
|
221 |
-
examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"],
|
222 |
-
fn=retrieve,
|
223 |
-
inputs=query_img,
|
224 |
-
outputs=img_query_gallery
|
225 |
-
)
|
226 |
-
|
227 |
-
with gr.Row(variant="compact"):
|
228 |
-
with gr.Column():
|
229 |
-
query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
|
230 |
-
btn_text_query = gr.Button("Retrieve Garments"
|
231 |
-
text_query_gallery = gr.Gallery(show_label=False
|
232 |
-
gr.Examples(
|
233 |
-
examples=["A purple sweater", "A dress with a floral pattern"],
|
234 |
-
fn=retrieve,
|
235 |
-
inputs=query_text,
|
236 |
-
outputs=text_query_gallery
|
237 |
-
)
|
238 |
-
|
239 |
-
# Listeners
|
240 |
-
btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
|
241 |
-
btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
|
242 |
-
btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
|
243 |
-
btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
|
244 |
-
btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
|
245 |
-
|
246 |
-
|
247 |
-
if __name__ == "__main__":
|
248 |
-
demo.launch(
|
249 |
-
# inline=True
|
250 |
-
)
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import clip
|
4 |
+
import faiss
|
5 |
+
import torch
|
6 |
+
import tarfile
|
7 |
+
import gradio as gr
|
8 |
+
import pandas as pd
|
9 |
+
from PIL import Image
|
10 |
+
from braceexpand import braceexpand
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
|
14 |
+
# Load model
|
15 |
+
checkpoint_path = "ViT-B/16"
|
16 |
+
device = "cpu"
|
17 |
+
model, preprocess = clip.load(checkpoint_path, device=device, jit=False)
|
18 |
+
|
19 |
+
|
20 |
+
def generate_caption(img):
|
21 |
+
# Load caption bank
|
22 |
+
df = pd.read_parquet("files/captions.parquet")
|
23 |
+
caption_list = df["caption"].tolist()
|
24 |
+
|
25 |
+
# Load index
|
26 |
+
index = faiss.read_index("files/caption_bank.index")
|
27 |
+
|
28 |
+
# Encode the image and query the caption bank index
|
29 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
30 |
+
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
31 |
+
query_features = query_features.cpu().detach().numpy().astype("float32")
|
32 |
+
|
33 |
+
# Get nearest captions
|
34 |
+
d, i = index.search(query_features, 1)
|
35 |
+
d, i = d[0], i[0]
|
36 |
+
idx = i[0]
|
37 |
+
distance = d[0]
|
38 |
+
|
39 |
+
# Start with a description of the image
|
40 |
+
caption = caption_list[idx]
|
41 |
+
|
42 |
+
print(f"Index: {idx} - Distance: {distance:.2f}")
|
43 |
+
return caption
|
44 |
+
|
45 |
+
|
46 |
+
def predict_brand(img):
|
47 |
+
# Load brand bank
|
48 |
+
df = pd.read_parquet("files/brands.parquet")
|
49 |
+
brand_list = df["brands"].tolist()
|
50 |
+
|
51 |
+
# Load index
|
52 |
+
index = faiss.read_index("files/brand_bank.index")
|
53 |
+
|
54 |
+
# Encode the image and query the brand bank index
|
55 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
56 |
+
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
57 |
+
query_features = query_features.cpu().detach().numpy().astype("float32")
|
58 |
+
|
59 |
+
# Get nearest brands
|
60 |
+
d, i = index.search(query_features, 1)
|
61 |
+
d, i = d[0], i[0]
|
62 |
+
idx = i[0]
|
63 |
+
distance = d[0]
|
64 |
+
|
65 |
+
brand = brand_list[idx]
|
66 |
+
print(f"Index: {idx} - Distance: {distance:.2f}")
|
67 |
+
return brand
|
68 |
+
|
69 |
+
|
70 |
+
def estimate_price_and_usage(img):
|
71 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
72 |
+
|
73 |
+
# Estimate usage
|
74 |
+
num_classes = 2
|
75 |
+
probe = torch.nn.Linear(
|
76 |
+
query_features.shape[-1],
|
77 |
+
num_classes,
|
78 |
+
dtype=torch.float16,
|
79 |
+
bias=False
|
80 |
+
)
|
81 |
+
# Load weights for the linear layer as a tensor
|
82 |
+
linear_data = torch.load("files/reuse_linear.pth", map_location="cpu")
|
83 |
+
probe.weight.data = linear_data["weight"]
|
84 |
+
|
85 |
+
# Do inference
|
86 |
+
with torch.autocast("cpu"):
|
87 |
+
probe.eval()
|
88 |
+
probe = probe.to(device)
|
89 |
+
output = probe(query_features)
|
90 |
+
output = torch.softmax(output, dim=-1)
|
91 |
+
#output = output.cpu().detach().numpy().astype("float32")
|
92 |
+
reuse = output.argmax(axis=-1)[0]
|
93 |
+
reuse_classes = ["Reuse", "Export"]
|
94 |
+
|
95 |
+
# Estimate price
|
96 |
+
num_classes = 4
|
97 |
+
probe = torch.nn.Linear(
|
98 |
+
query_features.shape[-1],
|
99 |
+
num_classes,
|
100 |
+
dtype=torch.float16,
|
101 |
+
bias=False
|
102 |
+
)
|
103 |
+
# Print output shape for the linear layer
|
104 |
+
# Load weights for the linear layer as a tensor
|
105 |
+
linear_data = torch.load("files/price_linear.pth", map_location="cpu")
|
106 |
+
probe.weight.data = linear_data["weight"]
|
107 |
+
|
108 |
+
# Do inference
|
109 |
+
with torch.autocast("cpu"):
|
110 |
+
probe.eval()
|
111 |
+
probe = probe.to(device)
|
112 |
+
output = probe(query_features)
|
113 |
+
output = torch.softmax(output, dim=-1)
|
114 |
+
#output = output.cpu().detach().numpy().astype("float32")
|
115 |
+
price = output.argmax(axis=-1)[0]
|
116 |
+
price_classes = ["<50", "50-100", "100-150", ">150"]
|
117 |
+
|
118 |
+
return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
|
119 |
+
|
120 |
+
|
121 |
+
def retrieve(query):
|
122 |
+
index_folder = "files/index"
|
123 |
+
num_results = 3
|
124 |
+
|
125 |
+
# Read image metadata
|
126 |
+
metadata_df = pd.read_parquet(os.path.join(index_folder, "metadata.parquet"))
|
127 |
+
key_list = metadata_df["key"].tolist()
|
128 |
+
|
129 |
+
# Load the index
|
130 |
+
index = faiss.read_index(os.path.join(index_folder, "image.index"))
|
131 |
+
|
132 |
+
# Encode the query
|
133 |
+
if isinstance(query, str):
|
134 |
+
print("Query is a string")
|
135 |
+
text = clip.tokenize([query]).to(device)
|
136 |
+
query_features = model.encode_text(text)
|
137 |
+
else:
|
138 |
+
print("Query is an image")
|
139 |
+
query_features = model.encode_image(preprocess(query).unsqueeze(0).to(device))
|
140 |
+
query_features = query_features / query_features.norm(dim=-1, keepdim=True)
|
141 |
+
query_features = query_features.cpu().detach().numpy().astype("float32")
|
142 |
+
|
143 |
+
d, i = index.search(query_features, num_results)
|
144 |
+
print(f"Found {num_results} items with query '{query}'")
|
145 |
+
indices = i[0]
|
146 |
+
similarities = d[0]
|
147 |
+
|
148 |
+
min_d = min(similarities)
|
149 |
+
max_d = max(similarities)
|
150 |
+
print(f"The minimum similarity is {min_d:.2f} and the maximum is {max_d:.2f}")
|
151 |
+
|
152 |
+
# Uncomment to generate combined.tar, combine the image_tars into a single tarfile
|
153 |
+
"""
|
154 |
+
dataset_dir = "/fs/sefs1/circularfashion/wargon_webdataset/front_only"
|
155 |
+
image_tars = [os.path.join(dataset_dir, file) for file in sorted(braceexpand("{0000..0028}.tar"))]
|
156 |
+
with tarfile.open("files/combined.tar", "w") as combined_tar:
|
157 |
+
for tar in image_tars:
|
158 |
+
with tarfile.open(tar, "r") as tar_file:
|
159 |
+
for member in tar_file.getmembers():
|
160 |
+
combined_tar.addfile(member, tar_file.extractfile(member))
|
161 |
+
"""
|
162 |
+
|
163 |
+
images = []
|
164 |
+
for idx in indices:
|
165 |
+
image_name = key_list[idx]
|
166 |
+
with tarfile.open("files/combined.tar", "r") as tar_file:
|
167 |
+
image = tar_file.extractfile(f"{image_name}.jpg")
|
168 |
+
image = Image.open(image).copy()
|
169 |
+
# Center crop the image
|
170 |
+
width, height = image.size
|
171 |
+
new_size = min(width, height)
|
172 |
+
image = transforms.CenterCrop(new_size)(image)
|
173 |
+
# Resize the image
|
174 |
+
image = transforms.Resize((600, 600))(image)
|
175 |
+
images.append(image)
|
176 |
+
return images
|
177 |
+
|
178 |
+
|
179 |
+
theme = gr.Theme.from_hub("JohnSmith9982/small_and_pretty")
|
180 |
+
|
181 |
+
with gr.Blocks(
|
182 |
+
theme=theme,
|
183 |
+
css="footer {visibility: hidden}",
|
184 |
+
) as demo:
|
185 |
+
with gr.Tab("Captioning and Prediction"):
|
186 |
+
with gr.Row(variant="compact"):
|
187 |
+
input_img = gr.Image(type="pil", show_label=False)
|
188 |
+
with gr.Column(min_width="80"):
|
189 |
+
btn_generate_caption = gr.Button("Create Description", size="sm")
|
190 |
+
generated_caption = gr.Textbox(label="Description", show_label=False)
|
191 |
+
gr.Examples(
|
192 |
+
examples=["files/examples/example_1.jpg", "files/examples/example_2.jpg"],
|
193 |
+
fn=generate_caption,
|
194 |
+
inputs=input_img,
|
195 |
+
outputs=generated_caption
|
196 |
+
)
|
197 |
+
|
198 |
+
with gr.Row(variant="compact"):
|
199 |
+
brand_img = gr.Image(type="pil", show_label=False)
|
200 |
+
with gr.Column(min_width="80"):
|
201 |
+
btn_predict_brand = gr.Button("Predict Brand", size="sm")
|
202 |
+
predicted_brand = gr.Textbox(label="Brand", show_label=False)
|
203 |
+
gr.Examples(
|
204 |
+
examples=["files/examples/example_brand_1.jpg", "files/examples/example_brand_2.jpg"],
|
205 |
+
fn=predict_brand,
|
206 |
+
inputs=brand_img,
|
207 |
+
outputs=predicted_brand
|
208 |
+
)
|
209 |
+
|
210 |
+
with gr.Column(variant="compact"):
|
211 |
+
btn_estimate = gr.Button("Estimate Price and Reuse", size="sm")
|
212 |
+
text_box = gr.Textbox(label="Estimates:", show_label=False)
|
213 |
+
|
214 |
+
with gr.Tab("Image Retrieval"):
|
215 |
+
with gr.Row(variant="compact"):
|
216 |
+
with gr.Column():
|
217 |
+
query_img = gr.Image(type="pil", label="Image Query")
|
218 |
+
btn_image_query = gr.Button("Retrieve Garments", size="sm")
|
219 |
+
img_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3)
|
220 |
+
gr.Examples(
|
221 |
+
examples=["files/examples/example_retrieval_1.jpg", "files/examples/example_retrieval_2.jpg"],
|
222 |
+
fn=retrieve,
|
223 |
+
inputs=query_img,
|
224 |
+
outputs=img_query_gallery
|
225 |
+
)
|
226 |
+
|
227 |
+
with gr.Row(variant="compact"):
|
228 |
+
with gr.Column():
|
229 |
+
query_text = gr.Textbox(label="Text Query", placeholder="Enter a description")
|
230 |
+
btn_text_query = gr.Button("Retrieve Garments", size="sm")
|
231 |
+
text_query_gallery = gr.Gallery(show_label=False, rows=1, columns=3)
|
232 |
+
gr.Examples(
|
233 |
+
examples=["A purple sweater", "A dress with a floral pattern"],
|
234 |
+
fn=retrieve,
|
235 |
+
inputs=query_text,
|
236 |
+
outputs=text_query_gallery
|
237 |
+
)
|
238 |
+
|
239 |
+
# Listeners
|
240 |
+
btn_generate_caption.click(fn=generate_caption, inputs=input_img, outputs=generated_caption)
|
241 |
+
btn_predict_brand.click(fn=predict_brand, inputs=brand_img, outputs=predicted_brand)
|
242 |
+
btn_estimate.click(fn=estimate_price_and_usage, inputs=input_img, outputs=text_box)
|
243 |
+
btn_image_query.click(fn=retrieve, inputs=query_img, outputs=img_query_gallery)
|
244 |
+
btn_text_query.click(fn=retrieve, inputs=query_text, outputs=text_query_gallery)
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == "__main__":
|
248 |
+
demo.launch(
|
249 |
+
# inline=True
|
250 |
+
)
|