shvardhan commited on
Commit
7f4efbd
·
verified ·
1 Parent(s): ba8abbe

Upload 5 files

Browse files

Add application file

Files changed (3) hide show
  1. app.py +136 -0
  2. model.py +74 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import pathlib
7
+ import subprocess
8
+ import tarfile
9
+
10
+ import cv2
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ from model import AppModel
15
+
16
+ DESCRIPTION = '''# MMDetection
17
+ This is an unofficial demo for [https://github.com/open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection).
18
+ <img id="overview" alt="overview" src="https://user-images.githubusercontent.com/12907710/137271636-56ba1cd2-b110-4812-8221-b4c120320aa9.png" />
19
+ '''
20
+
21
+ DEFAULT_MODEL_TYPE = 'detection'
22
+ DEFAULT_MODEL_NAMES = {
23
+ 'detection': 'YOLOX-l',
24
+ 'instance_segmentation': 'QueryInst (R-50-FPN)',
25
+ 'panoptic_segmentation': 'MaskFormer (R-50)',
26
+ }
27
+ DEFAULT_MODEL_NAME = DEFAULT_MODEL_NAMES[DEFAULT_MODEL_TYPE]
28
+
29
+
30
+
31
+ def update_input_image(image: np.ndarray) -> dict:
32
+ if image is None:
33
+ return gr.Image.update(value=None)
34
+ scale = 1500 / max(image.shape[:2])
35
+ if scale < 1:
36
+ image = cv2.resize(image, None, fx=scale, fy=scale)
37
+ return gr.Image.update(value=image)
38
+
39
+
40
+ def update_model_name(model_type: str) -> dict:
41
+ model_dict = getattr(AppModel, f'{model_type.upper()}_MODEL_DICT')
42
+ model_names = list(model_dict.keys())
43
+ model_name = DEFAULT_MODEL_NAMES[model_type]
44
+ return gr.Dropdown.update(choices=model_names, value=model_name)
45
+
46
+
47
+ def update_visualization_score_threshold(model_type: str) -> dict:
48
+ return gr.Slider.update(visible=model_type != 'panoptic_segmentation')
49
+
50
+
51
+ def update_redraw_button(model_type: str) -> dict:
52
+ return gr.Button.update(visible=model_type != 'panoptic_segmentation')
53
+
54
+
55
+ def set_example_image(example: list) -> dict:
56
+ return gr.Image.update(value=example[0])
57
+
58
+
59
+ model = AppModel(DEFAULT_MODEL_NAME)
60
+
61
+ with gr.Blocks(css='style.css') as demo:
62
+ gr.Markdown(DESCRIPTION)
63
+
64
+ with gr.Row():
65
+ with gr.Column():
66
+ with gr.Row():
67
+ input_image = gr.Image(label='Input Image', type='numpy')
68
+ with gr.Group():
69
+ with gr.Row():
70
+ model_type = gr.Radio(list(DEFAULT_MODEL_NAMES.keys()),
71
+ value=DEFAULT_MODEL_TYPE,
72
+ label='Model Type')
73
+ with gr.Row():
74
+ model_name = gr.Dropdown(list(
75
+ model.DETECTION_MODEL_DICT.keys()),
76
+ value=DEFAULT_MODEL_NAME,
77
+ label='Model')
78
+ with gr.Row():
79
+ run_button = gr.Button(value='Run')
80
+ prediction_results = gr.Variable()
81
+ with gr.Column():
82
+ with gr.Row():
83
+ visualization = gr.Image(label='Result', type='numpy')
84
+ with gr.Row():
85
+ visualization_score_threshold = gr.Slider(
86
+ 0,
87
+ 1,
88
+ step=0.05,
89
+ value=0.3,
90
+ label='Visualization Score Threshold')
91
+ with gr.Row():
92
+ redraw_button = gr.Button(value='Redraw')
93
+
94
+ with gr.Row():
95
+ paths = sorted(pathlib.Path('images').rglob('*.jpg'))
96
+ example_images = gr.Dataset(components=[input_image],
97
+ samples=[[path.as_posix()]
98
+ for path in paths])
99
+
100
+ input_image.change(fn=update_input_image,
101
+ inputs=input_image,
102
+ outputs=input_image)
103
+
104
+ model_type.change(fn=update_model_name,
105
+ inputs=model_type,
106
+ outputs=model_name)
107
+ model_type.change(fn=update_visualization_score_threshold,
108
+ inputs=model_type,
109
+ outputs=visualization_score_threshold)
110
+ model_type.change(fn=update_redraw_button,
111
+ inputs=model_type,
112
+ outputs=redraw_button)
113
+
114
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
115
+ run_button.click(fn=model.run,
116
+ inputs=[
117
+ model_name,
118
+ input_image,
119
+ visualization_score_threshold,
120
+ ],
121
+ outputs=[
122
+ prediction_results,
123
+ visualization,
124
+ ])
125
+ redraw_button.click(fn=model.visualize_detection_results,
126
+ inputs=[
127
+ input_image,
128
+ prediction_results,
129
+ visualization_score_threshold,
130
+ ],
131
+ outputs=visualization)
132
+ example_images.click(fn=set_example_image,
133
+ inputs=example_images,
134
+ outputs=input_image)
135
+
136
+ demo.queue().launch(show_api=False)
model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import huggingface_hub
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import yaml # type: ignore
10
+ from mmdet.apis import inference_detector, init_detector
11
+
12
+
13
+
14
+
15
+
16
+ class Model:
17
+
18
+ def __init__(self, model_name: str):
19
+ self.device = torch.device(
20
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self.model_name = model_name
22
+ self.model = self._load_model(model_name)
23
+
24
+
25
+ def _load_model(self, name: str) -> nn.Module:
26
+ dic = self.MODEL_DICT[name]
27
+ return init_detector('configs/_base_/faster-rcnn_r50_fpn_1x_coco.py','models/orgaquanT-pretarined.pth' , device=self.device)
28
+
29
+ def set_model(self, name: str) -> None:
30
+ if name == self.model_name:
31
+ return
32
+ self.model_name = name
33
+ self.model = self._load_model(name)
34
+
35
+ def detect_and_visualize(
36
+ self, image: np.ndarray, score_threshold: float
37
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
38
+ list[list[np.ndarray]]]
39
+ | dict[str, np.ndarray], np.ndarray]:
40
+ out = self.detect(image)
41
+ vis = self.visualize_detection_results(image, out, score_threshold)
42
+ return out, vis
43
+
44
+ def detect(
45
+ self, image: np.ndarray
46
+ ) -> list[np.ndarray] | tuple[
47
+ list[np.ndarray], list[list[np.ndarray]]] | dict[str, np.ndarray]:
48
+ out = inference_detector(self.model, image)
49
+ return out
50
+
51
+ def visualize_detection_results(
52
+ self,
53
+ image: np.ndarray,
54
+ detection_results: list[np.ndarray]
55
+ | tuple[list[np.ndarray], list[list[np.ndarray]]]
56
+ | dict[str, np.ndarray],
57
+ score_threshold: float = 0.3) -> np.ndarray:
58
+ vis = self.model.show_result(image,
59
+ detection_results,
60
+ score_thr=score_threshold,
61
+ bbox_color=None,
62
+ text_color=(200, 200, 200),
63
+ mask_color=None)
64
+ return vis
65
+
66
+
67
+ class AppModel(Model):
68
+ def run(
69
+ self, model_name: str, image: np.ndarray, score_threshold: float
70
+ ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray],
71
+ list[list[np.ndarray]]]
72
+ | dict[str, np.ndarray], np.ndarray]:
73
+ self.set_model(model_name)
74
+ return self.detect_and_visualize(image, score_threshold)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ mmcv-full==1.5.2
2
+ mmdet==2.25.0
3
+ numpy==1.22.4
4
+ opencv-python-headless==4.5.5.64
5
+ openmim==0.1.5
6
+ torch==1.11.0
7
+ torchvision==0.12.0