Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import subprocess | |
import tempfile | |
import uuid | |
import glob | |
import shutil | |
import time | |
import gradio as gr | |
import sys | |
from PIL import Image | |
# Set environment variables | |
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}" | |
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results" | |
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results" | |
def sh(cmd): subprocess.check_call(cmd, shell=True) | |
# only do this once per VM restart | |
sh("pip install -e .") | |
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e .") | |
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh") | |
def install_cuda_toolkit(): | |
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" | |
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) | |
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( | |
os.environ["CUDA_HOME"], | |
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], | |
) | |
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" | |
print("==> finished installation") | |
install_cuda_toolkit() | |
# Utility to select first image from a folder | |
def first_image_from_dir(directory): | |
patterns = ["*.jpg", "*.png", "*.jpeg"] | |
files = [] | |
for p in patterns: | |
files.extend(glob.glob(os.path.join(directory, p))) | |
if not files: | |
return None | |
return sorted(files)[0] | |
# Step 1: Preprocess the input image (Save and Crop) | |
def preprocess_image(image_array, state): | |
# Check if an image was uploaded | |
if image_array is None: | |
return "β Please upload an image first.", None, state | |
# Step 1a: Save the uploaded image | |
session_id = str(uuid.uuid4()) | |
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id) | |
os.makedirs(base_dir, exist_ok=True) | |
state.update({"session_id": session_id, "base_dir": base_dir}) | |
img = Image.fromarray(image_array) | |
saved_image_path = os.path.join(base_dir, f"{session_id}.png") | |
img.save(saved_image_path) | |
state["image_path"] = saved_image_path | |
# Step 1b: Run the preprocessing script | |
try: | |
p = subprocess.run([ | |
"python", "scripts/run_preprocessing.py", | |
"--video_or_images_path", saved_image_path | |
], check=True, capture_output=True, text=True) | |
except subprocess.CalledProcessError as e: | |
err = f"β Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
# Clean up created directory on failure | |
shutil.rmtree(base_dir) | |
return err, None, state | |
crop_dir = os.path.join(base_dir, "cropped") | |
image = first_image_from_dir(crop_dir) | |
return "β Preprocessing complete", image, state | |
# Step 2: Normals inference β normals image | |
def step2_normals(state): | |
session_id = state.get("session_id") | |
if not session_id: | |
return "β Please preprocess an image first.", None, state | |
try: | |
# Execute the network inference for normals | |
p = subprocess.run([ | |
"python", "scripts/network_inference.py", | |
"model.prediction_type=normals", f"video_name={session_id}" | |
], check=True, capture_output=True, text=True) | |
except subprocess.CalledProcessError as e: | |
err = f"β Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
return err, None, state | |
normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals") | |
image = first_image_from_dir(normals_dir) | |
return "β Step 2: Normals inference complete", image, state | |
# Step 3: UV map inference β uv map image | |
def step3_uv_map(state): | |
session_id = state.get("session_id") | |
if not session_id: | |
return "β Please preprocess an image first.", None, state | |
try: | |
# Execute the network inference for UV map | |
p = subprocess.run([ | |
"python", "scripts/network_inference.py", | |
"model.prediction_type=uv_map", f"video_name={session_id}" | |
], check=True, capture_output=True, text=True) | |
except subprocess.CalledProcessError as e: | |
err = f"β UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
return err, None, state | |
uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map") | |
image = first_image_from_dir(uv_dir) | |
return "β Step 3: UV map inference complete", image, state | |
# Step 4: Tracking β final tracking image | |
def step4_track(state): | |
session_id = state.get("session_id") | |
if not session_id: | |
return "β Please preprocess an image first.", None, state | |
script = os.path.join(os.environ["PIXEL3DMM_CODE_BASE"], "scripts", "track.py") | |
try: | |
# Execute the tracking script | |
p = subprocess.run([ | |
"python", script, | |
f"video_name={session_id}" | |
], check=True, capture_output=True, text=True) | |
except subprocess.CalledProcessError as e: | |
err = f"β Tracking failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}" | |
return err, None, state | |
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames") | |
image = first_image_from_dir(tracking_dir) | |
return "β Step 4: Tracking complete", image, state | |
# Build Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("## Image Processing Pipeline") | |
with gr.Row(): | |
with gr.Column(): | |
image_in = gr.Image(label="Upload Image", type="numpy", height=512) | |
status = gr.Textbox(label="Status", lines=2, interactive=False) | |
state = gr.State({}) | |
with gr.Column(): | |
with gr.Row(): | |
crop_img = gr.Image(label="Preprocessed", height=256) | |
normals_img = gr.Image(label="Normals", height=256) | |
with gr.Row(): | |
uv_img = gr.Image(label="UV Map", height=256) | |
track_img = gr.Image(label="Tracking", height=256) | |
with gr.Row(): | |
preprocess_btn = gr.Button("Step 1: Preprocess") | |
normals_btn = gr.Button("Step 2: Normals") | |
uv_map_btn = gr.Button("Step 3: UV Map") | |
track_btn = gr.Button("Step 4: Track") | |
# Pipeline execution | |
preprocess_btn.click(fn=preprocess_image, inputs=[image_in, state], outputs=[status, crop_img, state]) | |
normals_btn.click(fn=step2_normals, inputs=[state], outputs=[status, normals_img, state]) | |
uv_map_btn.click(fn=step3_uv_map, inputs=[state], outputs=[status, uv_img, state]) | |
track_btn.click(fn=step4_track, inputs=[state], outputs=[status, track_img, state]) | |
# ------------------------------------------------------------------ | |
# START THE GRADIO SERVER | |
# ------------------------------------------------------------------ | |
demo.queue() | |
demo.launch(share=True, ssr_mode=False) | |