pixel3dmm / app.py
alexnasa's picture
Update app.py
a73bdee verified
raw
history blame
13.6 kB
import spaces
import torch._dynamo
torch._dynamo.disable()
import os
# Force Dynamo off at import‐time of torch, pytorch3d, etc.
os.environ["TORCHDYNAMO_DISABLE"] = "1"
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
import sys
from PIL import Image
import importlib, site, sys
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()
# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"
def sh(cmd): subprocess.check_call(cmd, shell=True)
sh("pip install -e .")
# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
from pixel3dmm import env_paths
sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")
def install_cuda_toolkit():
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
print("==> finished installation")
install_cuda_toolkit()
from omegaconf import OmegaConf
from pixel3dmm.network_inference import normals_n_uvs
from pixel3dmm.run_facer_segmentation import segment
DEVICE = "cuda"
# 2. Empty cache for our heavy objects
_model_cache = {}
def first_file_from_dir(directory, ext):
files = glob.glob(os.path.join(directory, f"*.{ext}"))
return sorted(files)[0] if files else None#
# Utility to select first image from a folder
def first_image_from_dir(directory):
patterns = ["*.jpg", "*.png", "*.jpeg"]
files = []
for p in patterns:
files.extend(glob.glob(os.path.join(directory, p)))
if not files:
return None
return sorted(files)[0]
# Function to reset the UI and state
def reset_all():
return (
None, # crop_img
None, # normals_img
None, # uv_img
None, # track_img
"Time to Generate!", # status
gr.update(interactive=True), # preprocess_btn
gr.update(interactive=True), # normals_btn
gr.update(interactive=True), # uv_map_btn
gr.update(interactive=True) # track_btn
)
# Step 1: Preprocess the input image (Save and Crop)
@spaces.GPU()
def preprocess_image(image_array, session_id):
if image_array is None:
return "❌ Please upload an image first.", gr.update(interactive=True), gr.update(interactive=True)
base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
os.makedirs(base_dir, exist_ok=True)
img = Image.fromarray(image_array)
saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png")
img.save(saved_image_path)
import facer
if "face_detector" not in _model_cache:
device = 'cuda'
# This call downloads/loads the RetinaFace Mobilenet weights
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
# This call downloads/loads the FARL parsing model (celeba mask model)
face_parser = facer.face_parser ('farl/celebm/448', device=device)
_model_cache['face_detector'] = face_detector
_model_cache['face_parser'] = face_parser
_model_cache['facer_module'] = facer.hwc2bchw
subprocess.run([
"python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
], check=True, capture_output=True, text=True)
segment(f'{session_id}', _model_cache['face_detector'], _model_cache['face_parser'], _model_cache['facer_module'])
crop_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "cropped")
image = first_image_from_dir(crop_dir)
return "✅ Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 2: Normals inference → normals image
@spaces.GPU()
def step2_normals(session_id):
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
base_conf = OmegaConf.load("configs/base.yaml")
if "normals_model" not in _model_cache:
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False)
model = model.eval().to(DEVICE)
_model_cache["normals_model"] = model
base_conf.video_name = f'{session_id}'
normals_n_uvs(base_conf, _model_cache["normals_model"])
normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals")
image = first_image_from_dir(normals_dir)
return "✅ Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 3: UV map inference → uv map image
@spaces.GPU()
def step3_uv_map(session_id):
from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
base_conf = OmegaConf.load("configs/base.yaml")
if "uv_model" not in _model_cache:
model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False)
model = model.eval().to(DEVICE)
_model_cache["uv_model"] = model
base_conf.video_name = f'{session_id}'
base_conf.model.prediction_type = "uv_map"
normals_n_uvs(base_conf, _model_cache["uv_model"])
uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map")
image = first_image_from_dir(uv_dir)
return "✅ Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True)
# Step 4: Tracking → final tracking image
@spaces.GPU()
def step4_track(session_id):
import os
import torch
import numpy as np
import trimesh
from pytorch3d.io import load_obj
from pixel3dmm.tracking.flame.FLAME import FLAME
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
from pixel3dmm.tracking.tracker import Tracker
tracking_conf = OmegaConf.load("configs/tracking.yaml")
# Lazy init + caching of FLAME model on GPU
if "flame_model" not in _model_cache:
flame = FLAME(tracking_conf) # CPU instantiation
flame = flame.to(DEVICE) # CUDA init happens here
_model_cache["flame_model"] = flame
_mesh_file = env_paths.head_template
_obj_faces = load_obj(_mesh_file)[1]
_model_cache["diff_renderer"] = NVDRenderer(
image_size=tracking_conf.size,
obj_filename=_mesh_file,
no_sh=False,
white_bg=True
).to(DEVICE)
flame_model = _model_cache["flame_model"]
diff_renderer = _model_cache["diff_renderer"]
tracking_conf.video_name = f'{session_id}'
tracker = Tracker(tracking_conf, flame_model, diff_renderer)
tracker.run()
tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
image = first_image_from_dir(tracking_dir)
return "✅ Pipeline complete!", image, gr.update(interactive=True)
# New: run all steps sequentially
@spaces.GPU(duration=120)
def generate_results_and_mesh(image, session_id=None):
"""
Process an input image through a 3D reconstruction pipeline and return the intermediate outputs and mesh file.
This function runs a multi‐step workflow to go from a raw input image to a reconstructed 3D mesh:
1. **Preprocessing**: crops and masks the image for object isolation.
2. **Normals Estimation**: computes surface normal maps.
3. **UV Mapping**: generates UV coordinate maps for texturing.
4. **Tracking**: performs final alignment/tracking to prepare for mesh export.
5. **Mesh Discovery**: locates the resulting `.ply` file in the tracking output directory.
Args:
image (PIL.Image.Image or ndarray): Input image to reconstruct.
session_id (str): Unique identifier for this session’s output directories.
Returns:
tuple:
- final_status (str): Newline‐separated status messages from each pipeline step.
- crop_img (Image or None): Cropped and preprocessed image.
- normals_img (Image or None): Estimated surface normals visualization.
- uv_img (Image or None): UV‐map visualization.
- track_img (Image or None): Tracking/registration result.
- mesh_file (str or None): Path to the generated 3D mesh (`.ply`), if found.
"""
if session_id is None:
session_id = uuid.uuid4().hex
# Step 1
status1, crop_img, _, _ = preprocess_image(image, session_id)
if "❌" in status1:
return status1, None, None, None, None, None
# Step 2
status2, normals_img, _, _ = step2_normals(session_id)
# Step 3
status3, uv_img, _, _ = step3_uv_map(session_id)
# Step 4
status4, track_img, _ = step4_track(session_id)
# Locate mesh (.ply)
mesh_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh")
mesh_file = first_file_from_dir(mesh_dir, "glb")
final_status = "\n".join([status1, status2, status3, status4])
return final_status, crop_img, normals_img, uv_img, track_img, mesh_file
# Cleanup on unload
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid)
d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid)
shutil.rmtree(d1, ignore_errors=True)
shutil.rmtree(d2, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
# Build Gradio UI
with gr.Blocks(css=css) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
gr.HTML(
"""
<div style="text-align: center;">
<h1>Pixel3dmm [Image Mode]</h1>
<p style="font-size:16px;">Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction.</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/SimonGiebenhain/pixel3dmm">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Upload Image", type="numpy", height=512)
run_btn = gr.Button("Reconstruct Face", variant="primary")
status = gr.Textbox(label="Status", lines=6, interactive=True, value="Upload an image to start.")
with gr.Column():
with gr.Tabs():
with gr.Tab("Results"):
with gr.Row():
crop_img = gr.Image(label="Preprocessed", height=256)
normals_img = gr.Image(label="Normals", height=256)
with gr.Row():
uv_img = gr.Image(label="UV Map", height=256)
track_img = gr.Image(label="Tracking", height=256)
with gr.Tab("3D Model"):
with gr.Column():
mesh_file = gr.Model3D(label="3D Model Preview", height=512)
examples = gr.Examples(
examples=[
["example_images/jennifer_lawrence.png"],
["example_images/jim_carrey.png"],
["example_images/margaret_qualley.png"],
],
inputs=[image_in],
outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file],
fn=generate_results_and_mesh,
cache_examples=True
)
run_btn.click(
fn=generate_results_and_mesh,
inputs=[image_in, session_state],
outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file]
)
image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn])
demo.unload(cleanup)
demo.queue(default_concurrency_limit=1, # ≤ 1 worker per event
max_size=20) # optional: allow 20 waiting jobs
demo.launch(share=True)