hysts HF staff commited on
Commit
1fa5a67
·
1 Parent(s): 09e24f9
Files changed (5) hide show
  1. .pre-commit-config.yaml +1 -0
  2. README.md +4 -1
  3. app.py +36 -50
  4. model.py +7 -10
  5. requirements.txt +1 -1
.pre-commit-config.yaml CHANGED
@@ -29,6 +29,7 @@ repos:
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
 
32
  - repo: https://github.com/google/yapf
33
  rev: v0.32.0
34
  hooks:
 
29
  hooks:
30
  - id: mypy
31
  args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
  - repo: https://github.com/google/yapf
34
  rev: v0.32.0
35
  hooks:
README.md CHANGED
@@ -4,9 +4,12 @@ emoji: 🦀
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
+ suggested_hardware: t4-small
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ https://arxiv.org/abs/2204.12484
app.py CHANGED
@@ -9,18 +9,12 @@ import gradio as gr
9
 
10
  from model import AppModel
11
 
12
- DESCRIPTION = '''# ViTPose
13
-
14
- This is an unofficial demo for [https://github.com/ViTAE-Transformer/ViTPose](https://github.com/ViTAE-Transformer/ViTPose).
15
 
16
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
17
  '''
18
 
19
 
20
- def set_example_video(example: list) -> dict:
21
- return gr.Video.update(value=example[0])
22
-
23
-
24
  def extract_tar() -> None:
25
  if pathlib.Path('mmdet_configs/configs').exists():
26
  return
@@ -40,58 +34,54 @@ with gr.Blocks(css='style.css') as demo:
40
  input_video = gr.Video(label='Input Video',
41
  format='mp4',
42
  elem_id='input_video')
43
- detector_name = gr.Dropdown(list(
44
- model.det_model.MODEL_DICT.keys()),
45
- value=model.det_model.model_name,
46
- label='Detector')
47
- pose_model_name = gr.Dropdown(list(
48
- model.pose_model.MODEL_DICT.keys()),
49
- value=model.pose_model.model_name,
50
- label='Pose Model')
51
- det_score_threshold = gr.Slider(0,
52
- 1,
 
53
  step=0.05,
54
- value=0.5,
55
- label='Box Score Threshold')
56
- max_num_frames = gr.Slider(1,
57
- 300,
58
  step=1,
59
- value=60,
60
- label='Maximum Number of Frames')
61
- predict_button = gr.Button(value='Predict')
62
  pose_preds = gr.Variable()
63
 
64
  paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
65
- example_videos = gr.Dataset(components=[input_video],
66
- samples=[[path.as_posix()]
67
- for path in paths])
68
 
69
  with gr.Column():
70
  result = gr.Video(label='Result', format='mp4', elem_id='result')
71
  vis_kpt_score_threshold = gr.Slider(
72
- 0,
73
- 1,
 
74
  step=0.05,
75
- value=0.3,
76
- label='Visualization Score Threshold')
77
- vis_dot_radius = gr.Slider(1,
78
- 10,
79
  step=1,
80
- value=4,
81
- label='Dot Radius')
82
- vis_line_thickness = gr.Slider(1,
83
- 10,
84
  step=1,
85
- value=2,
86
- label='Line Thickness')
87
- redraw_button = gr.Button(value='Redraw')
88
 
89
- detector_name.change(fn=model.det_model.set_model,
90
- inputs=detector_name,
91
- outputs=None)
92
  pose_model_name.change(fn=model.pose_model.set_model,
93
- inputs=pose_model_name,
94
- outputs=None)
95
  predict_button.click(fn=model.run,
96
  inputs=[
97
  input_video,
@@ -117,8 +107,4 @@ with gr.Blocks(css='style.css') as demo:
117
  ],
118
  outputs=result)
119
 
120
- example_videos.click(fn=set_example_video,
121
- inputs=example_videos,
122
- outputs=input_video)
123
-
124
- demo.queue().launch(show_api=False)
 
9
 
10
  from model import AppModel
11
 
12
+ DESCRIPTION = '''# [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)
 
 
13
 
14
  Related app: [https://huggingface.co/spaces/Gradio-Blocks/ViTPose](https://huggingface.co/spaces/Gradio-Blocks/ViTPose)
15
  '''
16
 
17
 
 
 
 
 
18
  def extract_tar() -> None:
19
  if pathlib.Path('mmdet_configs/configs').exists():
20
  return
 
34
  input_video = gr.Video(label='Input Video',
35
  format='mp4',
36
  elem_id='input_video')
37
+ detector_name = gr.Dropdown(label='Detector',
38
+ choices=list(
39
+ model.det_model.MODEL_DICT.keys()),
40
+ value=model.det_model.model_name)
41
+ pose_model_name = gr.Dropdown(
42
+ label='Pose Model',
43
+ choices=list(model.pose_model.MODEL_DICT.keys()),
44
+ value=model.pose_model.model_name)
45
+ det_score_threshold = gr.Slider(label='Box Score Threshold',
46
+ minimum=0,
47
+ maximum=1,
48
  step=0.05,
49
+ value=0.5)
50
+ max_num_frames = gr.Slider(label='Maximum Number of Frames',
51
+ minimum=1,
52
+ maximum=300,
53
  step=1,
54
+ value=60)
55
+ predict_button = gr.Button('Predict')
 
56
  pose_preds = gr.Variable()
57
 
58
  paths = sorted(pathlib.Path('videos').rglob('*.mp4'))
59
+ gr.Examples(examples=[[path.as_posix()] for path in paths],
60
+ inputs=input_video)
 
61
 
62
  with gr.Column():
63
  result = gr.Video(label='Result', format='mp4', elem_id='result')
64
  vis_kpt_score_threshold = gr.Slider(
65
+ label='Visualization Score Threshold',
66
+ minimum=0,
67
+ maximum=1,
68
  step=0.05,
69
+ value=0.3)
70
+ vis_dot_radius = gr.Slider(label='Dot Radius',
71
+ minimum=1,
72
+ maximum=10,
73
  step=1,
74
+ value=4)
75
+ vis_line_thickness = gr.Slider(label='Line Thickness',
76
+ minimum=1,
77
+ maximum=10,
78
  step=1,
79
+ value=2)
80
+ redraw_button = gr.Button('Redraw')
 
81
 
82
+ detector_name.change(fn=model.det_model.set_model, inputs=detector_name)
 
 
83
  pose_model_name.change(fn=model.pose_model.set_model,
84
+ inputs=pose_model_name)
 
85
  predict_button.click(fn=model.run,
86
  inputs=[
87
  input_video,
 
107
  ],
108
  outputs=result)
109
 
110
+ demo.queue(max_size=10).launch()
 
 
 
 
model.py CHANGED
@@ -15,7 +15,7 @@ if os.getenv('SYSTEM') == 'spaces':
15
  subprocess.call(shlex.split('pip uninstall -y opencv-python'))
16
  subprocess.call(shlex.split('pip uninstall -y opencv-python-headless'))
17
  subprocess.call(
18
- shlex.split('pip install opencv-python-headless==4.5.5.64'))
19
 
20
  import cv2
21
  import huggingface_hub
@@ -29,8 +29,6 @@ from mmdet.apis import inference_detector, init_detector
29
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
  process_mmdet_results, vis_pose_result)
31
 
32
- HF_TOKEN = os.getenv('HF_TOKEN')
33
-
34
 
35
  class DetModel:
36
  MODEL_DICT = {
@@ -72,8 +70,8 @@ class DetModel:
72
  self._load_model(name)
73
 
74
  def _load_model(self, name: str) -> nn.Module:
75
- dic = self.MODEL_DICT[name]
76
- return init_detector(dic['config'], dic['model'], device=self.device)
77
 
78
  def set_model(self, name: str) -> None:
79
  if name == self.model_name:
@@ -145,11 +143,10 @@ class PoseModel:
145
  self._load_model(name)
146
 
147
  def _load_model(self, name: str) -> nn.Module:
148
- dic = self.MODEL_DICT[name]
149
- ckpt_path = huggingface_hub.hf_hub_download('hysts/ViTPose',
150
- dic['model'],
151
- use_auth_token=HF_TOKEN)
152
- model = init_pose_model(dic['config'], ckpt_path, device=self.device)
153
  return model
154
 
155
  def set_model(self, name: str) -> None:
 
15
  subprocess.call(shlex.split('pip uninstall -y opencv-python'))
16
  subprocess.call(shlex.split('pip uninstall -y opencv-python-headless'))
17
  subprocess.call(
18
+ shlex.split('pip install opencv-python-headless==4.8.0.74'))
19
 
20
  import cv2
21
  import huggingface_hub
 
29
  from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
30
  process_mmdet_results, vis_pose_result)
31
 
 
 
32
 
33
  class DetModel:
34
  MODEL_DICT = {
 
70
  self._load_model(name)
71
 
72
  def _load_model(self, name: str) -> nn.Module:
73
+ d = self.MODEL_DICT[name]
74
+ return init_detector(d['config'], d['model'], device=self.device)
75
 
76
  def set_model(self, name: str) -> None:
77
  if name == self.model_name:
 
143
  self._load_model(name)
144
 
145
  def _load_model(self, name: str) -> nn.Module:
146
+ d = self.MODEL_DICT[name]
147
+ ckpt_path = huggingface_hub.hf_hub_download('public-data/ViTPose',
148
+ d['model'])
149
+ model = init_pose_model(d['config'], ckpt_path, device=self.device)
 
150
  return model
151
 
152
  def set_model(self, name: str) -> None:
requirements.txt CHANGED
@@ -2,7 +2,7 @@ mmcv-full==1.5.0
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
  numpy==1.23.5
5
- opencv-python-headless==4.5.5.64
6
  openmim==0.1.5
7
  timm==0.5.4
8
  torch==1.11.0
 
2
  mmdet==2.24.1
3
  mmpose==0.25.1
4
  numpy==1.23.5
5
+ opencv-python-headless==4.8.0.74
6
  openmim==0.1.5
7
  timm==0.5.4
8
  torch==1.11.0