drhead commited on
Commit
369ecce
·
verified ·
1 Parent(s): 3cb1c16

fix threshold slider and matplotlib warning

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -12,7 +12,7 @@ from torchvision.transforms import InterpolationMode
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
  import numpy as np
15
- import matplotlib.cm as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
@@ -170,7 +170,7 @@ def run_classifier(image: Image.Image, threshold):
170
  tag_score[allowed_tags[indices[i]]] = values[i].item()
171
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
172
 
173
- return *create_tags(threshold, sorted_tag_score), img
174
 
175
  def create_tags(threshold, sorted_tag_score: dict):
176
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
@@ -178,7 +178,7 @@ def create_tags(threshold, sorted_tag_score: dict):
178
  return text_no_impl, filtered_tag_score
179
 
180
  def clear_image():
181
- return "", {}, None
182
 
183
  def cam_inference(img, threshold, evt: gr.SelectData):
184
  target_tag = evt.value
@@ -274,6 +274,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
274
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
275
  """)
276
  original_image_state = gr.State() # stash a copy of the input image
 
277
  with gr.Row():
278
  with gr.Column():
279
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
@@ -285,18 +286,18 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
285
  image_input.upload(
286
  fn=run_classifier,
287
  inputs=[image_input, threshold_slider],
288
- outputs=[tag_string, label_box, original_image_state]
289
  )
290
 
291
  image_input.clear(
292
  fn=clear_image,
293
  inputs=[],
294
- outputs=[tag_string, label_box, original_image_state]
295
  )
296
 
297
  threshold_slider.input(
298
  fn=create_tags,
299
- inputs=[threshold_slider],
300
  outputs=[tag_string, label_box]
301
  )
302
 
 
12
  import torchvision.transforms.functional as TF
13
  from huggingface_hub import hf_hub_download
14
  import numpy as np
15
+ import matplotlib.colormaps as cm
16
 
17
  class Fit(torch.nn.Module):
18
  def __init__(
 
170
  tag_score[allowed_tags[indices[i]]] = values[i].item()
171
  sorted_tag_score = dict(sorted(tag_score.items(), key=lambda item: item[1], reverse=True))
172
 
173
+ return *create_tags(threshold, sorted_tag_score), img, sorted_tag_score
174
 
175
  def create_tags(threshold, sorted_tag_score: dict):
176
  filtered_tag_score = {key: value for key, value in sorted_tag_score.items() if value > threshold}
 
178
  return text_no_impl, filtered_tag_score
179
 
180
  def clear_image():
181
+ return "", {}, None, {}
182
 
183
  def cam_inference(img, threshold, evt: gr.SelectData):
184
  target_tag = evt.value
 
274
  Special thanks to Minotoro at frosting.ai for providing the compute power for this project.
275
  """)
276
  original_image_state = gr.State() # stash a copy of the input image
277
+ sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
278
  with gr.Row():
279
  with gr.Column():
280
  image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
 
286
  image_input.upload(
287
  fn=run_classifier,
288
  inputs=[image_input, threshold_slider],
289
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
290
  )
291
 
292
  image_input.clear(
293
  fn=clear_image,
294
  inputs=[],
295
+ outputs=[tag_string, label_box, original_image_state, sorted_tag_score_state]
296
  )
297
 
298
  threshold_slider.input(
299
  fn=create_tags,
300
+ inputs=[threshold_slider, sorted_tag_score_state],
301
  outputs=[tag_string, label_box]
302
  )
303