pixel3dmm / app.py
alexnasa's picture
Update app.py
fb2ed30 verified
import spaces
import torch._dynamo
torch._dynamo.disable()
import os
# Force Dynamo off at import‐time of torch, pytorch3d, etc.
os.environ["TORCHDYNAMO_DISABLE"] = "1"
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
import sys
from PIL import Image
import importlib, site, sys
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
def sh(cmd): subprocess.check_call(cmd, shell=True)
sh("pip install -e .")
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
from pixel3dmm import env_paths
def install_cuda_toolkit():
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
print("==> finished installation")
install_cuda_toolkit()
from omegaconf import OmegaConf
from pixel3dmm.network_inference import normals_n_uvs
from pixel3dmm.run_facer_segmentation import segment
DEVICE = "cuda"
# 2. Empty cache for our heavy objects
_model_cache = {}
def first_file_from_dir(directory, ext):
files = glob.glob(os.path.join(directory, f"*.{ext}"))
return sorted(files)[0] if files else None#
# Utility to select first image from a folder
def first_image_from_dir(directory):
patterns = ["*.jpg", "*.png", "*.jpeg"]
files = []
for p in patterns:
files.extend(glob.glob(os.path.join(directory, p)))
if not files:
return None
return sorted(files)[0]
# Function to reset the UI and state
def reset_all():
"""
Reset all UI components to their initial state when a new image is uploaded.
This function is triggered when the user uploads a new image to clear all previous
results and prepare the interface for a new reconstruction session.
Returns:
tuple: A tuple containing None values for all image components and reset UI states
"""
return (
None, # crop_img
None, # normals_img
None, # uv_img
None, # track_img
"Time to Generate!", # status
gr.update(interactive=True), # preprocess_btn
gr.update(interactive=True), # normals_btn
gr.update(interactive=True), # uv_map_btn
gr.update(interactive=True) # track_btn
)
# Step 1: Preprocess the input image (Save and Crop)
@spaces.GPU()
def preprocess_image(image_array, session_id):
if image_array is None:
return "❌ Please upload an image first.", gr.update(interactive=True), gr.update(interactive=True)
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
os.makedirs(base_dir, exist_ok=True)
img = Image.fromarray(image_array)
saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png")
img.save(saved_image_path)
import facer
if "face_detector" not in _model_cache:
device = 'cuda'
# This call downloads/loads the RetinaFace Mobilenet weights
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
# This call downloads/loads the FARL parsing model (celeba mask model)
face_parser = facer.face_parser ('farl/celebm/448', device=device)
_model_cache['face_detector'] = face_detector
_model_cache['face_parser'] = face_parser
subprocess.run([
"python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
], check=True, capture_output=True, text=True)
segment(f'{session_id}', _model_cache['face_detector'], _model_cache['face_parser'])
crop_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "cropped")
image = first_image_from_dir(crop_dir)
return "✅ Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 2: Normals inference → normals image
@spaces.GPU()
def step2_normals(session_id):
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
base_conf = OmegaConf.load("configs/base.yaml")
if "normals_model" not in _model_cache:
print("----caching normal models----")
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False)
model = model.eval().to(DEVICE)
_model_cache["normals_model"] = model
base_conf.video_name = f'{session_id}'
normals_n_uvs(base_conf, _model_cache["normals_model"])
normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals")
image = first_image_from_dir(normals_dir)
return "✅ Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 3: UV map inference → uv map image
@spaces.GPU()
def step3_uv_map(session_id):
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
base_conf = OmegaConf.load("configs/base.yaml")
if "uv_model" not in _model_cache:
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False)
model = model.eval().to(DEVICE)
_model_cache["uv_model"] = model
base_conf.video_name = f'{session_id}'
base_conf.model.prediction_type = "uv_map"
normals_n_uvs(base_conf, _model_cache["uv_model"])
uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map")
image = first_image_from_dir(uv_dir)
return "✅ Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 4: Tracking → final tracking image
@spaces.GPU()
def step4_track(session_id):
import os
import torch
import numpy as np
import trimesh
from pytorch3d.io import load_obj
from pixel3dmm.tracking.flame.FLAME import FLAME
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
from pixel3dmm.tracking.tracker import Tracker
tracking_conf = OmegaConf.load("configs/tracking.yaml")
# Lazy init + caching of FLAME model on GPU
if "flame_model" not in _model_cache:
flame = FLAME(tracking_conf) # CPU instantiation
flame = flame.to(DEVICE) # CUDA init happens here
_model_cache["flame_model"] = flame
_mesh_file = env_paths.head_template
_obj_faces = load_obj(_mesh_file)[1]
_model_cache["diff_renderer"] = NVDRenderer(
image_size=tracking_conf.size,
obj_filename=_mesh_file,
no_sh=False,
white_bg=True
).to(DEVICE)
flame_model = _model_cache["flame_model"]
diff_renderer = _model_cache["diff_renderer"]
tracking_conf.video_name = f'{session_id}'
tracker = Tracker(tracking_conf, flame_model, diff_renderer)
tracker.run()
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
image = first_image_from_dir(tracking_dir)
return "✅ Pipeline complete!", image, gr.update(interactive=True)
# New: run all steps sequentially
@spaces.GPU(duration=120)
def generate_results_and_mesh(image, session_id=None):
"""
Process an input image through a 3D reconstruction pipeline and return the intermediate outputs and mesh file.
This function is triggered when the user clicks the "Reconstruct Face" button or selects an example.
It runs a multi‐step workflow to go from a raw input image to a reconstructed 3D mesh:
1. **Preprocessing**: crops and masks the image for object isolation.
2. **Normals Estimation**: computes surface normal maps.
3. **UV Mapping**: generates UV coordinate maps for texturing.
4. **Tracking**: performs final alignment/tracking to prepare for mesh export.
5. **Mesh Discovery**: locates the resulting `.ply` file in the tracking output directory.
Args:
image (PIL.Image.Image or ndarray): Input image to reconstruct.
session_id (str): Unique identifier for this session's output directories.
Returns:
tuple:
- final_status (str): Newline‐separated status messages from each pipeline step.
- crop_img (Image or None): Cropped and preprocessed image.
- normals_img (Image or None): Estimated surface normals visualization.
- uv_img (Image or None): UV‐map visualization.
- track_img (Image or None): Tracking/registration result.
- mesh_file (str or None): Path to the generated 3D mesh (`.ply`), if found.
"""
if session_id is None:
session_id = uuid.uuid4().hex
# Step 1
status1, crop_img, _, _ = preprocess_image(image, session_id)
if "❌" in status1:
return status1, None, None, None, None, None
# Step 2
status2, normals_img, _, _ = step2_normals(session_id)
# Step 3
status3, uv_img, _, _ = step3_uv_map(session_id)
# Step 4
status4, track_img, _ = step4_track(session_id)
# Locate mesh (.ply)
mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh")
mesh_file = first_file_from_dir(mesh_dir, "glb")
final_status = "\n".join([status1, status2, status3, status4])
return final_status, crop_img, normals_img, uv_img, track_img, mesh_file
# Cleanup on unload
def cleanup(request: gr.Request):
"""
Clean up session-specific directories and temporary files when the user session ends.
This function is triggered when the Gradio demo is unloaded (e.g., when the user
closes the browser tab or navigates away). It removes all temporary files and
directories created during the user's session to free up storage space.
Args:
request (gr.Request): Gradio request object containing session information
"""
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid)
d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid)
shutil.rmtree(d1, ignore_errors=True)
shutil.rmtree(d2, ignore_errors=True)
def start_session(request: gr.Request):
"""
Initialize a new user session and return the session identifier.
This function is triggered when the Gradio demo loads and creates a unique
session hash that will be used to organize outputs and temporary files
for this specific user session.
Args:
request (gr.Request): Gradio request object containing session information
Returns:
str: Unique session hash identifier
"""
return request.session_hash
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
# Build Gradio UI
with gr.Blocks(css=css) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>Pixel3dmm [Image Mode]</strong> – Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction.
</p>
<a href="https://github.com/SimonGiebenhain/pixel3dmm" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Upload Image", type="numpy", height=512)
run_btn = gr.Button("Reconstruct Face", variant="primary")
with gr.Row():
crop_img = gr.Image(label="Preprocessed", height=128)
normals_img = gr.Image(label="Normals", height=128)
uv_img = gr.Image(label="UV Map", height=128)
track_img = gr.Image(label="Tracking", height=128)
with gr.Column():
mesh_file = gr.Model3D(label="3D Model Preview", height=512)
examples = gr.Examples(
examples=[
["example_images/jennifer_lawrence.png"],
["example_images/brendan_fraser.png"],
["example_images/jim_carrey.png"],
],
inputs=[image_in],
outputs=[gr.State(), crop_img, normals_img, uv_img, track_img, mesh_file],
fn=generate_results_and_mesh,
cache_examples=True
)
status = gr.Textbox(label="Status", lines=5, interactive=True, value="Upload an image to start.")
run_btn.click(
fn=generate_results_and_mesh,
inputs=[image_in, session_state],
outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file]
)
image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn])
demo.unload(cleanup)
demo.queue()
demo.launch(share=True, mcp_server=True)