kevinconka commited on
Commit
9b3bb2e
·
1 Parent(s): fb78bf6

Refactor flagging mechanism in app.py to use dedicated function for image flagging and clean up commented code.

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -7,7 +7,7 @@ Any new model should implement the following functions:
7
 
8
  import os
9
  import glob
10
- import spaces
11
  import gradio as gr
12
  from huggingface_hub import get_token
13
  from utils import (
@@ -50,12 +50,16 @@ model = load_model("experimental/ahoy6-MIX-1280-b1.onnx")
50
  model.det_conf_thresh = 0.1
51
  model.hor_conf_thresh = 0.1
52
 
53
- @spaces.GPU
54
  def inference(image):
55
  """Run inference on image and return annotated image."""
56
  results = model(image)
57
  return results.draw(image)
58
 
 
 
 
 
59
  # Flagging
60
  dataset_name = "SEA-AI/crowdsourced-sea-images"
61
  hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
@@ -131,14 +135,14 @@ with gr.Blocks(theme=theme, css=css, title="SEA.AI Vision Demo") as demo:
131
  [flag],
132
  show_api=False,
133
  ).then(
134
- lambda img_input, flag: hf_writer.flag(img_input, flag),
135
- [img_input, flag],
136
  [],
137
  preprocess=False,
138
  show_api=True,
139
  api_name="flag_misdetection"
140
  ).then(
141
- lambda: load_badges(flagged_counter.count()), [], badges, show_api=False
142
  )
143
 
144
  # called during initial load in browser
 
7
 
8
  import os
9
  import glob
10
+ #import spaces
11
  import gradio as gr
12
  from huggingface_hub import get_token
13
  from utils import (
 
50
  model.det_conf_thresh = 0.1
51
  model.hor_conf_thresh = 0.1
52
 
53
+ # @spaces.GPU
54
  def inference(image):
55
  """Run inference on image and return annotated image."""
56
  results = model(image)
57
  return results.draw(image)
58
 
59
+ def flag_img_input(image: gr.Image, flag_option: str = "misdetection", username: str = "anonymous"):
60
+ """Wrapper for flagging"""
61
+ return hf_writer.flag([image], flag_option=flag_option, username=username)
62
+
63
  # Flagging
64
  dataset_name = "SEA-AI/crowdsourced-sea-images"
65
  hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
 
135
  [flag],
136
  show_api=False,
137
  ).then(
138
+ flag_img_input,
139
+ [img_input],
140
  [],
141
  preprocess=False,
142
  show_api=True,
143
  api_name="flag_misdetection"
144
  ).then(
145
+ lambda: load_badges(flagged_counter.count()), [], badges, show_api=False,
146
  )
147
 
148
  # called during initial load in browser