Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import subprocess | |
import tempfile | |
import uuid | |
import glob | |
import shutil | |
import time | |
import gradio as gr | |
import sys | |
# Set environment variables | |
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}" | |
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results" | |
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results" | |
def sh(cmd): subprocess.check_call(cmd, shell=True) | |
# only do this once per VM restart | |
sh("pip install -e .") | |
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .") | |
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh") | |
def install_cuda_toolkit(): | |
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" | |
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) | |
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( | |
os.environ["CUDA_HOME"], | |
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], | |
) | |
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" | |
print("==> finished installation") | |
install_cuda_toolkit() | |
# Utility to stitch frames into a video | |
def make_video_from_frames(frames_dir, out_path, fps=15): | |
if not os.path.isdir(frames_dir): | |
return None | |
files = glob.glob(os.path.join(frames_dir, "*.jpg")) + glob.glob(os.path.join(frames_dir, "*.png")) | |
if not files: | |
return None | |
ext = files[0].split('.')[-1] | |
pattern = os.path.join(frames_dir, f"%05d.{ext}") | |
subprocess.run([ | |
"ffmpeg", "-y", "-i", pattern, | |
"-r", str(fps), out_path | |
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
return out_path | |
# Function to probe video for duration and frame rate | |
def get_video_info(video_path): | |
""" | |
Probes the uploaded video and returns updated slider configs: | |
- seconds slider: max = int(duration) | |
- fps slider: max = int(orig_fps) | |
""" | |
if not video_path: | |
# Return default slider updates when no video is uploaded | |
return gr.update(maximum=10, value=3, step=1), gr.update(maximum=30, value=15, step=1) | |
# Use ffprobe to get JSON metadata | |
cmd = [ | |
"ffprobe", "-v", "quiet", | |
"-print_format", "json", | |
"-show_streams", video_path | |
] | |
res = subprocess.run(cmd, capture_output=True, text=True) | |
try: | |
import json | |
data = json.loads(res.stdout) | |
stream = next(s for s in data.get('streams', []) if s.get('codec_type') == 'video') | |
duration = float(stream.get('duration') or data.get('format', {}).get('duration', 0)) | |
fr = stream.get('r_frame_rate', '0/1') | |
num, den = fr.split('/') | |
orig_fps = float(num) / float(den) if float(den) else 30 | |
except Exception: | |
duration, orig_fps = 10, 30 | |
# Configure sliders based on actual video properties | |
seconds_cfg = gr.update(maximum=int(duration), value=min(int(duration), 3), step=1) | |
fps_cfg = gr.update(maximum=int(orig_fps), value=min(int(orig_fps), 15), step=1) | |
return seconds_cfg, fps_cfg | |
# Step 1: Trim video based on user-defined duration and fps based on user-defined duration and fps | |
def step1_trim(video_path, seconds, fps, state): | |
session_id = str(uuid.uuid4()) | |
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id) | |
state.update({"session_id": session_id, "base_dir": base_dir}) | |
tmp = tempfile.mkdtemp() | |
trimmed = os.path.join(tmp, f"{session_id}.mp4") | |
try: | |
# capture both stdout & stderr | |
p = subprocess.run([ | |
"ffmpeg", "-y", "-i", video_path, | |
"-t", str(seconds), # user-specified duration | |
"-r", str(fps), # user-specified fps | |
trimmed | |
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
all_output = [] | |
for line in p.stdout: | |
print(line, end="") # real-time echo | |
all_output.append(line) | |
except subprocess.CalledProcessError as e: | |
# e.stdout contains everything | |
err = f"β Preprocess failed (exit {e.returncode}).\n\n{e.stdout}" | |
return err, None, state | |
state["trimmed_path"] = trimmed | |
return f"β Step 1: Trimmed to {seconds}s @{fps}fps", state | |
# Step 2: Preprocessing β cropped video | |
def step2_preprocess(state): | |
session_id = state["session_id"] | |
base_dir = state["base_dir"] | |
trimmed = state["trimmed_path"] | |
try: | |
# capture both stdout & stderr | |
p = subprocess.run([ | |
"python", "scripts/run_preprocessing.py", | |
"--video_or_images_path", trimmed | |
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
except subprocess.CalledProcessError as e: | |
# e.stdout contains everything | |
err = f"β Preprocess failed (exit {e.returncode}).\n\n{e.stdout}" | |
return err, None, state | |
crop_dir = os.path.join(base_dir, "cropped") | |
out = os.path.join(os.path.dirname(trimmed), f"crop_{session_id}.mp4") | |
video = make_video_from_frames(crop_dir, out) | |
return "β Step 2: Preprocessing complete", video, state | |
# Step 3: Normals inference β normals video | |
def step3_normals(state): | |
session_id = state["session_id"] | |
base_dir = state["base_dir"] | |
try: | |
# capture both stdout & stderr | |
p = subprocess.run([ | |
"python", "scripts/network_inference.py", | |
"model.prediction_type=normals", f"video_name={session_id}" | |
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
except subprocess.CalledProcessError as e: | |
# e.stdout contains everything | |
err = f"β Normal map failed (exit {e.returncode}).\n\n{e.stdout}" | |
return err, None, state | |
normals_dir = os.path.join(base_dir, "p3dmm", "normals") | |
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"normals_{session_id}.mp4") | |
video = make_video_from_frames(normals_dir, out) | |
return "β Step 3: Normals inference complete", video, state | |
# Step 4: UV map inference β uv map video | |
def step4_uv_map(state): | |
session_id = state["session_id"] | |
base_dir = state["base_dir"] | |
try: | |
# capture both stdout & stderr | |
p = subprocess.run([ | |
"python", "scripts/network_inference.py", | |
"model.prediction_type=uv_map", f"video_name={session_id}" | |
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) | |
except subprocess.CalledProcessError as e: | |
# e.stdout contains everything | |
err = f"β UV map failed (exit {e.returncode}).\n\n{e.stdout}" | |
return err, None, state | |
uv_dir = os.path.join(base_dir, "p3dmm", "uv_map") | |
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"uv_map_{session_id}.mp4") | |
video = make_video_from_frames(uv_dir, out) | |
return "β Step 4: UV map inference complete", video, state | |
# Step 5: Tracking β final tracking video | |
def step5_track(state): | |
session_id = state["session_id"] | |
script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py") | |
cmd = [ | |
"python", script, | |
f"video_name={session_id}" | |
] | |
try: | |
# capture both stdout & stderr | |
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True) | |
except subprocess.CalledProcessError as e: | |
# e.stdout contains everything | |
err = f"β Tracking failed (exit {e.returncode}).\n\n{e.stdout}" | |
return err, None, state | |
# if we get here, it succeeded: | |
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames") | |
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"result_{session_id}.mp4") | |
video = make_video_from_frames(tracking_dir, out) | |
return "β Step 5: Tracking complete", video, state | |
# Build Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## Video Processing Pipeline") | |
with gr.Row(): | |
with gr.Column(): | |
video_in = gr.Video(label="Upload video", height=512) | |
# Sliders for duration and fps | |
seconds_slider = gr.Slider(label="Duration (seconds)", minimum=2, maximum=10, step=1, value=3) | |
fps_slider = gr.Slider(label="Frame Rate (fps)", minimum=15, maximum=30, step=1, value=15) | |
status = gr.Textbox(label="Status", lines=2, interactive=False) | |
state = gr.State({}) | |
with gr.Column(): | |
with gr.Row(): | |
crop_vid = gr.Video(label="Preprocessed", height=256) | |
normals_vid = gr.Video(label="Normals", height=256) | |
with gr.Row(): | |
uv_vid = gr.Video(label="UV Map", height=256) | |
track_vid = gr.Video(label="Tracking", height=256) | |
run_btn_1 = gr.Button("Run Pipeline 1") | |
run_btn_2 = gr.Button("Run Pipeline 2") | |
run_btn_3 = gr.Button("Run Pipeline 3") | |
run_btn_4 = gr.Button("Run Pipeline 4") | |
run_btn_5 = gr.Button("Run Pipeline 5") | |
# Update sliders after video upload | |
video_in.change(fn=get_video_info, inputs=video_in, outputs=[seconds_slider, fps_slider]) | |
# Pipeline execution | |
run_btn_1.click(fn=step1_trim, inputs=[video_in, seconds_slider, fps_slider, state], outputs=[status, state]) | |
run_btn_2.click(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state]) | |
run_btn_3.click(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state]) | |
run_btn_4.click(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state]) | |
run_btn_5.click(fn=step5_track, inputs=[state], outputs=[status, track_vid, state]) | |
# .then(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state]) | |
# .then(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state]) | |
# .then(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state]) | |
# .then(fn=step5_track, inputs=[state], outputs=[status, track_vid, state]) | |
# ------------------------------------------------------------------ | |
# START THE GRADIO SERVER | |
# ------------------------------------------------------------------ | |
demo.queue() | |
demo.launch(share=True, ssr_mode=False) | |