hysts HF staff commited on
Commit
e6d1870
·
1 Parent(s): 7db2ef4
Files changed (4) hide show
  1. .pre-commit-config.yaml +59 -36
  2. README.md +1 -1
  3. app.py +74 -77
  4. model.py +91 -116
.pre-commit-config.yaml CHANGED
@@ -1,37 +1,60 @@
1
- exclude: ^(ViTPose/|mmdet_configs/configs/)
2
  repos:
3
- - repo: https://github.com/pre-commit/pre-commit-hooks
4
- rev: v4.2.0
5
- hooks:
6
- - id: check-executables-have-shebangs
7
- - id: check-json
8
- - id: check-merge-conflict
9
- - id: check-shebang-scripts-are-executable
10
- - id: check-toml
11
- - id: check-yaml
12
- - id: double-quote-string-fixer
13
- - id: end-of-file-fixer
14
- - id: mixed-line-ending
15
- args: ['--fix=lf']
16
- - id: requirements-txt-fixer
17
- - id: trailing-whitespace
18
- - repo: https://github.com/myint/docformatter
19
- rev: v1.4
20
- hooks:
21
- - id: docformatter
22
- args: ['--in-place']
23
- - repo: https://github.com/pycqa/isort
24
- rev: 5.12.0
25
- hooks:
26
- - id: isort
27
- - repo: https://github.com/pre-commit/mirrors-mypy
28
- rev: v0.991
29
- hooks:
30
- - id: mypy
31
- args: ['--ignore-missing-imports']
32
- additional_dependencies: ['types-python-slugify']
33
- - repo: https://github.com/google/yapf
34
- rev: v0.32.0
35
- hooks:
36
- - id: yapf
37
- args: ['--parallel', '--in-place']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.6.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.10.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.4.2
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.8.5
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  suggested_hardware: t4-small
app.py CHANGED
@@ -2,109 +2,106 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import pathlib
 
 
6
  import tarfile
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import gradio as gr
9
 
10
  from model import AppModel
11
 
12
- DESCRIPTION = '''# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)
13
 
14
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
15
- '''
16
 
17
 
18
  def extract_tar() -> None:
19
- if pathlib.Path('mmdet_configs/configs').exists():
20
  return
21
- with tarfile.open('mmdet_configs/configs.tar') as f:
22
- f.extractall('mmdet_configs')
23
 
24
 
25
  extract_tar()
26
 
27
  model = AppModel()
28
 
29
- with gr.Blocks(css='style.css') as demo:
30
  gr.Markdown(DESCRIPTION)
31
 
32
  with gr.Row():
33
  with gr.Column():
34
- input_video = gr.Video(label='Input Video',
35
- format='mp4',
36
- elem_id='input_video')
37
- detector_name = gr.Dropdown(label='Detector',
38
- choices=list(
39
- model.det_model.MODEL_DICT.keys()),
40
- value=model.det_model.model_name)
41
  pose_model_name = gr.Dropdown(
42
- label='Pose Model',
43
- choices=list(model.pose_model.MODEL_DICT.keys()),
44
- value=model.pose_model.model_name)
45
- det_score_threshold = gr.Slider(label='Box Score Threshold',
46
- minimum=0,
47
- maximum=1,
48
- step=0.05,
49
- value=0.5)
50
- max_num_frames = gr.Slider(label='Maximum Number of Frames',
51
- minimum=1,
52
- maximum=300,
53
- step=1,
54
- value=60)
55
- predict_button = gr.Button('Predict')
56
- pose_preds = gr.Variable()
57
-
58
- paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
59
- gr.Examples(examples=[[path.as_posix()] for path in paths],
60
- inputs=input_video)
61
 
62
  with gr.Column():
63
- result = gr.Video(label='Result', format='mp4', elem_id='result')
64
  vis_kpt_score_threshold = gr.Slider(
65
- label='Visualization Score Threshold',
66
- minimum=0,
67
- maximum=1,
68
- step=0.05,
69
- value=0.3)
70
- vis_dot_radius = gr.Slider(label='Dot Radius',
71
- minimum=1,
72
- maximum=10,
73
- step=1,
74
- value=4)
75
- vis_line_thickness = gr.Slider(label='Line Thickness',
76
- minimum=1,
77
- maximum=10,
78
- step=1,
79
- value=2)
80
- redraw_button = gr.Button('Redraw')
81
 
82
  detector_name.change(fn=model.det_model.set_model, inputs=detector_name)
83
- pose_model_name.change(fn=model.pose_model.set_model,
84
- inputs=pose_model_name)
85
- predict_button.click(fn=model.run,
86
- inputs=[
87
- input_video,
88
- detector_name,
89
- pose_model_name,
90
- det_score_threshold,
91
- max_num_frames,
92
- vis_kpt_score_threshold,
93
- vis_dot_radius,
94
- vis_line_thickness,
95
- ],
96
- outputs=[
97
- result,
98
- pose_preds,
99
- ])
100
- redraw_button.click(fn=model.visualize_pose_results,
101
- inputs=[
102
- input_video,
103
- pose_preds,
104
- vis_kpt_score_threshold,
105
- vis_dot_radius,
106
- vis_line_thickness,
107
- ],
108
- outputs=result)
109
-
110
- demo.queue(max_size=10).launch()
 
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ import os
6
  import pathlib
7
+ import shlex
8
+ import subprocess
9
  import tarfile
10
 
11
+ if os.getenv("SYSTEM") == "spaces":
12
+ subprocess.run(shlex.split("pip install click==7.1.2"))
13
+ subprocess.run(shlex.split("pip install typer==0.9.4"))
14
+
15
+ import mim
16
+
17
+ mim.uninstall("mmcv-full", confirm_yes=True)
18
+ mim.install("mmcv-full==1.5.0", is_yes=True)
19
+
20
+ subprocess.call(shlex.split("pip uninstall -y opencv-python"))
21
+ subprocess.call(shlex.split("pip uninstall -y opencv-python-headless"))
22
+ subprocess.call(shlex.split("pip install opencv-python-headless==4.8.0.74"))
23
+
24
+
25
  import gradio as gr
26
 
27
  from model import AppModel
28
 
29
+ DESCRIPTION = """# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)
30
 
31
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
32
+ """
33
 
34
 
35
  def extract_tar() -> None:
36
+ if pathlib.Path("mmdet_configs/configs").exists():
37
  return
38
+ with tarfile.open("mmdet_configs/configs.tar") as f:
39
+ f.extractall("mmdet_configs")
40
 
41
 
42
  extract_tar()
43
 
44
  model = AppModel()
45
 
46
+ with gr.Blocks(css="style.css") as demo:
47
  gr.Markdown(DESCRIPTION)
48
 
49
  with gr.Row():
50
  with gr.Column():
51
+ input_video = gr.Video(label="Input Video", format="mp4", elem_id="input_video")
52
+ detector_name = gr.Dropdown(
53
+ label="Detector", choices=list(model.det_model.MODEL_DICT.keys()), value=model.det_model.model_name
54
+ )
 
 
 
55
  pose_model_name = gr.Dropdown(
56
+ label="Pose Model", choices=list(model.pose_model.MODEL_DICT.keys()), value=model.pose_model.model_name
57
+ )
58
+ det_score_threshold = gr.Slider(label="Box Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
59
+ max_num_frames = gr.Slider(label="Maximum Number of Frames", minimum=1, maximum=300, step=1, value=60)
60
+ predict_button = gr.Button("Predict")
61
+ pose_preds = gr.State()
62
+
63
+ paths = sorted(pathlib.Path("videos").rglob("*.mp4"))
64
+ gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_video)
 
 
 
 
 
 
 
 
 
 
65
 
66
  with gr.Column():
67
+ result = gr.Video(label="Result", format="mp4", elem_id="result")
68
  vis_kpt_score_threshold = gr.Slider(
69
+ label="Visualization Score Threshold", minimum=0, maximum=1, step=0.05, value=0.3
70
+ )
71
+ vis_dot_radius = gr.Slider(label="Dot Radius", minimum=1, maximum=10, step=1, value=4)
72
+ vis_line_thickness = gr.Slider(label="Line Thickness", minimum=1, maximum=10, step=1, value=2)
73
+ redraw_button = gr.Button("Redraw")
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  detector_name.change(fn=model.det_model.set_model, inputs=detector_name)
76
+ pose_model_name.change(fn=model.pose_model.set_model, inputs=pose_model_name)
77
+ predict_button.click(
78
+ fn=model.run,
79
+ inputs=[
80
+ input_video,
81
+ detector_name,
82
+ pose_model_name,
83
+ det_score_threshold,
84
+ max_num_frames,
85
+ vis_kpt_score_threshold,
86
+ vis_dot_radius,
87
+ vis_line_thickness,
88
+ ],
89
+ outputs=[
90
+ result,
91
+ pose_preds,
92
+ ],
93
+ )
94
+ redraw_button.click(
95
+ fn=model.visualize_pose_results,
96
+ inputs=[
97
+ input_video,
98
+ pose_preds,
99
+ vis_kpt_score_threshold,
100
+ vis_dot_radius,
101
+ vis_line_thickness,
102
+ ],
103
+ outputs=result,
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ demo.queue(max_size=10).launch()
model.py CHANGED
@@ -1,68 +1,49 @@
1
  from __future__ import annotations
2
 
3
- import os
4
- import shlex
5
- import subprocess
6
  import sys
7
  import tempfile
8
 
9
- if os.getenv('SYSTEM') == 'spaces':
10
- import mim
11
-
12
- mim.uninstall('mmcv-full', confirm_yes=True)
13
- mim.install('mmcv-full==1.5.0', is_yes=True)
14
-
15
- subprocess.call(shlex.split('pip uninstall -y opencv-python'))
16
- subprocess.call(shlex.split('pip uninstall -y opencv-python-headless'))
17
- subprocess.call(
18
- shlex.split('pip install opencv-python-headless==4.8.0.74'))
19
-
20
  import cv2
21
  import huggingface_hub
22
  import numpy as np
23
  import torch
24
  import torch.nn as nn
25
 
26
- sys.path.insert(0, 'ViTPose/')
27
 
28
  from mmdet.apis import inference_detector, init_detector
29
- from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
- process_mmdet_results, vis_pose_result)
 
 
 
 
31
 
32
 
33
  class DetModel:
34
  MODEL_DICT = {
35
- 'YOLOX-tiny': {
36
- 'config':
37
- 'mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py',
38
- 'model':
39
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth',
40
  },
41
- 'YOLOX-s': {
42
- 'config':
43
- 'mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py',
44
- 'model':
45
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth',
46
  },
47
- 'YOLOX-l': {
48
- 'config':
49
- 'mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py',
50
- 'model':
51
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
52
  },
53
- 'YOLOX-x': {
54
- 'config':
55
- 'mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py',
56
- 'model':
57
- 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth',
58
  },
59
  }
60
 
61
  def __init__(self):
62
- self.device = torch.device(
63
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
64
  self._load_all_models_once()
65
- self.model_name = 'YOLOX-l'
66
  self.model = self._load_model(self.model_name)
67
 
68
  def _load_all_models_once(self) -> None:
@@ -71,7 +52,7 @@ class DetModel:
71
 
72
  def _load_model(self, name: str) -> nn.Module:
73
  d = self.MODEL_DICT[name]
74
- return init_detector(d['config'], d['model'], device=self.device)
75
 
76
  def set_model(self, name: str) -> None:
77
  if name == self.model_name:
@@ -79,9 +60,7 @@ class DetModel:
79
  self.model_name = name
80
  self.model = self._load_model(name)
81
 
82
- def detect_and_visualize(
83
- self, image: np.ndarray,
84
- score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
85
  out = self.detect(image)
86
  vis = self.visualize_detection_results(image, out, score_threshold)
87
  return out, vis
@@ -92,50 +71,40 @@ class DetModel:
92
  return out
93
 
94
  def visualize_detection_results(
95
- self,
96
- image: np.ndarray,
97
- detection_results: list[np.ndarray],
98
- score_threshold: float = 0.3) -> np.ndarray:
99
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
100
 
101
  image = image[:, :, ::-1] # RGB -> BGR
102
- vis = self.model.show_result(image,
103
- person_det,
104
- score_thr=score_threshold,
105
- bbox_color=None,
106
- text_color=(200, 200, 200),
107
- mask_color=None)
108
  return vis[:, :, ::-1] # BGR -> RGB
109
 
110
 
111
  class PoseModel:
112
  MODEL_DICT = {
113
- 'ViTPose-B (single-task train)': {
114
- 'config':
115
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
116
- 'model': 'models/vitpose-b.pth',
117
  },
118
- 'ViTPose-L (single-task train)': {
119
- 'config':
120
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
121
- 'model': 'models/vitpose-l.pth',
122
  },
123
- 'ViTPose-B (multi-task train, COCO)': {
124
- 'config':
125
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py',
126
- 'model': 'models/vitpose-b-multi-coco.pth',
127
  },
128
- 'ViTPose-L (multi-task train, COCO)': {
129
- 'config':
130
- 'ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py',
131
- 'model': 'models/vitpose-l-multi-coco.pth',
132
  },
133
  }
134
 
135
  def __init__(self):
136
- self.device = torch.device(
137
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
138
- self.model_name = 'ViTPose-B (multi-task train, COCO)'
139
  self.model = self._load_model(self.model_name)
140
 
141
  def _load_all_models_once(self) -> None:
@@ -144,9 +113,8 @@ class PoseModel:
144
 
145
  def _load_model(self, name: str) -> nn.Module:
146
  d = self.MODEL_DICT[name]
147
- ckpt_path = huggingface_hub.hf_hub_download('public-data/ViTPose',
148
- d['model'])
149
- model = init_pose_model(d['config'], ckpt_path, device=self.device)
150
  return model
151
 
152
  def set_model(self, name: str) -> None:
@@ -165,37 +133,36 @@ class PoseModel:
165
  vis_line_thickness: int,
166
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
167
  out = self.predict_pose(image, det_results, box_score_threshold)
168
- vis = self.visualize_pose_results(image, out, kpt_score_threshold,
169
- vis_dot_radius, vis_line_thickness)
170
  return out, vis
171
 
172
  def predict_pose(
173
- self,
174
- image: np.ndarray,
175
- det_results: list[np.ndarray],
176
- box_score_threshold: float = 0.5) -> list[dict[str, np.ndarray]]:
177
  image = image[:, :, ::-1] # RGB -> BGR
178
  person_results = process_mmdet_results(det_results, 1)
179
- out, _ = inference_top_down_pose_model(self.model,
180
- image,
181
- person_results=person_results,
182
- bbox_thr=box_score_threshold,
183
- format='xyxy')
184
  return out
185
 
186
- def visualize_pose_results(self,
187
- image: np.ndarray,
188
- pose_results: list[dict[str, np.ndarray]],
189
- kpt_score_threshold: float = 0.3,
190
- vis_dot_radius: int = 4,
191
- vis_line_thickness: int = 1) -> np.ndarray:
 
 
192
  image = image[:, :, ::-1] # RGB -> BGR
193
- vis = vis_pose_result(self.model,
194
- image,
195
- pose_results,
196
- kpt_score_thr=kpt_score_threshold,
197
- radius=vis_dot_radius,
198
- thickness=vis_line_thickness)
 
 
199
  return vis[:, :, ::-1] # BGR -> RGB
200
 
201
 
@@ -205,10 +172,15 @@ class AppModel:
205
  self.pose_model = PoseModel()
206
 
207
  def run(
208
- self, video_path: str, det_model_name: str, pose_model_name: str,
209
- box_score_threshold: float, max_num_frames: int,
210
- kpt_score_threshold: float, vis_dot_radius: int,
211
- vis_line_thickness: int
 
 
 
 
 
212
  ) -> tuple[str, list[list[dict[str, np.ndarray]]]]:
213
  if video_path is None:
214
  return
@@ -222,8 +194,8 @@ class AppModel:
222
 
223
  preds_all = []
224
 
225
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
226
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
227
  writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
228
  for _ in range(max_num_frames):
229
  ok, frame = cap.read()
@@ -232,8 +204,8 @@ class AppModel:
232
  rgb_frame = frame[:, :, ::-1]
233
  det_preds = self.det_model.detect(rgb_frame)
234
  preds, vis = self.pose_model.predict_pose_and_visualize(
235
- rgb_frame, det_preds, box_score_threshold, kpt_score_threshold,
236
- vis_dot_radius, vis_line_thickness)
237
  preds_all.append(preds)
238
  writer.write(vis[:, :, ::-1])
239
  cap.release()
@@ -241,11 +213,14 @@ class AppModel:
241
 
242
  return out_file.name, preds_all
243
 
244
- def visualize_pose_results(self, video_path: str,
245
- pose_preds_all: list[list[dict[str,
246
- np.ndarray]]],
247
- kpt_score_threshold: float, vis_dot_radius: int,
248
- vis_line_thickness: int) -> str:
 
 
 
249
  if video_path is None or pose_preds_all is None:
250
  return
251
  cap = cv2.VideoCapture(video_path)
@@ -253,8 +228,8 @@ class AppModel:
253
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
254
  fps = cap.get(cv2.CAP_PROP_FPS)
255
 
256
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
257
- out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
258
  writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
259
  for pose_preds in pose_preds_all:
260
  ok, frame = cap.read()
@@ -262,8 +237,8 @@ class AppModel:
262
  break
263
  rgb_frame = frame[:, :, ::-1]
264
  vis = self.pose_model.visualize_pose_results(
265
- rgb_frame, pose_preds, kpt_score_threshold, vis_dot_radius,
266
- vis_line_thickness)
267
  writer.write(vis[:, :, ::-1])
268
  cap.release()
269
  writer.release()
 
1
  from __future__ import annotations
2
 
 
 
 
3
  import sys
4
  import tempfile
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  import cv2
7
  import huggingface_hub
8
  import numpy as np
9
  import torch
10
  import torch.nn as nn
11
 
12
+ sys.path.insert(0, "ViTPose/")
13
 
14
  from mmdet.apis import inference_detector, init_detector
15
+ from mmpose.apis import (
16
+ inference_top_down_pose_model,
17
+ init_pose_model,
18
+ process_mmdet_results,
19
+ vis_pose_result,
20
+ )
21
 
22
 
23
  class DetModel:
24
  MODEL_DICT = {
25
+ "YOLOX-tiny": {
26
+ "config": "mmdet_configs/configs/yolox/yolox_tiny_8x8_300e_coco.py",
27
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20211124_171234-b4047906.pth",
 
 
28
  },
29
+ "YOLOX-s": {
30
+ "config": "mmdet_configs/configs/yolox/yolox_s_8x8_300e_coco.py",
31
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth",
 
 
32
  },
33
+ "YOLOX-l": {
34
+ "config": "mmdet_configs/configs/yolox/yolox_l_8x8_300e_coco.py",
35
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth",
 
 
36
  },
37
+ "YOLOX-x": {
38
+ "config": "mmdet_configs/configs/yolox/yolox_x_8x8_300e_coco.py",
39
+ "model": "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth",
 
 
40
  },
41
  }
42
 
43
  def __init__(self):
44
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
45
  self._load_all_models_once()
46
+ self.model_name = "YOLOX-l"
47
  self.model = self._load_model(self.model_name)
48
 
49
  def _load_all_models_once(self) -> None:
 
52
 
53
  def _load_model(self, name: str) -> nn.Module:
54
  d = self.MODEL_DICT[name]
55
+ return init_detector(d["config"], d["model"], device=self.device)
56
 
57
  def set_model(self, name: str) -> None:
58
  if name == self.model_name:
 
60
  self.model_name = name
61
  self.model = self._load_model(name)
62
 
63
+ def detect_and_visualize(self, image: np.ndarray, score_threshold: float) -> tuple[list[np.ndarray], np.ndarray]:
 
 
64
  out = self.detect(image)
65
  vis = self.visualize_detection_results(image, out, score_threshold)
66
  return out, vis
 
71
  return out
72
 
73
  def visualize_detection_results(
74
+ self, image: np.ndarray, detection_results: list[np.ndarray], score_threshold: float = 0.3
75
+ ) -> np.ndarray:
 
 
76
  person_det = [detection_results[0]] + [np.array([]).reshape(0, 5)] * 79
77
 
78
  image = image[:, :, ::-1] # RGB -> BGR
79
+ vis = self.model.show_result(
80
+ image, person_det, score_thr=score_threshold, bbox_color=None, text_color=(200, 200, 200), mask_color=None
81
+ )
 
 
 
82
  return vis[:, :, ::-1] # BGR -> RGB
83
 
84
 
85
  class PoseModel:
86
  MODEL_DICT = {
87
+ "ViTPose-B (single-task train)": {
88
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
89
+ "model": "models/vitpose-b.pth",
 
90
  },
91
+ "ViTPose-L (single-task train)": {
92
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
93
+ "model": "models/vitpose-l.pth",
 
94
  },
95
+ "ViTPose-B (multi-task train, COCO)": {
96
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_base_coco_256x192.py",
97
+ "model": "models/vitpose-b-multi-coco.pth",
 
98
  },
99
+ "ViTPose-L (multi-task train, COCO)": {
100
+ "config": "ViTPose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_large_coco_256x192.py",
101
+ "model": "models/vitpose-l-multi-coco.pth",
 
102
  },
103
  }
104
 
105
  def __init__(self):
106
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
107
+ self.model_name = "ViTPose-B (multi-task train, COCO)"
 
108
  self.model = self._load_model(self.model_name)
109
 
110
  def _load_all_models_once(self) -> None:
 
113
 
114
  def _load_model(self, name: str) -> nn.Module:
115
  d = self.MODEL_DICT[name]
116
+ ckpt_path = huggingface_hub.hf_hub_download("public-data/ViTPose", d["model"])
117
+ model = init_pose_model(d["config"], ckpt_path, device=self.device)
 
118
  return model
119
 
120
  def set_model(self, name: str) -> None:
 
133
  vis_line_thickness: int,
134
  ) -> tuple[list[dict[str, np.ndarray]], np.ndarray]:
135
  out = self.predict_pose(image, det_results, box_score_threshold)
136
+ vis = self.visualize_pose_results(image, out, kpt_score_threshold, vis_dot_radius, vis_line_thickness)
 
137
  return out, vis
138
 
139
  def predict_pose(
140
+ self, image: np.ndarray, det_results: list[np.ndarray], box_score_threshold: float = 0.5
141
+ ) -> list[dict[str, np.ndarray]]:
 
 
142
  image = image[:, :, ::-1] # RGB -> BGR
143
  person_results = process_mmdet_results(det_results, 1)
144
+ out, _ = inference_top_down_pose_model(
145
+ self.model, image, person_results=person_results, bbox_thr=box_score_threshold, format="xyxy"
146
+ )
 
 
147
  return out
148
 
149
+ def visualize_pose_results(
150
+ self,
151
+ image: np.ndarray,
152
+ pose_results: list[dict[str, np.ndarray]],
153
+ kpt_score_threshold: float = 0.3,
154
+ vis_dot_radius: int = 4,
155
+ vis_line_thickness: int = 1,
156
+ ) -> np.ndarray:
157
  image = image[:, :, ::-1] # RGB -> BGR
158
+ vis = vis_pose_result(
159
+ self.model,
160
+ image,
161
+ pose_results,
162
+ kpt_score_thr=kpt_score_threshold,
163
+ radius=vis_dot_radius,
164
+ thickness=vis_line_thickness,
165
+ )
166
  return vis[:, :, ::-1] # BGR -> RGB
167
 
168
 
 
172
  self.pose_model = PoseModel()
173
 
174
  def run(
175
+ self,
176
+ video_path: str,
177
+ det_model_name: str,
178
+ pose_model_name: str,
179
+ box_score_threshold: float,
180
+ max_num_frames: int,
181
+ kpt_score_threshold: float,
182
+ vis_dot_radius: int,
183
+ vis_line_thickness: int,
184
  ) -> tuple[str, list[list[dict[str, np.ndarray]]]]:
185
  if video_path is None:
186
  return
 
194
 
195
  preds_all = []
196
 
197
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
198
+ out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
199
  writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
200
  for _ in range(max_num_frames):
201
  ok, frame = cap.read()
 
204
  rgb_frame = frame[:, :, ::-1]
205
  det_preds = self.det_model.detect(rgb_frame)
206
  preds, vis = self.pose_model.predict_pose_and_visualize(
207
+ rgb_frame, det_preds, box_score_threshold, kpt_score_threshold, vis_dot_radius, vis_line_thickness
208
+ )
209
  preds_all.append(preds)
210
  writer.write(vis[:, :, ::-1])
211
  cap.release()
 
213
 
214
  return out_file.name, preds_all
215
 
216
+ def visualize_pose_results(
217
+ self,
218
+ video_path: str,
219
+ pose_preds_all: list[list[dict[str, np.ndarray]]],
220
+ kpt_score_threshold: float,
221
+ vis_dot_radius: int,
222
+ vis_line_thickness: int,
223
+ ) -> str:
224
  if video_path is None or pose_preds_all is None:
225
  return
226
  cap = cv2.VideoCapture(video_path)
 
228
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
229
  fps = cap.get(cv2.CAP_PROP_FPS)
230
 
231
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
232
+ out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
233
  writer = cv2.VideoWriter(out_file.name, fourcc, fps, (width, height))
234
  for pose_preds in pose_preds_all:
235
  ok, frame = cap.read()
 
237
  break
238
  rgb_frame = frame[:, :, ::-1]
239
  vis = self.pose_model.visualize_pose_results(
240
+ rgb_frame, pose_preds, kpt_score_threshold, vis_dot_radius, vis_line_thickness
241
+ )
242
  writer.write(vis[:, :, ::-1])
243
  cap.release()
244
  writer.release()