Spaces:
Runtime error
Runtime error
Fix
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ import cv2
|
|
20 |
import gradio as gr
|
21 |
import numpy as np
|
22 |
|
23 |
-
from model import
|
24 |
|
25 |
TITLE = '# MMDetection'
|
26 |
DESCRIPTION = '''
|
@@ -67,7 +67,7 @@ def update_input_image(image: np.ndarray) -> dict:
|
|
67 |
|
68 |
|
69 |
def update_model_name(model_type: str) -> dict:
|
70 |
-
model_dict = getattr(
|
71 |
model_names = list(model_dict.keys())
|
72 |
model_name = DEFAULT_MODEL_NAMES[model_type]
|
73 |
return gr.Dropdown.update(choices=model_names, value=model_name)
|
@@ -88,7 +88,7 @@ def set_example_image(example: list) -> dict:
|
|
88 |
def main():
|
89 |
args = parse_args()
|
90 |
extract_tar()
|
91 |
-
model =
|
92 |
|
93 |
with gr.Blocks(theme=args.theme, css='style.css') as demo:
|
94 |
gr.Markdown(TITLE)
|
@@ -147,8 +147,9 @@ def main():
|
|
147 |
outputs=redraw_button)
|
148 |
|
149 |
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
|
150 |
-
run_button.click(fn=model.
|
151 |
inputs=[
|
|
|
152 |
input_image,
|
153 |
visualization_score_threshold,
|
154 |
],
|
|
|
20 |
import gradio as gr
|
21 |
import numpy as np
|
22 |
|
23 |
+
from model import AppModel
|
24 |
|
25 |
TITLE = '# MMDetection'
|
26 |
DESCRIPTION = '''
|
|
|
67 |
|
68 |
|
69 |
def update_model_name(model_type: str) -> dict:
|
70 |
+
model_dict = getattr(AppModel, f'{model_type.upper()}_MODEL_DICT')
|
71 |
model_names = list(model_dict.keys())
|
72 |
model_name = DEFAULT_MODEL_NAMES[model_type]
|
73 |
return gr.Dropdown.update(choices=model_names, value=model_name)
|
|
|
88 |
def main():
|
89 |
args = parse_args()
|
90 |
extract_tar()
|
91 |
+
model = AppModel(DEFAULT_MODEL_NAME, args.device)
|
92 |
|
93 |
with gr.Blocks(theme=args.theme, css='style.css') as demo:
|
94 |
gr.Markdown(TITLE)
|
|
|
147 |
outputs=redraw_button)
|
148 |
|
149 |
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
|
150 |
+
run_button.click(fn=model.run,
|
151 |
inputs=[
|
152 |
+
model_name,
|
153 |
input_image,
|
154 |
visualization_score_threshold,
|
155 |
],
|
model.py
CHANGED
@@ -51,6 +51,7 @@ class Model:
|
|
51 |
def __init__(self, model_name: str, device: str | torch.device):
|
52 |
self.device = torch.device(device)
|
53 |
self._load_all_models_once()
|
|
|
54 |
self.model = self._load_model(model_name)
|
55 |
|
56 |
def _load_all_models_once(self) -> None:
|
@@ -62,6 +63,9 @@ class Model:
|
|
62 |
return init_detector(dic['config'], dic['model'], device=self.device)
|
63 |
|
64 |
def set_model(self, name: str) -> None:
|
|
|
|
|
|
|
65 |
self.model = self._load_model(name)
|
66 |
|
67 |
def detect_and_visualize(
|
@@ -96,3 +100,13 @@ class Model:
|
|
96 |
text_color=(200, 200, 200),
|
97 |
mask_color=None)
|
98 |
return vis[:, :, ::-1] # BGR -> RGB
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def __init__(self, model_name: str, device: str | torch.device):
|
52 |
self.device = torch.device(device)
|
53 |
self._load_all_models_once()
|
54 |
+
self.model_name = model_name
|
55 |
self.model = self._load_model(model_name)
|
56 |
|
57 |
def _load_all_models_once(self) -> None:
|
|
|
63 |
return init_detector(dic['config'], dic['model'], device=self.device)
|
64 |
|
65 |
def set_model(self, name: str) -> None:
|
66 |
+
if name == self.model_name:
|
67 |
+
return
|
68 |
+
self.model_name = name
|
69 |
self.model = self._load_model(name)
|
70 |
|
71 |
def detect_and_visualize(
|
|
|
100 |
text_color=(200, 200, 200),
|
101 |
mask_color=None)
|
102 |
return vis[:, :, ::-1] # BGR -> RGB
|
103 |
+
|
104 |
+
|
105 |
+
class AppModel(Model):
|
106 |
+
def run(
|
107 |
+
self, model_name: str, image: np.ndarray, score_threshold: float
|
108 |
+
) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
|
109 |
+
list[list[np.ndarray]]]
|
110 |
+
| dict[str, np.ndarray], np.ndarray]:
|
111 |
+
self.set_model(model_name)
|
112 |
+
return self.detect_and_visualize(image, score_threshold)
|