alexnasa commited on
Commit
e2cec11
Β·
verified Β·
1 Parent(s): 5dae933

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -164
app.py CHANGED
@@ -8,13 +8,14 @@ import shutil
8
  import time
9
  import gradio as gr
10
  import sys
11
-
12
 
13
  # Set environment variables
14
  os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
15
  os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
16
  os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
17
 
 
18
  def sh(cmd): subprocess.check_call(cmd, shell=True)
19
 
20
  # only do this once per VM restart
@@ -44,224 +45,149 @@ def install_cuda_toolkit():
44
  install_cuda_toolkit()
45
 
46
 
47
- # Utility to stitch frames into a video
48
- def make_video_from_frames(frames_dir, out_path, fps=15):
49
- if not os.path.isdir(frames_dir):
50
- return None
51
- files = glob.glob(os.path.join(frames_dir, "*.jpg")) + glob.glob(os.path.join(frames_dir, "*.png"))
 
52
  if not files:
53
  return None
54
- ext = files[0].split('.')[-1]
55
- pattern = os.path.join(frames_dir, f"%05d.{ext}")
56
- subprocess.run([
57
- "ffmpeg", "-y", "-i", pattern,
58
- "-r", str(fps), out_path
59
- ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
60
- return out_path
61
-
62
- # Function to probe video for duration and frame rate
63
- def get_video_info(video_path):
64
- """
65
- Probes the uploaded video and returns updated slider configs:
66
- - seconds slider: max = int(duration)
67
- - fps slider: max = int(orig_fps)
68
- """
69
- if not video_path:
70
- # Return default slider updates when no video is uploaded
71
- return gr.update(maximum=10, value=3, step=1), gr.update(maximum=30, value=15, step=1)
72
 
73
- # Use ffprobe to get JSON metadata
74
- cmd = [
75
- "ffprobe", "-v", "quiet",
76
- "-print_format", "json",
77
- "-show_streams", video_path
78
- ]
79
- res = subprocess.run(cmd, capture_output=True, text=True)
80
- try:
81
- import json
82
- data = json.loads(res.stdout)
83
- stream = next(s for s in data.get('streams', []) if s.get('codec_type') == 'video')
84
- duration = float(stream.get('duration') or data.get('format', {}).get('duration', 0))
85
- fr = stream.get('r_frame_rate', '0/1')
86
- num, den = fr.split('/')
87
- orig_fps = float(num) / float(den) if float(den) else 30
88
- except Exception:
89
- duration, orig_fps = 10, 30
90
-
91
- # Configure sliders based on actual video properties
92
- seconds_cfg = gr.update(maximum=int(duration), value=min(int(duration), 3), step=1)
93
- fps_cfg = gr.update(maximum=int(orig_fps), value=min(int(orig_fps), 15), step=1)
94
- return seconds_cfg, fps_cfg
95
 
96
- # Step 1: Trim video based on user-defined duration and fps based on user-defined duration and fps
97
- def step1_trim(video_path, seconds, fps, state):
98
  session_id = str(uuid.uuid4())
99
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
 
100
  state.update({"session_id": session_id, "base_dir": base_dir})
101
 
102
- tmp = tempfile.mkdtemp()
103
- trimmed = os.path.join(tmp, f"{session_id}.mp4")
 
 
104
 
 
105
  try:
106
- # capture both stdout & stderr
107
  p = subprocess.run([
108
- "ffmpeg", "-y", "-i", video_path,
109
- "-t", str(seconds), # user-specified duration
110
- "-r", str(fps), # user-specified fps
111
- trimmed
112
- ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
113
-
114
- all_output = []
115
-
116
- for line in p.stdout:
117
- print(line, end="") # real-time echo
118
- all_output.append(line)
119
-
120
  except subprocess.CalledProcessError as e:
121
- # e.stdout contains everything
122
- err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}"
123
- return err, None, state
124
-
125
-
126
- state["trimmed_path"] = trimmed
127
- return f"βœ… Step 1: Trimmed to {seconds}s @{fps}fps", state
128
-
129
- # Step 2: Preprocessing β†’ cropped video
130
- @spaces.GPU()
131
- def step2_preprocess(state):
132
- session_id = state["session_id"]
133
- base_dir = state["base_dir"]
134
- trimmed = state["trimmed_path"]
135
-
136
- try:
137
- # capture both stdout & stderr
138
- p = subprocess.run([
139
- "python", "scripts/run_preprocessing.py",
140
- "--video_or_images_path", trimmed
141
- ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
142
- except subprocess.CalledProcessError as e:
143
- # e.stdout contains everything
144
- err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}"
145
  return err, None, state
146
 
147
  crop_dir = os.path.join(base_dir, "cropped")
148
- out = os.path.join(os.path.dirname(trimmed), f"crop_{session_id}.mp4")
149
- video = make_video_from_frames(crop_dir, out)
150
- return "βœ… Step 2: Preprocessing complete", video, state
151
 
152
- # Step 3: Normals inference β†’ normals video
 
153
  @spaces.GPU()
154
- def step3_normals(state):
155
- session_id = state["session_id"]
156
- base_dir = state["base_dir"]
 
157
 
158
  try:
159
- # capture both stdout & stderr
160
  p = subprocess.run([
161
- "python", "scripts/network_inference.py",
162
- "model.prediction_type=normals", f"video_name={session_id}"
163
- ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
164
  except subprocess.CalledProcessError as e:
165
- # e.stdout contains everything
166
- err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}"
167
  return err, None, state
168
-
169
- normals_dir = os.path.join(base_dir, "p3dmm", "normals")
170
- out = os.path.join(os.path.dirname(state["trimmed_path"]), f"normals_{session_id}.mp4")
171
- video = make_video_from_frames(normals_dir, out)
172
- return "βœ… Step 3: Normals inference complete", video, state
173
 
174
- # Step 4: UV map inference β†’ uv map video
 
 
 
 
175
  @spaces.GPU()
176
- def step4_uv_map(state):
177
- session_id = state["session_id"]
178
- base_dir = state["base_dir"]
 
179
 
180
  try:
181
- # capture both stdout & stderr
182
  p = subprocess.run([
183
- "python", "scripts/network_inference.py",
184
- "model.prediction_type=uv_map", f"video_name={session_id}"
185
- ], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
186
  except subprocess.CalledProcessError as e:
187
- # e.stdout contains everything
188
- err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}"
189
  return err, None, state
190
 
191
- uv_dir = os.path.join(base_dir, "p3dmm", "uv_map")
192
- out = os.path.join(os.path.dirname(state["trimmed_path"]), f"uv_map_{session_id}.mp4")
193
- video = make_video_from_frames(uv_dir, out)
194
- return "βœ… Step 4: UV map inference complete", video, state
195
 
196
- # Step 5: Tracking β†’ final tracking video
197
  @spaces.GPU()
198
- def step5_track(state):
199
- session_id = state["session_id"]
 
 
 
200
  script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
201
- cmd = [
202
- "python", script,
203
- f"video_name={session_id}"
204
- ]
205
  try:
206
- # capture both stdout & stderr
207
- p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True)
 
 
 
208
  except subprocess.CalledProcessError as e:
209
- # e.stdout contains everything
210
- err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}"
211
  return err, None, state
212
 
213
- # if we get here, it succeeded:
214
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
215
- out = os.path.join(os.path.dirname(state["trimmed_path"]), f"result_{session_id}.mp4")
216
- video = make_video_from_frames(tracking_dir, out)
217
- return "βœ… Step 5: Tracking complete", video, state
218
 
219
  # Build Gradio UI
220
  demo = gr.Blocks()
221
 
222
  with demo:
223
- gr.Markdown("## Video Processing Pipeline")
224
  with gr.Row():
225
  with gr.Column():
226
- video_in = gr.Video(label="Upload video", height=512)
227
- # Sliders for duration and fps
228
- seconds_slider = gr.Slider(label="Duration (seconds)", minimum=2, maximum=10, step=1, value=3)
229
- fps_slider = gr.Slider(label="Frame Rate (fps)", minimum=15, maximum=30, step=1, value=15)
230
- status = gr.Textbox(label="Status", lines=2, interactive=False)
231
- state = gr.State({})
232
  with gr.Column():
233
  with gr.Row():
234
- crop_vid = gr.Video(label="Preprocessed", height=256)
235
- normals_vid = gr.Video(label="Normals", height=256)
236
  with gr.Row():
237
- uv_vid = gr.Video(label="UV Map", height=256)
238
- track_vid = gr.Video(label="Tracking", height=256)
239
- run_btn_1 = gr.Button("Run Pipeline 1")
240
- run_btn_2 = gr.Button("Run Pipeline 2")
241
- run_btn_3 = gr.Button("Run Pipeline 3")
242
- run_btn_4 = gr.Button("Run Pipeline 4")
243
- run_btn_5 = gr.Button("Run Pipeline 5")
244
 
245
- # Update sliders after video upload
246
- video_in.change(fn=get_video_info, inputs=video_in, outputs=[seconds_slider, fps_slider])
 
 
 
247
 
248
  # Pipeline execution
249
- run_btn_1.click(fn=step1_trim, inputs=[video_in, seconds_slider, fps_slider, state], outputs=[status, state])
250
- run_btn_2.click(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
251
- run_btn_3.click(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
252
- run_btn_4.click(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
253
- run_btn_5.click(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
254
-
255
- # .then(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
256
- # .then(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
257
- # .then(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
258
- # .then(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
259
-
260
 
261
  # ------------------------------------------------------------------
262
  # START THE GRADIO SERVER
263
  # ------------------------------------------------------------------
264
  demo.queue()
265
-
266
  demo.launch(share=True, ssr_mode=False)
267
 
 
8
  import time
9
  import gradio as gr
10
  import sys
11
+ from PIL import Image
12
 
13
  # Set environment variables
14
  os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
15
  os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
16
  os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
17
 
18
+
19
  def sh(cmd): subprocess.check_call(cmd, shell=True)
20
 
21
  # only do this once per VM restart
 
45
  install_cuda_toolkit()
46
 
47
 
48
+ # Utility to select first image from a folder
49
+ def first_image_from_dir(directory):
50
+ patterns = ["*.jpg", "*.png", "*.jpeg"]
51
+ files = []
52
+ for p in patterns:
53
+ files.extend(glob.glob(os.path.join(directory, p)))
54
  if not files:
55
  return None
56
+ return sorted(files)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Step 1: Preprocess the input image (Save and Crop)
59
+ @spaces.GPU()
60
+ def preprocess_image(image_array, state):
61
+ # Check if an image was uploaded
62
+ if image_array is None:
63
+ return "❌ Please upload an image first.", None, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Step 1a: Save the uploaded image
 
66
  session_id = str(uuid.uuid4())
67
  base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
68
+ os.makedirs(base_dir, exist_ok=True)
69
  state.update({"session_id": session_id, "base_dir": base_dir})
70
 
71
+ img = Image.fromarray(image_array)
72
+ saved_image_path = os.path.join(base_dir, f"{session_id}.png")
73
+ img.save(saved_image_path)
74
+ state["image_path"] = saved_image_path
75
 
76
+ # Step 1b: Run the preprocessing script
77
  try:
 
78
  p = subprocess.run([
79
+ "python", "scripts/run_preprocessing.py",
80
+ "--video_or_images_path", saved_image_path
81
+ ], check=True, capture_output=True, text=True)
 
 
 
 
 
 
 
 
 
82
  except subprocess.CalledProcessError as e:
83
+ err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
84
+ # Clean up created directory on failure
85
+ shutil.rmtree(base_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return err, None, state
87
 
88
  crop_dir = os.path.join(base_dir, "cropped")
89
+ image = first_image_from_dir(crop_dir)
90
+ return "βœ… Preprocessing complete", image, state
 
91
 
92
+
93
+ # Step 2: Normals inference β†’ normals image
94
  @spaces.GPU()
95
+ def step2_normals(state):
96
+ session_id = state.get("session_id")
97
+ if not session_id:
98
+ return "❌ Please preprocess an image first.", None, state
99
 
100
  try:
101
+ # Execute the network inference for normals
102
  p = subprocess.run([
103
+ "python", "scripts/network_inference.py",
104
+ "model.prediction_type=normals", f"video_name={session_id}"
105
+ ], check=True, capture_output=True, text=True)
106
  except subprocess.CalledProcessError as e:
107
+ err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
 
108
  return err, None, state
 
 
 
 
 
109
 
110
+ normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
111
+ image = first_image_from_dir(normals_dir)
112
+ return "βœ… Step 2: Normals inference complete", image, state
113
+
114
+ # Step 3: UV map inference β†’ uv map image
115
  @spaces.GPU()
116
+ def step3_uv_map(state):
117
+ session_id = state.get("session_id")
118
+ if not session_id:
119
+ return "❌ Please preprocess an image first.", None, state
120
 
121
  try:
122
+ # Execute the network inference for UV map
123
  p = subprocess.run([
124
+ "python", "scripts/network_inference.py",
125
+ "model.prediction_type=uv_map", f"video_name={session_id}"
126
+ ], check=True, capture_output=True, text=True)
127
  except subprocess.CalledProcessError as e:
128
+ err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
 
129
  return err, None, state
130
 
131
+ uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
132
+ image = first_image_from_dir(uv_dir)
133
+ return "βœ… Step 3: UV map inference complete", image, state
 
134
 
135
+ # Step 4: Tracking β†’ final tracking image
136
  @spaces.GPU()
137
+ def step4_track(state):
138
+ session_id = state.get("session_id")
139
+ if not session_id:
140
+ return "❌ Please preprocess an image first.", None, state
141
+
142
  script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
 
 
 
 
143
  try:
144
+ # Execute the tracking script
145
+ p = subprocess.run([
146
+ "python", script,
147
+ f"video_name={session_id}"
148
+ ], check=True, capture_output=True, text=True)
149
  except subprocess.CalledProcessError as e:
150
+ err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
 
151
  return err, None, state
152
 
 
153
  tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
154
+ image = first_image_from_dir(tracking_dir)
155
+ return "βœ… Step 4: Tracking complete", image, state
 
156
 
157
  # Build Gradio UI
158
  demo = gr.Blocks()
159
 
160
  with demo:
161
+ gr.Markdown("## Image Processing Pipeline")
162
  with gr.Row():
163
  with gr.Column():
164
+ image_in = gr.Image(label="Upload Image", type="numpy", height=512)
165
+ status = gr.Textbox(label="Status", lines=2, interactive=False)
166
+ state = gr.State({})
 
 
 
167
  with gr.Column():
168
  with gr.Row():
169
+ crop_img = gr.Image(label="Preprocessed", height=256)
170
+ normals_img = gr.Image(label="Normals", height=256)
171
  with gr.Row():
172
+ uv_img = gr.Image(label="UV Map", height=256)
173
+ track_img = gr.Image(label="Tracking", height=256)
 
 
 
 
 
174
 
175
+ with gr.Row():
176
+ preprocess_btn = gr.Button("Step 1: Preprocess")
177
+ normals_btn = gr.Button("Step 2: Normals")
178
+ uv_map_btn = gr.Button("Step 3: UV Map")
179
+ track_btn = gr.Button("Step 4: Track")
180
 
181
  # Pipeline execution
182
+ preprocess_btn.click(fn=preprocess_image, inputs=[image_in, state], outputs=[status, crop_img, state])
183
+ normals_btn.click(fn=step2_normals, inputs=[state], outputs=[status, normals_img, state])
184
+ uv_map_btn.click(fn=step3_uv_map, inputs=[state], outputs=[status, uv_img, state])
185
+ track_btn.click(fn=step4_track, inputs=[state], outputs=[status, track_img, state])
186
+
 
 
 
 
 
 
187
 
188
  # ------------------------------------------------------------------
189
  # START THE GRADIO SERVER
190
  # ------------------------------------------------------------------
191
  demo.queue()
 
192
  demo.launch(share=True, ssr_mode=False)
193