wenkun commited on
Commit
e4580db
·
1 Parent(s): d05dfd9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import subprocess
10
+ import sys
11
+ import urllib.request
12
+
13
+ if os.environ.get('SYSTEM') == 'spaces':
14
+ import mim
15
+ mim.install('mmcv-full==1.3.3', is_yes=True)
16
+
17
+ subprocess.call('pip uninstall -y opencv-python'.split())
18
+ subprocess.call('pip uninstall -y opencv-python-headless'.split())
19
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
20
+ subprocess.call('pip install terminaltables==3.1.0'.split())
21
+ subprocess.call('pip install mmpycocotools==12.0.3'.split())
22
+
23
+ subprocess.call('pip install insightface==0.6.2'.split())
24
+
25
+
26
+ import cv2
27
+ import gradio as gr
28
+ import huggingface_hub
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ sys.path.insert(0, 'insightface/detection/scrfd')
34
+
35
+ from mmdet.apis import inference_detector, init_detector, show_result_pyplot
36
+
37
+ TITLE = 'insightface Face Detection (SCRFD)'
38
+ DESCRIPTION = 'This is an unofficial demo for https://github.com/deepinsight/insightface/tree/master/detection/scrfd.'
39
+ ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.insightface-scrfd" alt="visitor badge"/></center>'
40
+
41
+ TOKEN = os.environ['TOKEN']
42
+
43
+
44
+ def parse_args() -> argparse.Namespace:
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument('--face-score-slider-step', type=float, default=0.05)
47
+ parser.add_argument('--face-score-threshold', type=float, default=0.3)
48
+ parser.add_argument('--device', type=str, default='cpu')
49
+ parser.add_argument('--theme', type=str)
50
+ parser.add_argument('--live', action='store_true')
51
+ parser.add_argument('--share', action='store_true')
52
+ parser.add_argument('--port', type=int)
53
+ parser.add_argument('--disable-queue',
54
+ dest='enable_queue',
55
+ action='store_false')
56
+ parser.add_argument('--allow-flagging', type=str, default='never')
57
+ return parser.parse_args()
58
+
59
+
60
+ def load_model(model_size: str, device) -> nn.Module:
61
+ ckpt_path = huggingface_hub.hf_hub_download(
62
+ 'hysts/insightface',
63
+ f'models/scrfd_{model_size}/model.pth',
64
+ use_auth_token=TOKEN)
65
+ scrfd_dir = 'insightface/detection/scrfd'
66
+ config_path = f'{scrfd_dir}/configs/scrfd/scrfd_{model_size}.py'
67
+ model = init_detector(config_path, ckpt_path, device.type)
68
+ return model
69
+
70
+
71
+ def update_test_pipeline(model: nn.Module, mode: int):
72
+ cfg = model.cfg
73
+ pipelines = cfg.data.test.pipeline
74
+ for pipeline in pipelines:
75
+ if pipeline.type == 'MultiScaleFlipAug':
76
+ if mode == 0: #640 scale
77
+ pipeline.img_scale = (640, 640)
78
+ if hasattr(pipeline, 'scale_factor'):
79
+ del pipeline.scale_factor
80
+ elif mode == 1: #for single scale in other pages
81
+ pipeline.img_scale = (1100, 1650)
82
+ if hasattr(pipeline, 'scale_factor'):
83
+ del pipeline.scale_factor
84
+ elif mode == 2: #original scale
85
+ pipeline.img_scale = None
86
+ pipeline.scale_factor = 1.0
87
+ transforms = pipeline.transforms
88
+ for transform in transforms:
89
+ if transform.type == 'Pad':
90
+ if mode != 2:
91
+ transform.size = pipeline.img_scale
92
+ if hasattr(transform, 'size_divisor'):
93
+ del transform.size_divisor
94
+ else:
95
+ transform.size = None
96
+ transform.size_divisor = 32
97
+
98
+
99
+ def detect(image: np.ndarray, model_size: str, mode: int,
100
+ face_score_threshold: float,
101
+ detectors: dict[str, nn.Module]) -> np.ndarray:
102
+ model = detectors[model_size]
103
+ update_test_pipeline(model, mode)
104
+
105
+ # RGB -> BGR
106
+ image = image[:, :, ::-1]
107
+ preds = inference_detector(model, image)
108
+ boxes = preds[0]
109
+
110
+ res = image.copy()
111
+ for box in boxes:
112
+ box, score = box[:4], box[4]
113
+ if score < face_score_threshold:
114
+ continue
115
+ box = np.round(box).astype(int)
116
+
117
+ line_width = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
118
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0),
119
+ line_width)
120
+
121
+ res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
122
+ return res
123
+
124
+
125
+ def main():
126
+ args = parse_args()
127
+ device = torch.device(args.device)
128
+
129
+ model_sizes = [
130
+ '500m',
131
+ '1g',
132
+ '2.5g',
133
+ '10g',
134
+ '34g',
135
+ ]
136
+ detectors = {
137
+ model_size: load_model(model_size, device=device)
138
+ for model_size in model_sizes
139
+ }
140
+ modes = [
141
+ '(640, 640)',
142
+ '(1100, 1650)',
143
+ 'original',
144
+ ]
145
+
146
+ func = functools.partial(detect, detectors=detectors)
147
+ func = functools.update_wrapper(func, detect)
148
+
149
+ image_path = pathlib.Path('selfie.jpg')
150
+ if not image_path.exists():
151
+ url = 'https://raw.githubusercontent.com/peiyunh/tiny/master/data/demo/selfie.jpg'
152
+ urllib.request.urlretrieve(url, image_path)
153
+ examples = [[image_path.as_posix(), '10g', modes[0], 0.3]]
154
+
155
+ gr.Interface(
156
+ func,
157
+ [
158
+ gr.inputs.Image(type='numpy', label='Input'),
159
+ gr.inputs.Radio(
160
+ model_sizes, type='value', default='10g', label='Model'),
161
+ gr.inputs.Radio(
162
+ modes, type='index', default=modes[0], label='Mode'),
163
+ gr.inputs.Slider(0,
164
+ 1,
165
+ step=args.face_score_slider_step,
166
+ default=args.face_score_threshold,
167
+ label='Face Score Threshold'),
168
+ ],
169
+ gr.outputs.Image(type='numpy', label='Output'),
170
+ examples=examples,
171
+ title=TITLE,
172
+ description=DESCRIPTION,
173
+ article=ARTICLE,
174
+ theme=args.theme,
175
+ allow_flagging=args.allow_flagging,
176
+ live=args.live,
177
+ ).launch(
178
+ enable_queue=args.enable_queue,
179
+ server_port=args.port,
180
+ share=args.share,
181
+ )
182
+
183
+
184
+ if __name__ == '__main__':
185
+ main()