pixel3dmm / app.py
alexnasa's picture
Update app.py
63529cc verified
raw
history blame
9 kB
import spaces
import os
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
import sys
from PIL import Image
# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
def sh(cmd): subprocess.check_call(cmd, shell=True)
from pixel3dmm import env_paths
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
def install_cuda_toolkit():
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
print("==> finished installation")
install_cuda_toolkit()
import os
import torch
import numpy as np
import trimesh
from pytorch3d.io import load_obj
from pixel3dmm import env_paths
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
from pixel3dmm.tracking.flame.FLAME import FLAME
from pixel3dmm.tracking.tracker import Tracker
from omegaconf import OmegaConf
DEVICE = "cuda"
base_conf = OmegaConf.load(f'{env_paths.CODE_BASE}/configs/tracking.yaml')
_mesh_file = env_paths.head_template
flame_model = FLAME(base_conf).to(DEVICE)
_obj_faces = load_obj(_mesh_file)[1]
diff_renderer = NVDRenderer(
image_size=base_conf.size,
obj_filename=_mesh_file,
no_sh=False,
white_bg=True
).to(DEVICE)
# Utility to select first image from a folder
def first_image_from_dir(directory):
patterns = ["*.jpg", "*.png", "*.jpeg"]
files = []
for p in patterns:
files.extend(glob.glob(os.path.join(directory, p)))
if not files:
return None
return sorted(files)[0]
# Function to reset the UI and state
def reset_all():
return (
None, # crop_img
None, # normals_img
None, # uv_img
None, # track_img
"Awaiting new image upload...", # status
{}, # state
gr.update(interactive=True), # preprocess_btn
gr.update(interactive=False), # normals_btn
gr.update(interactive=False), # uv_map_btn
gr.update(interactive=False) # track_btn
)
# Step 1: Preprocess the input image (Save and Crop)
# @spaces.GPU()
def preprocess_image(image_array, state):
if image_array is None:
return "❌ Please upload an image first.", None, state, gr.update(interactive=True), gr.update(interactive=False)
session_id = str(uuid.uuid4())
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
os.makedirs(base_dir, exist_ok=True)
state.update({"session_id": session_id, "base_dir": base_dir})
img = Image.fromarray(image_array)
saved_image_path = os.path.join(base_dir, f"{session_id}.png")
img.save(saved_image_path)
state["image_path"] = saved_image_path
try:
p = subprocess.run([
"python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
shutil.rmtree(base_dir)
return err, None, {}, gr.update(interactive=True), gr.update(interactive=False)
crop_dir = os.path.join(base_dir, "cropped")
image = first_image_from_dir(crop_dir)
return "βœ… Step 1 complete. Ready for Normals.", image, state, gr.update(interactive=False), gr.update(interactive=True)
# Step 2: Normals inference β†’ normals image
@spaces.GPU()
def step2_normals(state):
session_id = state.get("session_id")
if not session_id:
return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
try:
p = subprocess.run([
"python", "scripts/network_inference.py", "model.prediction_type=normals", f"video_name={session_id}"
], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
err = f"❌ Normal map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
normals_dir = os.path.join(state["base_dir"], "p3dmm", "normals")
image = first_image_from_dir(normals_dir)
return "βœ… Step 2 complete. Ready for UV Map.", image, state, gr.update(interactive=False), gr.update(interactive=True)
# Step 3: UV map inference β†’ uv map image
@spaces.GPU()
def step3_uv_map(state):
session_id = state.get("session_id")
if not session_id:
return "❌ State lost. Please start from Step 1.", None, state, gr.update(interactive=False), gr.update(interactive=False)
try:
p = subprocess.run([
"python", "scripts/network_inference.py", "model.prediction_type=uv_map", f"video_name={session_id}"
], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
err = f"❌ UV map failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
return err, None, state, gr.update(interactive=True), gr.update(interactive=False)
uv_dir = os.path.join(state["base_dir"], "p3dmm", "uv_map")
image = first_image_from_dir(uv_dir)
return "βœ… Step 3 complete. Ready for Tracking.", image, state, gr.update(interactive=False), gr.update(interactive=True)
# Step 4: Tracking β†’ final tracking image
@spaces.GPU()
def step4_track(state):
session_id = state.get("session_id")
base_conf.video_name = f'{session_id}'
tracker = Tracker(base_conf, flame_model, diff_renderer)
tracker.run()
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
image = first_image_from_dir(tracking_dir)
return "βœ… Pipeline complete!", image, state, gr.update(interactive=False)
# Build Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("## Image Processing Pipeline")
gr.Markdown("Upload an image, then click the buttons in order. Uploading a new image will reset the process.")
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Upload Image", type="numpy", height=512)
status = gr.Textbox(label="Status", lines=2, interactive=False, value="Upload an image to start.")
state = gr.State({})
with gr.Column():
with gr.Row():
crop_img = gr.Image(label="Preprocessed", height=256)
normals_img = gr.Image(label="Normals", height=256)
with gr.Row():
uv_img = gr.Image(label="UV Map", height=256)
track_img = gr.Image(label="Tracking", height=256)
with gr.Row():
preprocess_btn = gr.Button("Step 1: Preprocess", interactive=True)
normals_btn = gr.Button("Step 2: Normals", interactive=False)
uv_map_btn = gr.Button("Step 3: UV Map", interactive=False)
track_btn = gr.Button("Step 4: Track", interactive=False)
# Define component list for reset
outputs_for_reset = [crop_img, normals_img, uv_img, track_img, status, state, preprocess_btn, normals_btn, uv_map_btn, track_btn]
# Pipeline execution logic
preprocess_btn.click(
fn=preprocess_image,
inputs=[image_in, state],
outputs=[status, crop_img, state, preprocess_btn, normals_btn]
)
normals_btn.click(
fn=step2_normals,
inputs=[state],
outputs=[status, normals_img, state, normals_btn, uv_map_btn]
)
uv_map_btn.click(
fn=step3_uv_map,
inputs=[state],
outputs=[status, uv_img, state, uv_map_btn, track_btn]
)
track_btn.click(
fn=step4_track,
inputs=[state],
outputs=[status, track_img, state, track_btn]
)
# Event to reset everything when a new image is uploaded
image_in.upload(fn=reset_all, inputs=None, outputs=outputs_for_reset)
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
demo.queue()
demo.launch(share=True, ssr_mode=False)