fix threshold slider and matplotlib warning
Browse files
app.py
CHANGED
@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import numpy as np
|
15 |
-
import matplotlib.
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
@@ -170,7 +170,7 @@ def run_classifier(image: Image.Image, threshold):
|
|
170 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
171 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
172 |
|
173 |
-
return *create_tags(threshold, sorted_tag_score), img
|
174 |
|
175 |
def create_tags(threshold, sorted_tag_score: dict):
|
176 |
filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
|
@@ -178,7 +178,7 @@ def create_tags(threshold, sorted_tag_score: dict):
|
|
178 |
return text_no_impl, filtered_tag_score
|
179 |
|
180 |
def clear_image():
|
181 |
-
return "", {}, None
|
182 |
|
183 |
def cam_inference(img, threshold, evt: gr.SelectData):
|
184 |
target_tag = evt.value
|
@@ -274,6 +274,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
274 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
275 |
""")
|
276 |
original_image_state = gr.State() # stash a copy of the input image
|
|
|
277 |
with gr.Row():
|
278 |
with gr.Column():
|
279 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
@@ -285,18 +286,18 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
285 |
image_input.upload(
|
286 |
fn=run_classifier,
|
287 |
inputs=[image_input, threshold_slider],
|
288 |
-
outputs=[tag_string, label_box, original_image_state]
|
289 |
)
|
290 |
|
291 |
image_input.clear(
|
292 |
fn=clear_image,
|
293 |
inputs=[],
|
294 |
-
outputs=[tag_string, label_box, original_image_state]
|
295 |
)
|
296 |
|
297 |
threshold_slider.input(
|
298 |
fn=create_tags,
|
299 |
-
inputs=[threshold_slider],
|
300 |
outputs=[tag_string, label_box]
|
301 |
)
|
302 |
|
|
|
12 |
import torchvision.transforms.functional as TF
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
import numpy as np
|
15 |
+
import matplotlib.colormaps as cm
|
16 |
|
17 |
class Fit(torch.nn.Module):
|
18 |
def __init__(
|
|
|
170 |
tag_score[allowed_tags[indices[i]]] = values[i].item()
|
171 |
sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
|
172 |
|
173 |
+
return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
|
174 |
|
175 |
def create_tags(threshold, sorted_tag_score: dict):
|
176 |
filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
|
|
|
178 |
return text_no_impl, filtered_tag_score
|
179 |
|
180 |
def clear_image():
|
181 |
+
return "", {}, None, {}
|
182 |
|
183 |
def cam_inference(img, threshold, evt: gr.SelectData):
|
184 |
target_tag = evt.value
|
|
|
274 |
Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
|
275 |
""")
|
276 |
original_image_state = gr.State() # stash a copy of the input image
|
277 |
+
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
278 |
with gr.Row():
|
279 |
with gr.Column():
|
280 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
|
|
286 |
image_input.upload(
|
287 |
fn=run_classifier,
|
288 |
inputs=[image_input, threshold_slider],
|
289 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
|
290 |
)
|
291 |
|
292 |
image_input.clear(
|
293 |
fn=clear_image,
|
294 |
inputs=[],
|
295 |
+
outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
|
296 |
)
|
297 |
|
298 |
threshold_slider.input(
|
299 |
fn=create_tags,
|
300 |
+
inputs=[threshold_slider, sorted_tag_score_state],
|
301 |
outputs=[tag_string, label_box]
|
302 |
)
|
303 |
|