pixel3dmm / app.py
alexnasa's picture
Update app.py
9c0c48e verified
raw
history blame
10.7 kB
import spaces
import os
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
import sys
# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
def sh(cmd): subprocess.check_call(cmd, shell=True)
# only do this once per VM restart
sh("pip install -e .")
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .")
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh")
def install_cuda_toolkit():
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
print("==> finished installation")
install_cuda_toolkit()
# Utility to stitch frames into a video
def make_video_from_frames(frames_dir, out_path, fps=15):
if not os.path.isdir(frames_dir):
return None
files = glob.glob(os.path.join(frames_dir, "*.jpg")) + glob.glob(os.path.join(frames_dir, "*.png"))
if not files:
return None
ext = files[0].split('.')[-1]
pattern = os.path.join(frames_dir, f"%05d.{ext}")
subprocess.run([
"ffmpeg", "-y", "-i", pattern,
"-r", str(fps), out_path
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return out_path
# Function to probe video for duration and frame rate
def get_video_info(video_path):
"""
Probes the uploaded video and returns updated slider configs:
- seconds slider: max = int(duration)
- fps slider: max = int(orig_fps)
"""
if not video_path:
# Return default slider updates when no video is uploaded
return gr.update(maximum=10, value=3, step=1), gr.update(maximum=30, value=15, step=1)
# Use ffprobe to get JSON metadata
cmd = [
"ffprobe", "-v", "quiet",
"-print_format", "json",
"-show_streams", video_path
]
res = subprocess.run(cmd, capture_output=True, text=True)
try:
import json
data = json.loads(res.stdout)
stream = next(s for s in data.get('streams', []) if s.get('codec_type') == 'video')
duration = float(stream.get('duration') or data.get('format', {}).get('duration', 0))
fr = stream.get('r_frame_rate', '0/1')
num, den = fr.split('/')
orig_fps = float(num) / float(den) if float(den) else 30
except Exception:
duration, orig_fps = 10, 30
# Configure sliders based on actual video properties
seconds_cfg = gr.update(maximum=int(duration), value=min(int(duration), 3), step=1)
fps_cfg = gr.update(maximum=int(orig_fps), value=min(int(orig_fps), 15), step=1)
return seconds_cfg, fps_cfg
# Step 1: Trim video based on user-defined duration and fps based on user-defined duration and fps
def step1_trim(video_path, seconds, fps, state):
session_id = str(uuid.uuid4())
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
state.update({"session_id": session_id, "base_dir": base_dir})
tmp = tempfile.mkdtemp()
trimmed = os.path.join(tmp, f"{session_id}.mp4")
try:
# capture both stdout & stderr
p = subprocess.run([
"ffmpeg", "-y", "-i", video_path,
"-t", str(seconds), # user-specified duration
"-r", str(fps), # user-specified fps
trimmed
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
all_output = []
for line in p.stdout:
print(line, end="") # real-time echo
all_output.append(line)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
state["trimmed_path"] = trimmed
return f"βœ… Step 1: Trimmed to {seconds}s @{fps}fps", state
# Step 2: Preprocessing β†’ cropped video
@spaces.GPU()
def step2_preprocess(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
trimmed = state["trimmed_path"]
try:
# capture both stdout & stderr
p = subprocess.run([
"python", "scripts/run_preprocessing.py",
"--video_or_images_path", trimmed
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
crop_dir = os.path.join(base_dir, "cropped")
out = os.path.join(os.path.dirname(trimmed), f"crop_{session_id}.mp4")
video = make_video_from_frames(crop_dir, out)
return "βœ… Step 2: Preprocessing complete", video, state
# Step 3: Normals inference β†’ normals video
@spaces.GPU()
def step3_normals(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
try:
# capture both stdout & stderr
p = subprocess.run([
"python", "scripts/network_inference.py",
"model.prediction_type=normals", f"video_name={session_id}"
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
normals_dir = os.path.join(base_dir, "p3dmm", "normals")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"normals_{session_id}.mp4")
video = make_video_from_frames(normals_dir, out)
return "βœ… Step 3: Normals inference complete", video, state
# Step 4: UV map inference β†’ uv map video
@spaces.GPU()
def step4_uv_map(state):
session_id = state["session_id"]
base_dir = state["base_dir"]
try:
# capture both stdout & stderr
p = subprocess.run([
"python", "scripts/network_inference.py",
"model.prediction_type=uv_map", f"video_name={session_id}"
], check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
uv_dir = os.path.join(base_dir, "p3dmm", "uv_map")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"uv_map_{session_id}.mp4")
video = make_video_from_frames(uv_dir, out)
return "βœ… Step 4: UV map inference complete", video, state
# Step 5: Tracking β†’ final tracking video
@spaces.GPU()
def step5_track(state):
session_id = state["session_id"]
script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py")
cmd = [
"python", script,
f"video_name={session_id}"
]
try:
# capture both stdout & stderr
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, check=True)
except subprocess.CalledProcessError as e:
# e.stdout contains everything
err = f"❌ Tracking failed (exit {e.returncode}).\n\n{e.stdout}"
return err, None, state
# if we get here, it succeeded:
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
out = os.path.join(os.path.dirname(state["trimmed_path"]), f"result_{session_id}.mp4")
video = make_video_from_frames(tracking_dir, out)
return "βœ… Step 5: Tracking complete", video, state
# Build Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("## Video Processing Pipeline")
with gr.Row():
with gr.Column():
video_in = gr.Video(label="Upload video", height=512)
# Sliders for duration and fps
seconds_slider = gr.Slider(label="Duration (seconds)", minimum=2, maximum=10, step=1, value=3)
fps_slider = gr.Slider(label="Frame Rate (fps)", minimum=15, maximum=30, step=1, value=15)
status = gr.Textbox(label="Status", lines=2, interactive=False)
state = gr.State({})
with gr.Column():
with gr.Row():
crop_vid = gr.Video(label="Preprocessed", height=256)
normals_vid = gr.Video(label="Normals", height=256)
with gr.Row():
uv_vid = gr.Video(label="UV Map", height=256)
track_vid = gr.Video(label="Tracking", height=256)
run_btn_1 = gr.Button("Run Pipeline 1")
run_btn_2 = gr.Button("Run Pipeline 2")
run_btn_3 = gr.Button("Run Pipeline 3")
run_btn_4 = gr.Button("Run Pipeline 4")
run_btn_5 = gr.Button("Run Pipeline 5")
# Update sliders after video upload
video_in.change(fn=get_video_info, inputs=video_in, outputs=[seconds_slider, fps_slider])
# Pipeline execution
run_btn_1.click(fn=step1_trim, inputs=[video_in, seconds_slider, fps_slider, state], outputs=[status, state])
run_btn_2.click(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
run_btn_3.click(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
run_btn_4.click(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
run_btn_5.click(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
# .then(fn=step2_preprocess, inputs=[state], outputs=[status, crop_vid, state])
# .then(fn=step3_normals, inputs=[state], outputs=[status, normals_vid, state])
# .then(fn=step4_uv_map, inputs=[state], outputs=[status, uv_vid, state])
# .then(fn=step5_track, inputs=[state], outputs=[status, track_vid, state])
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
demo.queue()
demo.launch(share=True, ssr_mode=False)