hysts HF Staff commited on
Commit
197bb86
·
1 Parent(s): 8f4581c
Files changed (2) hide show
  1. app.py +5 -4
  2. model.py +14 -0
app.py CHANGED
@@ -20,7 +20,7 @@ import cv2
20
  import gradio as gr
21
  import numpy as np
22
 
23
- from model import Model
24
 
25
  TITLE = '# MMDetection'
26
  DESCRIPTION = '''
@@ -67,7 +67,7 @@ def update_input_image(image: np.ndarray) -> dict:
67
 
68
 
69
  def update_model_name(model_type: str) -> dict:
70
- model_dict = getattr(Model, f'{model_type.upper()}_MODEL_DICT')
71
  model_names = list(model_dict.keys())
72
  model_name = DEFAULT_MODEL_NAMES[model_type]
73
  return gr.Dropdown.update(choices=model_names, value=model_name)
@@ -88,7 +88,7 @@ def set_example_image(example: list) -> dict:
88
  def main():
89
  args = parse_args()
90
  extract_tar()
91
- model = Model(DEFAULT_MODEL_NAME, args.device)
92
 
93
  with gr.Blocks(theme=args.theme, css='style.css') as demo:
94
  gr.Markdown(TITLE)
@@ -147,8 +147,9 @@ def main():
147
  outputs=redraw_button)
148
 
149
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
150
- run_button.click(fn=model.detect_and_visualize,
151
  inputs=[
 
152
  input_image,
153
  visualization_score_threshold,
154
  ],
 
20
  import gradio as gr
21
  import numpy as np
22
 
23
+ from model import AppModel
24
 
25
  TITLE = '# MMDetection'
26
  DESCRIPTION = '''
 
67
 
68
 
69
  def update_model_name(model_type: str) -> dict:
70
+ model_dict = getattr(AppModel, f'{model_type.upper()}_MODEL_DICT')
71
  model_names = list(model_dict.keys())
72
  model_name = DEFAULT_MODEL_NAMES[model_type]
73
  return gr.Dropdown.update(choices=model_names, value=model_name)
 
88
  def main():
89
  args = parse_args()
90
  extract_tar()
91
+ model = AppModel(DEFAULT_MODEL_NAME, args.device)
92
 
93
  with gr.Blocks(theme=args.theme, css='style.css') as demo:
94
  gr.Markdown(TITLE)
 
147
  outputs=redraw_button)
148
 
149
  model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
150
+ run_button.click(fn=model.run,
151
  inputs=[
152
+ model_name,
153
  input_image,
154
  visualization_score_threshold,
155
  ],
model.py CHANGED
@@ -51,6 +51,7 @@ class Model:
51
  def __init__(self, model_name: str, device: str | torch.device):
52
  self.device = torch.device(device)
53
  self._load_all_models_once()
 
54
  self.model = self._load_model(model_name)
55
 
56
  def _load_all_models_once(self) -> None:
@@ -62,6 +63,9 @@ class Model:
62
  return init_detector(dic['config'], dic['model'], device=self.device)
63
 
64
  def set_model(self, name: str) -> None:
 
 
 
65
  self.model = self._load_model(name)
66
 
67
  def detect_and_visualize(
@@ -96,3 +100,13 @@ class Model:
96
  text_color=(200, 200, 200),
97
  mask_color=None)
98
  return vis[:, :, ::-1] # BGR -> RGB
 
 
 
 
 
 
 
 
 
 
 
51
  def __init__(self, model_name: str, device: str | torch.device):
52
  self.device = torch.device(device)
53
  self._load_all_models_once()
54
+ self.model_name = model_name
55
  self.model = self._load_model(model_name)
56
 
57
  def _load_all_models_once(self) -> None:
 
63
  return init_detector(dic['config'], dic['model'], device=self.device)
64
 
65
  def set_model(self, name: str) -> None:
66
+ if name == self.model_name:
67
+ return
68
+ self.model_name = name
69
  self.model = self._load_model(name)
70
 
71
  def detect_and_visualize(
 
100
  text_color=(200, 200, 200),
101
  mask_color=None)
102
  return vis[:, :, ::-1] # BGR -> RGB
103
+
104
+
105
+ class AppModel(Model):
106
+ def run(
107
+ self, model_name: str, image: np.ndarray, score_threshold: float
108
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
109
+ list[list[np.ndarray]]]
110
+ | dict[str, np.ndarray], np.ndarray]:
111
+ self.set_model(model_name)
112
+ return self.detect_and_visualize(image, score_threshold)