Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch._dynamo | |
torch._dynamo.disable() | |
import os | |
# Force Dynamo off at import‐time of torch, pytorch3d, etc. | |
os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
import subprocess | |
import tempfile | |
import uuid | |
import glob | |
import shutil | |
import time | |
import gradio as gr | |
import sys | |
from PIL import Image | |
import importlib, site, sys | |
# Re-discover all .pth/.egg-link files | |
for sitedir in site.getsitepackages(): | |
site.addsitedir(sitedir) | |
# Clear caches so importlib will pick up new modules | |
importlib.invalidate_caches() | |
# Set environment variables | |
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}" | |
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results" | |
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results" | |
def sh(cmd): subprocess.check_call(cmd, shell=True) | |
sh("pip install -e .") | |
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..") | |
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..") | |
# tell Python to re-scan site-packages now that the egg-link exists | |
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches() | |
from pixel3dmm import env_paths | |
def install_cuda_toolkit(): | |
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run" | |
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) | |
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) | |
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) | |
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) | |
os.environ["CUDA_HOME"] = "/usr/local/cuda" | |
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) | |
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( | |
os.environ["CUDA_HOME"], | |
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], | |
) | |
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range | |
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" | |
print("==> finished installation") | |
install_cuda_toolkit() | |
from omegaconf import OmegaConf | |
from pixel3dmm.network_inference import normals_n_uvs | |
from pixel3dmm.run_facer_segmentation import segment | |
DEVICE = "cuda" | |
# 2. Empty cache for our heavy objects | |
_model_cache = {} | |
def first_file_from_dir(directory, ext): | |
files = glob.glob(os.path.join(directory, f"*.{ext}")) | |
return sorted(files)[0] if files else None# | |
# Utility to select first image from a folder | |
def first_image_from_dir(directory): | |
patterns = ["*.jpg", "*.png", "*.jpeg"] | |
files = [] | |
for p in patterns: | |
files.extend(glob.glob(os.path.join(directory, p))) | |
if not files: | |
return None | |
return sorted(files)[0] | |
# Function to reset the UI and state | |
def reset_all(): | |
""" | |
Reset all UI components to their initial state when a new image is uploaded. | |
This function is triggered when the user uploads a new image to clear all previous | |
results and prepare the interface for a new reconstruction session. | |
Returns: | |
tuple: A tuple containing None values for all image components and reset UI states | |
""" | |
return ( | |
None, # crop_img | |
None, # normals_img | |
None, # uv_img | |
None, # track_img | |
"Time to Generate!", # status | |
gr.update(interactive=True), # preprocess_btn | |
gr.update(interactive=True), # normals_btn | |
gr.update(interactive=True), # uv_map_btn | |
gr.update(interactive=True) # track_btn | |
) | |
# Step 1: Preprocess the input image (Save and Crop) | |
def preprocess_image(image_array, session_id): | |
if image_array is None: | |
return "❌ Please upload an image first.", gr.update(interactive=True), gr.update(interactive=True) | |
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id) | |
os.makedirs(base_dir, exist_ok=True) | |
img = Image.fromarray(image_array) | |
saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png") | |
img.save(saved_image_path) | |
import facer | |
if "face_detector" not in _model_cache: | |
device = 'cuda' | |
# This call downloads/loads the RetinaFace Mobilenet weights | |
face_detector = facer.face_detector('retinaface/mobilenet', device=device) | |
# This call downloads/loads the FARL parsing model (celeba mask model) | |
face_parser = facer.face_parser ('farl/celebm/448', device=device) | |
_model_cache['face_detector'] = face_detector | |
_model_cache['face_parser'] = face_parser | |
subprocess.run([ | |
"python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path | |
], check=True, capture_output=True, text=True) | |
segment(f'{session_id}', _model_cache['face_detector'], _model_cache['face_parser']) | |
crop_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "cropped") | |
image = first_image_from_dir(crop_dir) | |
return "✅ Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True) | |
# Step 2: Normals inference → normals image | |
def step2_normals(session_id): | |
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system | |
base_conf = OmegaConf.load("configs/base.yaml") | |
if "normals_model" not in _model_cache: | |
print("----caching normal models----") | |
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False) | |
model = model.eval().to(DEVICE) | |
_model_cache["normals_model"] = model | |
base_conf.video_name = f'{session_id}' | |
normals_n_uvs(base_conf, _model_cache["normals_model"]) | |
normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals") | |
image = first_image_from_dir(normals_dir) | |
return "✅ Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True) | |
# Step 3: UV map inference → uv map image | |
def step3_uv_map(session_id): | |
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system | |
base_conf = OmegaConf.load("configs/base.yaml") | |
if "uv_model" not in _model_cache: | |
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False) | |
model = model.eval().to(DEVICE) | |
_model_cache["uv_model"] = model | |
base_conf.video_name = f'{session_id}' | |
base_conf.model.prediction_type = "uv_map" | |
normals_n_uvs(base_conf, _model_cache["uv_model"]) | |
uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map") | |
image = first_image_from_dir(uv_dir) | |
return "✅ Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True) | |
# Step 4: Tracking → final tracking image | |
def step4_track(session_id): | |
import os | |
import torch | |
import numpy as np | |
import trimesh | |
from pytorch3d.io import load_obj | |
from pixel3dmm.tracking.flame.FLAME import FLAME | |
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer | |
from pixel3dmm.tracking.tracker import Tracker | |
tracking_conf = OmegaConf.load("configs/tracking.yaml") | |
# Lazy init + caching of FLAME model on GPU | |
if "flame_model" not in _model_cache: | |
flame = FLAME(tracking_conf) # CPU instantiation | |
flame = flame.to(DEVICE) # CUDA init happens here | |
_model_cache["flame_model"] = flame | |
_mesh_file = env_paths.head_template | |
_obj_faces = load_obj(_mesh_file)[1] | |
_model_cache["diff_renderer"] = NVDRenderer( | |
image_size=tracking_conf.size, | |
obj_filename=_mesh_file, | |
no_sh=False, | |
white_bg=True | |
).to(DEVICE) | |
flame_model = _model_cache["flame_model"] | |
diff_renderer = _model_cache["diff_renderer"] | |
tracking_conf.video_name = f'{session_id}' | |
tracker = Tracker(tracking_conf, flame_model, diff_renderer) | |
tracker.run() | |
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames") | |
image = first_image_from_dir(tracking_dir) | |
return "✅ Pipeline complete!", image, gr.update(interactive=True) | |
# New: run all steps sequentially | |
def generate_results_and_mesh(image, session_id=None): | |
""" | |
Process an input image through a 3D reconstruction pipeline and return the intermediate outputs and mesh file. | |
This function is triggered when the user clicks the "Reconstruct Face" button or selects an example. | |
It runs a multi‐step workflow to go from a raw input image to a reconstructed 3D mesh: | |
1. **Preprocessing**: crops and masks the image for object isolation. | |
2. **Normals Estimation**: computes surface normal maps. | |
3. **UV Mapping**: generates UV coordinate maps for texturing. | |
4. **Tracking**: performs final alignment/tracking to prepare for mesh export. | |
5. **Mesh Discovery**: locates the resulting `.ply` file in the tracking output directory. | |
Args: | |
image (PIL.Image.Image or ndarray): Input image to reconstruct. | |
session_id (str): Unique identifier for this session's output directories. | |
Returns: | |
tuple: | |
- final_status (str): Newline‐separated status messages from each pipeline step. | |
- crop_img (Image or None): Cropped and preprocessed image. | |
- normals_img (Image or None): Estimated surface normals visualization. | |
- uv_img (Image or None): UV‐map visualization. | |
- track_img (Image or None): Tracking/registration result. | |
- mesh_file (str or None): Path to the generated 3D mesh (`.ply`), if found. | |
""" | |
if session_id is None: | |
session_id = uuid.uuid4().hex | |
# Step 1 | |
status1, crop_img, _, _ = preprocess_image(image, session_id) | |
if "❌" in status1: | |
return status1, None, None, None, None, None | |
# Step 2 | |
status2, normals_img, _, _ = step2_normals(session_id) | |
# Step 3 | |
status3, uv_img, _, _ = step3_uv_map(session_id) | |
# Step 4 | |
status4, track_img, _ = step4_track(session_id) | |
# Locate mesh (.ply) | |
mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh") | |
mesh_file = first_file_from_dir(mesh_dir, "glb") | |
final_status = "\n".join([status1, status2, status3, status4]) | |
return final_status, crop_img, normals_img, uv_img, track_img, mesh_file | |
# Cleanup on unload | |
def cleanup(request: gr.Request): | |
""" | |
Clean up session-specific directories and temporary files when the user session ends. | |
This function is triggered when the Gradio demo is unloaded (e.g., when the user | |
closes the browser tab or navigates away). It removes all temporary files and | |
directories created during the user's session to free up storage space. | |
Args: | |
request (gr.Request): Gradio request object containing session information | |
""" | |
sid = request.session_hash | |
if sid: | |
d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid) | |
d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid) | |
shutil.rmtree(d1, ignore_errors=True) | |
shutil.rmtree(d2, ignore_errors=True) | |
def start_session(request: gr.Request): | |
""" | |
Initialize a new user session and return the session identifier. | |
This function is triggered when the Gradio demo loads and creates a unique | |
session hash that will be used to organize outputs and temporary files | |
for this specific user session. | |
Args: | |
request (gr.Request): Gradio request object containing session information | |
Returns: | |
str: Unique session hash identifier | |
""" | |
return request.session_hash | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
# Build Gradio UI | |
with gr.Blocks(css=css) as demo: | |
session_state = gr.State() | |
demo.load(start_session, outputs=[session_state]) | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<p style="font-size:16px; display: inline; margin: 0;"> | |
<strong>Pixel3dmm [Image Mode]</strong> – Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction. | |
</p> | |
<a href="https://github.com/SimonGiebenhain/pixel3dmm" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Column(elem_id="col-container"): | |
with gr.Row(): | |
with gr.Column(): | |
image_in = gr.Image(label="Upload Image", type="numpy", height=512) | |
run_btn = gr.Button("Reconstruct Face", variant="primary") | |
with gr.Row(): | |
crop_img = gr.Image(label="Preprocessed", height=128) | |
normals_img = gr.Image(label="Normals", height=128) | |
uv_img = gr.Image(label="UV Map", height=128) | |
track_img = gr.Image(label="Tracking", height=128) | |
with gr.Column(): | |
mesh_file = gr.Model3D(label="3D Model Preview", height=512) | |
examples = gr.Examples( | |
examples=[ | |
["example_images/jennifer_lawrence.png"], | |
["example_images/brendan_fraser.png"], | |
["example_images/jim_carrey.png"], | |
], | |
inputs=[image_in], | |
outputs=[gr.State(), crop_img, normals_img, uv_img, track_img, mesh_file], | |
fn=generate_results_and_mesh, | |
cache_examples=True | |
) | |
status = gr.Textbox(label="Status", lines=5, interactive=True, value="Upload an image to start.") | |
run_btn.click( | |
fn=generate_results_and_mesh, | |
inputs=[image_in, session_state], | |
outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file] | |
) | |
image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn]) | |
demo.unload(cleanup) | |
demo.queue() | |
demo.launch(share=True, mcp_server=True) |