hysts HF staff commited on
Commit
5c22f66
·
1 Parent(s): afea94c
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +143 -0
  3. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pathlib
8
+ import subprocess
9
+ import tarfile
10
+
11
+ if os.environ.get('SYSTEM') == 'spaces':
12
+ subprocess.call('pip uninstall -y opencv-python'.split())
13
+ subprocess.call('pip uninstall -y opencv-python-headless'.split())
14
+ subprocess.call('pip install opencv-python-headless==4.5.5.64'.split())
15
+
16
+ import gradio as gr
17
+ import huggingface_hub
18
+ import mediapipe as mp
19
+ import numpy as np
20
+
21
+ mp_drawing = mp.solutions.drawing_utils
22
+ mp_drawing_styles = mp.solutions.drawing_styles
23
+ mp_pose = mp.solutions.pose
24
+
25
+ TITLE = 'MediaPipe Human Pose Estimation'
26
+ DESCRIPTION = 'https://google.github.io/mediapipe/'
27
+ ARTICLE = None
28
+
29
+ TOKEN = os.environ['TOKEN']
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--theme', type=str)
35
+ parser.add_argument('--live', action='store_true')
36
+ parser.add_argument('--share', action='store_true')
37
+ parser.add_argument('--port', type=int)
38
+ parser.add_argument('--disable-queue',
39
+ dest='enable_queue',
40
+ action='store_false')
41
+ parser.add_argument('--allow-flagging', type=str, default='never')
42
+ parser.add_argument('--allow-screenshot', action='store_true')
43
+ return parser.parse_args()
44
+
45
+
46
+ def load_sample_images() -> list[pathlib.Path]:
47
+ image_dir = pathlib.Path('images')
48
+ if not image_dir.exists():
49
+ image_dir.mkdir()
50
+ dataset_repo = 'hysts/input-images'
51
+ filenames = ['002.tar']
52
+ for name in filenames:
53
+ path = huggingface_hub.hf_hub_download(dataset_repo,
54
+ name,
55
+ repo_type='dataset',
56
+ use_auth_token=TOKEN)
57
+ with tarfile.open(path) as f:
58
+ f.extractall(image_dir.as_posix())
59
+ return sorted(image_dir.rglob('*.jpg'))
60
+
61
+
62
+ def run(image: np.ndarray, model_complexity: int, enable_segmentation: bool,
63
+ min_detection_confidence: float, background_color: str) -> np.ndarray:
64
+ with mp_pose.Pose(
65
+ static_image_mode=True,
66
+ model_complexity=model_complexity,
67
+ enable_segmentation=enable_segmentation,
68
+ min_detection_confidence=min_detection_confidence) as pose:
69
+ results = pose.process(image)
70
+
71
+ res = image[:, :, ::-1].copy()
72
+ if enable_segmentation:
73
+ if background_color == 'white':
74
+ bg_color = 255
75
+ elif background_color == 'black':
76
+ bg_color = 0
77
+ elif background_color == 'green':
78
+ bg_color = (0, 255, 0)
79
+ else:
80
+ raise ValueError
81
+
82
+ if results.segmentation_mask is not None:
83
+ res[results.segmentation_mask <= 0.1] = bg_color
84
+ else:
85
+ res[:] = bg_color
86
+
87
+ mp_drawing.draw_landmarks(res,
88
+ results.pose_landmarks,
89
+ mp_pose.POSE_CONNECTIONS,
90
+ landmark_drawing_spec=mp_drawing_styles.
91
+ get_default_pose_landmarks_style())
92
+
93
+ return res[:, :, ::-1]
94
+
95
+
96
+ def main():
97
+ args = parse_args()
98
+
99
+ model_complexities = list(range(3))
100
+ background_colors = ['white', 'black', 'green']
101
+
102
+ image_paths = load_sample_images()
103
+ examples = [[
104
+ path.as_posix(), model_complexities[1], True, 0.5, background_colors[0]
105
+ ] for path in image_paths]
106
+
107
+ gr.Interface(
108
+ run,
109
+ [
110
+ gr.inputs.Image(type='numpy', label='Input'),
111
+ gr.inputs.Radio(model_complexities,
112
+ type='index',
113
+ default=model_complexities[1],
114
+ label='Model Complexity'),
115
+ gr.inputs.Checkbox(default=True, label='Enable Segmentation'),
116
+ gr.inputs.Slider(0,
117
+ 1,
118
+ step=0.05,
119
+ default=0.5,
120
+ label='Minimum Detection Confidence'),
121
+ gr.inputs.Radio(background_colors,
122
+ type='value',
123
+ default=background_colors[0],
124
+ label='Background Color'),
125
+ ],
126
+ gr.outputs.Image(type='numpy', label='Output'),
127
+ examples=examples,
128
+ title=TITLE,
129
+ description=DESCRIPTION,
130
+ article=ARTICLE,
131
+ theme=args.theme,
132
+ allow_screenshot=args.allow_screenshot,
133
+ allow_flagging=args.allow_flagging,
134
+ live=args.live,
135
+ ).launch(
136
+ enable_queue=args.enable_queue,
137
+ server_port=args.port,
138
+ share=args.share,
139
+ )
140
+
141
+
142
+ if __name__ == '__main__':
143
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ mediapipe==0.8.9.1
2
+ numpy==1.22.3
3
+ opencv-python-headless==4.5.5.64