alexnasa commited on
Commit
480e656
Β·
verified Β·
1 Parent(s): 48dcbfb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import subprocess
4
+ import tempfile
5
+ import uuid
6
+ import glob
7
+ import shutil
8
+ import time
9
+
10
+ import gradio as gr
11
+
12
+ # Set environment variables
13
+ os.environ["PIXEL3DMM_CODE_BASE"] = "./"
14
+ os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = "./proprocess_results"
15
+ os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = "./tracking_results"
16
+
17
+ # Utility to stitch frames into a video
18
+ def make_video_from_frames(frames_dir, out_path, fps=15):
19
+ if not os.path.isdir(frames_dir):
20
+ return None
21
+ files = glob.glob(os.path.join(frames_dir, "*.jpg")) + glob.glob(os.path.join(frames_dir, "*.png"))
22
+ if not files:
23
+ return None
24
+ ext = files[0].split('.')[-1]
25
+ pattern = os.path.join(frames_dir, f"%05d.{ext}")
26
+ subprocess.run([
27
+ "ffmpeg", "-y", "-i", pattern,
28
+ "-r", str(fps), out_path
29
+ ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
30
+ return out_path
31
+
32
+ # Function to probe video for duration and frame rate
33
+ def get_video_info(video_path):
34
+ """
35
+ Probes the uploaded video and returns updated slider configs:
36
+ - seconds slider: max = int(duration)
37
+ - fps slider: max = int(orig_fps)
38
+ """
39
+ if not video_path:
40
+ # Return default slider updates when no video is uploaded
41
+ return gr.update(maximum=10, value=3, step=1), gr.update(maximum=30, value=15, step=1)
42
+
43
+ # Use ffprobe to get JSON metadata
44
+ cmd = [
45
+ "ffprobe", "-v", "quiet",
46
+ "-print_format", "json",
47
+ "-show_streams", video_path
48
+ ]
49
+ res = subprocess.run(cmd, capture_output=True, text=True)
50
+ try:
51
+ import json
52
+ data = json.loads(res.stdout)
53
+ stream = next(s for s in data.get('streams', []) if s.get('codec_type') == 'video')
54
+ duration = float(stream.get('duration') or data.get('format', {}).get('duration', 0))
55
+ fr = stream.get('r_frame_rate', '0/1')
56
+ num, den = fr.split('/')
57
+ orig_fps = float(num) / float(den) if float(den) else 30
58
+ except Exception:
59
+ duration, orig_fps = 10, 30
60
+
61
+ # Configure sliders based on actual video properties
62
+ seconds_cfg = gr.update(maximum=int(duration), value=min(int(duration), 3), step=1)
63
+ fps_cfg = gr.update(maximum=int(orig_fps), value=min(int(orig_fps), 15), step=1)
64
+ return seconds_cfg, fps_cfg
65
+
66
+ # Step 1: Trim video based on user-defined duration and fps based on user-defined duration and fps
67
+ @space.GPU()
68
+ def step1_trim(video_path, seconds, fps, state):
69
+ session_id = str(uuid.uuid4())
70
+ base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
71
+ state.update({"session_id": session_id, "base_dir": base_dir})
72
+
73
+ tmp = tempfile.mkdtemp()
74
+ trimmed = os.path.join(tmp, f"{session_id}.mp4")
75
+ subprocess.run([
76
+ "ffmpeg", "-y", "-i", video_path,
77
+ "-t", str(seconds), # user-specified duration
78
+ "-r", str(fps), # user-specified fps
79
+ trimmed
80
+ ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
81
+ state["trimmed_path"] = trimmed
82
+ return f"βœ… Step 1: Trimmed to {seconds}s @{fps}fps", state
83
+
84
+ # Step 2: Preprocessing β†’ cropped video
85
+ @space.GPU()
86
+ def step2_preprocess(state):
87
+ session_id = state["session_id"]
88
+ base_dir = state["base_dir"]
89
+ trimmed = state["trimmed_path"]
90
+
91
+ subprocess.run([
92
+ "python", "scripts/run_preprocessing.py",
93
+ "--video_or_images_path", trimmed
94
+ ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
95
+
96
+ crop_dir = os.path.join(base_dir, "cropped")
97
+ out = os.path.join(os.path.dirname(trimmed), f"crop_{session_id}.mp4")
98
+ video = make_video_from_frames(crop_dir, out)
99
+ return "βœ… Step 2: Preprocessing complete", video, state
100
+
101
+ # Step 3: Normals inference β†’ normals video
102
+ @space.GPU()
103
+ def step3_normals(state):
104
+ session_id = state["session_id"]
105
+ base_dir = state["base_dir"]
106
+
107
+ subprocess.run([
108
+ "python", "scripts/network_inference.py",
109
+ "model.prediction_type=normals", f"video_name={session_id}"
110
+ ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
111
+
112
+ normals_dir = os.path.join(base_dir, "p3dmm", "normals")
113
+ out = os.path.join(os.path.dirname(state["trimmed_path"]), f"normals_{session_id}.mp4")
114
+ video = make_video_from_frames(normals_dir, out)
115
+ return "βœ… Step 3: Normals inference complete", video, state
116
+
117
+ # Step 4: UV map inference β†’ uv map video
118
+ @space.GPU()
119
+ def step4_uv_map(state):
120
+ session_id = state["session_id"]
121
+ base_dir = state["base_dir"]
122
+
123
+ subprocess.run([
124
+ "python", "scripts/network_inference.py",
125
+ "model.prediction_type=uv_map", f"video_name={session_id}"
126
+ ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
127
+
128
+ uv_dir = os.path.join(base_dir, "p3dmm", "uv_map")
129
+ out = os.path.join(os.path.dirname(state["trimmed_path"]), f"uv_map_{session_id}.mp4")
130
+ video = make_video_from_frames(uv_dir, out)
131
+ return "βœ… Step 4: UV map inference complete", video, state
132
+
133
+ # Step 5: Tracking β†’ final tracking video
134
+ @space.GPU()
135
+ def step5_track(state):
136
+ session_id = state["session_id"]
137
+ script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
138
+ cmd = [
139
+ "python", script,
140
+ f"video_name={session_id}"
141
+ ]
142
+ try:
143
+ # capture both stdout & stderr
144
+ p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True)
145
+ except subprocess.CalledProcessError as e:
146
+ # e.stdout contains everything
147
+ err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}"
148
+ return err, None, state
149
+
150
+ # if we get here, it succeeded:
151
+ tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
152
+ out = os.path.join(os.path.dirname(state["trimmed_path"]), f"result_{session_id}.mp4")
153
+ video = make_video_from_frames(tracking_dir, out)
154
+ return "βœ… Step 5: Tracking complete", video, state
155
+
156
+ # Build Gradio UI
157
+ demo = gr.Blocks()
158
+
159
+ with demo:
160
+ gr.Markdown("## Video Processing Pipeline")
161
+ with gr.Row():
162
+ with gr.Column():
163
+ video_in = gr.Video(label="Upload video", height=512)
164
+ # Sliders for duration and fps
165
+ seconds_slider = gr.Slider(label="Duration (seconds)", minimum=2, maximum=10, step=1, value=3)
166
+ fps_slider = gr.Slider(label="Frame Rate (fps)", minimum=15, maximum=30, step=1, value=15)
167
+ status = gr.Textbox(label="Status", lines=2, interactive=False)
168
+ state = gr.State({})
169
+ with gr.Column():
170
+ with gr.Row():
171
+ crop_vid = gr.Video(label="Preprocessed", height=256)
172
+ normals_vid = gr.Video(label="Normals", height=256)
173
+ with gr.Row():
174
+ uv_vid = gr.Video(label="UV Map", height=256)
175
+ track_vid = gr.Video(label="Tracking", height=256)
176
+ run_btn = gr.Button("Run Pipeline")
177
+
178
+ # Update sliders after video upload
179
+ video_in.change(fn=get_video_info, inputs=video_in, outputs=[seconds_slider, fps_slider])
180
+
181
+ # Pipeline execution
182
+ (run_btn.click(fn=step1_trim, inputs=[video_in, seconds_slider, fps_slider, state], outputs=[status, state])
183
+ .then(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
184
+ .then(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
185
+ .then(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
186
+ .then(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
187
+ )
188
+
189
+ # ------------------------------------------------------------------
190
+ # START THE GRADIO SERVER
191
+ # ------------------------------------------------------------------
192
+ demo.queue()
193
+
194
+ demo.launch(share=True)
195
+