pixel3dmm / app.py
alexnasa's picture
Create app.py
480e656 verified
raw
history blame
7.72 kB
import spaces
import os
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = "./"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = "./proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = "./tracking_results"
# Utility to stitch frames into a video
def make_video_from_frames(frames_dir, out_path, fps=15):
if not os.path.isdir(frames_dir):
return None
files = glob.glob(os.path.join(frames_dir, "*.jpg")) + glob.glob(os.path.join(frames_dir, "*.png"))
if not files:
return None
ext = files[0].split('.')[-1]
pattern = os.path.join(frames_dir, f"%05d.{ext}")
subprocess.run([
"ffmpeg", "-y", "-i", pattern,
"-r", str(fps), out_path
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return out_path
# Function to probe video for duration and frame rate
def get_video_info(video_path):
"""
Probes the uploaded video and returns updated slider configs:
- seconds slider: max = int(duration)
- fps slider: max = int(orig_fps)
"""
if not video_path:
# Return default slider updates when no video is uploaded
return gr.update(maximum=10, value=3, step=1), gr.update(maximum=30, value=15, step=1)
# Use ffprobe to get JSON metadata
cmd = [
"ffprobe", "-v", "quiet",
"-print_format", "json",
"-show_streams", video_path
]
res = subprocess.run(cmd, capture_output=True, text=True)
try:
import json
data = json.loads(res.stdout)
stream = next(s for s in data.get('streams', []) if s.get('codec_type') == 'video')
duration = float(stream.get('duration') or data.get('format', {}).get('duration', 0))
fr = stream.get('r_frame_rate', '0/1')
num, den = fr.split('/')
orig_fps = float(num) / float(den) if float(den) else 30
except Exception:
duration, orig_fps = 10, 30
# Configure sliders based on actual video properties
seconds_cfg = gr.update(maximum=int(duration), value=min(int(duration), 3), step=1)
fps_cfg = gr.update(maximum=int(orig_fps), value=min(int(orig_fps), 15), step=1)
return seconds_cfg, fps_cfg
# Step 1: Trim video based on user-defined duration and fps based on user-defined duration and fps
@space.GPU()
def step1_trim(video_path, seconds, fps, state):
session_id = str(uuid.uuid4())
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
state.update({"session_id": session_id, "base_dir": base_dir})
tmp = tempfile.mkdtemp()
trimmed = os.path.join(tmp, f"{session_id}.mp4")
subprocess.run([
"ffmpeg", "-y", "-i", video_path,
"-t", str(seconds), # user-specified duration
"-r", str(fps), # user-specified fps
trimmed
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
state["trimmed_path"] = trimmed
return f"βœ… Step 1: Trimmed to {seconds}s @{fps}fps", state
# Step 2: Preprocessing β†’ cropped video
@space.GPU()
def step2_preprocess(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
trimmed = state["trimmed_path"]
subprocess.run([
"python", "scripts/run_preprocessing.py",
"--video_or_images_path", trimmed
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
crop_dir = os.path.join(base_dir, "cropped")
out = os.path.join(os.path.dirname(trimmed), f"crop_{session_id}.mp4")
video = make_video_from_frames(crop_dir, out)
return "βœ… Step 2: Preprocessing complete", video, state
# Step 3: Normals inference β†’ normals video
@space.GPU()
def step3_normals(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
subprocess.run([
"python", "scripts/network_inference.py",
"model.prediction_type=normals", f"video_name={session_id}"
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
normals_dir = os.path.join(base_dir, "p3dmm", "normals")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"normals_{session_id}.mp4")
video = make_video_from_frames(normals_dir, out)
return "βœ… Step 3: Normals inference complete", video, state
# Step 4: UV map inference β†’ uv map video
@space.GPU()
def step4_uv_map(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
subprocess.run([
"python", "scripts/network_inference.py",
"model.prediction_type=uv_map", f"video_name={session_id}"
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
uv_dir = os.path.join(base_dir, "p3dmm", "uv_map")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"uv_map_{session_id}.mp4")
video = make_video_from_frames(uv_dir, out)
return "βœ… Step 4: UV map inference complete", video, state
# Step 5: Tracking β†’ final tracking video
@space.GPU()
def step5_track(state):
session_id = state["session_id"]
script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
cmd = [
"python", script,
f"video_name={session_id}"
]
try:
# capture both stdout & stderr
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
# if we get here, it succeeded:
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"result_{session_id}.mp4")
video = make_video_from_frames(tracking_dir, out)
return "βœ… Step 5: Tracking complete", video, state
# Build Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("## Video Processing Pipeline")
with gr.Row():
with gr.Column():
video_in = gr.Video(label="Upload video", height=512)
# Sliders for duration and fps
seconds_slider = gr.Slider(label="Duration (seconds)", minimum=2, maximum=10, step=1, value=3)
fps_slider = gr.Slider(label="Frame Rate (fps)", minimum=15, maximum=30, step=1, value=15)
status = gr.Textbox(label="Status", lines=2, interactive=False)
state = gr.State({})
with gr.Column():
with gr.Row():
crop_vid = gr.Video(label="Preprocessed", height=256)
normals_vid = gr.Video(label="Normals", height=256)
with gr.Row():
uv_vid = gr.Video(label="UV Map", height=256)
track_vid = gr.Video(label="Tracking", height=256)
run_btn = gr.Button("Run Pipeline")
# Update sliders after video upload
video_in.change(fn=get_video_info, inputs=video_in, outputs=[seconds_slider, fps_slider])
# Pipeline execution
(run_btn.click(fn=step1_trim, inputs=[video_in, seconds_slider, fps_slider, state], outputs=[status, state])
.then(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
.then(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
.then(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
.then(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
)
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
demo.queue()
demo.launch(share=True)