File size: 13,204 Bytes
480e656
2f19918
 
 
 
480e656
8355ef8
 
 
f0ccdf1
480e656
 
 
 
 
 
 
98dd4d1
e2cec11
3b8b123
 
 
 
 
 
 
 
b94a387
fe8ccea
d38ca73
27d9c8d
 
a2f3197
 
 
39c9247
 
3b8b123
 
 
0f5fa52
e86d00d
63529cc
 
0c8ae54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cae2bd9
 
6b0fd85
cae2bd9
c037013
 
 
 
0c8ae54
1d930bb
 
 
 
0c8ae54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fe571f
0c8ae54
 
6b0fd85
 
 
0c8ae54
 
 
c037013
c65fc1e
0c8ae54
a79ad4c
0c8ae54
8c4572d
 
0c8ae54
 
c65fc1e
0c8ae54
 
 
 
 
 
 
 
8c4572d
a79ad4c
0c8ae54
dfdcbbc
0c8ae54
c65fc1e
0c8ae54
 
 
c65fc1e
1a9a496
 
6b0fd85
 
 
1a9a496
6b0fd85
 
 
 
 
 
0c8ae54
c65fc1e
0c8ae54
6b0fd85
c65fc1e
0c8ae54
 
 
c65fc1e
1a9a496
 
6b0fd85
 
 
1a9a496
6b0fd85
 
 
 
 
 
 
0c8ae54
c65fc1e
0c8ae54
6b0fd85
c65fc1e
0c8ae54
 
 
c65fc1e
1a9a496
 
 
 
 
 
 
 
 
 
6b0fd85
 
c037013
 
1a9a496
6b0fd85
c037013
 
939af2e
 
 
 
 
 
6b0fd85
939af2e
 
 
 
6b0fd85
c037013
939af2e
6b0fd85
 
2f19918
 
0c8ae54
 
 
 
c65fc1e
0c8ae54
9db6838
4c2fa10
403a6e7
0ed0f93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403a6e7
 
 
d0853f7
4bed63a
9db6838
c65fc1e
d0853f7
c65fc1e
d0853f7
c65fc1e
d0853f7
c65fc1e
d0853f7
c5acb00
e0832eb
2079241
9db6838
c65fc1e
9db6838
58c9f2e
6f8a30f
 
58c9f2e
065fdc8
 
58c9f2e
 
 
383eaaa
ad91b24
383eaaa
 
308fe7a
 
 
 
 
 
 
0c8ae54
308fe7a
af4bba5
 
 
d9631c9
 
43d6de6
 
 
 
d9631c9
 
 
 
 
 
 
 
308fe7a
 
 
 
 
 
ead355d
43d6de6
308fe7a
c65fc1e
6d4c5f0
 
 
 
 
 
 
 
 
 
 
e0832eb
 
6d4c5f0
308fe7a
f1d0538
403a6e7
 
 
f1d0538
403a6e7
ead355d
fca038f
ead355d
308fe7a
383eaaa
43d6de6
308fe7a
383eaaa
ad91b24
c65fc1e
308fe7a
c65fc1e
d0853f7
6f8a30f
0c8ae54
bda807f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import spaces

import torch._dynamo
torch._dynamo.disable()

import os
# Force Dynamo off at import‐time of torch, pytorch3d, etc.
os.environ["TORCHDYNAMO_DISABLE"] = "1"

from gradio_litmodel3d import LitModel3D
import subprocess
import tempfile
import uuid
import glob
import shutil
import time
import gradio as gr
import sys
from PIL import Image
import importlib, site, sys

# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
    site.addsitedir(sitedir)

# Clear caches so importlib will pick up new modules
importlib.invalidate_caches()

# Set environment variables
os.environ["PIXEL3DMM_CODE_BASE"] = f"{os.getcwd()}"
os.environ["PIXEL3DMM_PREPROCESSED_DATA"] = f"{os.getcwd()}/proprocess_results"
os.environ["PIXEL3DMM_TRACKING_OUTPUT"] = f"{os.getcwd()}/tracking_results"

def sh(cmd): subprocess.check_call(cmd, shell=True)

sh("pip install -e .")

# tell Python to re-scan site-packages now that the egg-link exists
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()

from pixel3dmm import env_paths

sh("cd src/pixel3dmm/preprocessing/facer && pip install -e . && cd ../../../..")
sh("cd src/pixel3dmm/preprocessing/PIPNet/FaceBoxesV2/utils && sh make.sh && cd ../../../../../..")

def install_cuda_toolkit():
    CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"
    CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
    subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
    subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
    subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])

    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
    os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
        os.environ["CUDA_HOME"],
        "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
    )
    # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
    os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
    print("==> finished installation")
    
install_cuda_toolkit()

from omegaconf import OmegaConf
from pixel3dmm.network_inference import normals_n_uvs

DEVICE = "cuda"

# 2. Empty cache for our heavy objects
_model_cache = {}

def first_file_from_dir(directory, ext):
    files = glob.glob(os.path.join(directory, f"*.{ext}"))
    return sorted(files)[0] if files else None#
    
# Utility to select first image from a folder
def first_image_from_dir(directory):
    patterns = ["*.jpg", "*.png", "*.jpeg"]
    files = []
    for p in patterns:
        files.extend(glob.glob(os.path.join(directory, p)))
    if not files:
        return None
    return sorted(files)[0]

# Function to reset the UI and state
def reset_all():
    return (
        None,  # crop_img
        None,  # normals_img
        None,  # uv_img
        None,  # track_img
        "Time to Generate!",  # status
        {},    # state
        gr.update(interactive=True),   # preprocess_btn
        gr.update(interactive=True),  # normals_btn
        gr.update(interactive=True),  # uv_map_btn
        gr.update(interactive=True)   # track_btn
    )

# Step 1: Preprocess the input image (Save and Crop)
@spaces.GPU()
def preprocess_image(image_array, session_id):
    if image_array is None:
        return "❌ Please upload an image first.", gr.update(interactive=True), gr.update(interactive=True)

    base_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id)
    os.makedirs(base_dir, exist_ok=True)

    img = Image.fromarray(image_array)
    saved_image_path = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, f"{session_id}.png")
    img.save(saved_image_path)
    
    try:
        p = subprocess.run([
            "python", "scripts/run_preprocessing.py", "--video_or_images_path", saved_image_path
        ], check=True, capture_output=True, text=True)
    except subprocess.CalledProcessError as e:
        err = f"❌ Preprocess failed (exit {e.returncode}).\n\n{e.stdout}\n{e.stderr}"
        shutil.rmtree(base_dir, ignore_errors=True)
        return err, None, gr.update(interactive=True), gr.update(interactive=True)

    crop_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "cropped")
    image = first_image_from_dir(crop_dir)
    return "✅ Step 1 complete. Ready for Normals.", image, gr.update(interactive=True), gr.update(interactive=True)

# Step 2: Normals inference → normals image
@spaces.GPU()
def step2_normals(session_id):
    from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
    
    base_conf = OmegaConf.load("configs/base.yaml")

    if "normals_model" not in _model_cache:

        model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_N_PRED}", strict=False)
        model =  model.eval().to(DEVICE)
        _model_cache["normals_model"] = model   

    base_conf.video_name = f'{session_id}'
    normals_n_uvs(base_conf, _model_cache["normals_model"])

    normals_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "normals")
    image = first_image_from_dir(normals_dir)

    return "✅ Step 2 complete. Ready for UV Map.", image, gr.update(interactive=True), gr.update(interactive=True)

# Step 3: UV map inference → uv map image
@spaces.GPU()
def step3_uv_map(session_id):
    from pixel3dmm.lightning.p3dmm_system import system as p3dmm_system
    
    base_conf = OmegaConf.load("configs/base.yaml")

    if "uv_model" not in _model_cache:

        model = p3dmm_system.load_from_checkpoint(f"{env_paths.CKPT_UV_PRED}", strict=False)
        model =  model.eval().to(DEVICE)
        _model_cache["uv_model"] = model   

    base_conf.video_name = f'{session_id}'
    base_conf.model.prediction_type = "uv_map"
    normals_n_uvs(base_conf, _model_cache["uv_model"])

    uv_dir = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], session_id, "p3dmm", "uv_map")
    image = first_image_from_dir(uv_dir)

    return "✅ Step 3 complete. Ready for Tracking.", image, gr.update(interactive=True), gr.update(interactive=True)

# Step 4: Tracking → final tracking image
@spaces.GPU()
def step4_track(session_id):
    import os
    import torch
    import numpy as np
    import trimesh
    from pytorch3d.io import load_obj
    
    from pixel3dmm.tracking.flame.FLAME import FLAME
    from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
    from pixel3dmm.tracking.tracker import Tracker
    
    tracking_conf = OmegaConf.load("configs/tracking.yaml")

     # Lazy init + caching of FLAME model on GPU
    if "flame_model" not in _model_cache:

        flame = FLAME(tracking_conf)      # CPU instantiation
        flame = flame.to(DEVICE)      # CUDA init happens here
        _model_cache["flame_model"] = flame   
        
        _mesh_file = env_paths.head_template
        
        _obj_faces = load_obj(_mesh_file)[1]
        
        _model_cache["diff_renderer"] = NVDRenderer(
            image_size=tracking_conf.size, 
            obj_filename=_mesh_file,
            no_sh=False,
            white_bg=True
        ).to(DEVICE)
   
    flame_model = _model_cache["flame_model"]
    diff_renderer = _model_cache["diff_renderer"]
    tracking_conf.video_name = f'{session_id}'
    tracker = Tracker(tracking_conf, flame_model, diff_renderer)
    tracker.run()
        

    tracking_dir = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "frames")
    image = first_image_from_dir(tracking_dir)

    return "✅ Pipeline complete!", image, gr.update(interactive=True)

# New: run all steps sequentially
@spaces.GPU(duration=120)
def generate_results_and_mesh(image, session_id=None):

    """
    Process an input image through a 3D reconstruction pipeline and return the intermediate outputs and mesh file.

    This function runs a multi‐step workflow to go from a raw input image to a reconstructed 3D mesh:
      1. **Preprocessing**: crops and masks the image for object isolation.
      2. **Normals Estimation**: computes surface normal maps.
      3. **UV Mapping**: generates UV coordinate maps for texturing.
      4. **Tracking**: performs final alignment/tracking to prepare for mesh export.
      5. **Mesh Discovery**: locates the resulting `.ply` file in the tracking output directory.

    Args:
        image (PIL.Image.Image or ndarray): Input image to reconstruct.
        session_id (str): Unique identifier for this session’s output directories.

    Returns:
        tuple:
            - final_status (str): Newline‐separated status messages from each pipeline step.
            - crop_img (Image or None): Cropped and preprocessed image.
            - normals_img (Image or None): Estimated surface normals visualization.
            - uv_img (Image or None): UV‐map visualization.
            - track_img (Image or None): Tracking/registration result.
            - mesh_file (str or None): Path to the generated 3D mesh (`.ply`), if found.
    """
    if session_id is None:
        session_id = uuid.uuid4().hex
         
    # Step 1
    status1, crop_img, _, _ = preprocess_image(image, session_id)
    if "❌" in status1:
        return status1, None, None, None, None, None
    # Step 2
    status2, normals_img, _, _ = step2_normals(session_id)
    # Step 3
    status3, uv_img, _, _ = step3_uv_map(session_id)
    # Step 4
    status4, track_img, _ = step4_track(session_id)
    # Locate mesh (.ply)
    mesh_dir     = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], session_id, "mesh")
    mesh_file = first_file_from_dir(mesh_dir, "glb")

    final_status = "\n".join([status1, status2, status3, status4])
    return final_status, crop_img, normals_img, uv_img, track_img, mesh_file

# Cleanup on unload
def cleanup(request: gr.Request):
    sid = request.session_hash
    if sid:
        d1 = os.path.join(os.environ["PIXEL3DMM_PREPROCESSED_DATA"], sid)
        d2 = os.path.join(os.environ["PIXEL3DMM_TRACKING_OUTPUT"], sid)
        shutil.rmtree(d1, ignore_errors=True)
        shutil.rmtree(d2, ignore_errors=True)

def start_session(request: gr.Request):
    return request.session_hash

    
css = """
#col-container {
    margin: 0 auto;
    max-width: 1024px;
}
"""

# Build Gradio UI
with gr.Blocks(css=css) as demo:
    session_state = gr.State()
    demo.load(start_session, outputs=[session_state])

    gr.HTML(
        """
        <div style="text-align: center;">
            <h1>Pixel3dmm [Image Mode]</h1>
            <p style="font-size:16px;">Versatile Screen-Space Priors for Single-Image 3D Face Reconstruction.</p>
        </div>
        <br>
        <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
            <a href="https://github.com/SimonGiebenhain/pixel3dmm">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a>
        </div>
        """
    )

    with gr.Column(elem_id="col-container"):
        
        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Upload Image", type="numpy", height=512)
                run_btn = gr.Button("Reconstruct Face", variant="primary")
    
                status = gr.Textbox(label="Status", lines=6, interactive=True, value="Upload an image to start.")

            with gr.Column():
                with gr.Tabs():
                    with gr.Tab("Results"):
                        
                            with gr.Row():
                                crop_img = gr.Image(label="Preprocessed", height=256)
                                normals_img = gr.Image(label="Normals", height=256)
                            with gr.Row():
                                uv_img = gr.Image(label="UV Map", height=256)
                                track_img = gr.Image(label="Tracking", height=256)
                    with gr.Tab("3D Model"):
                        with gr.Column():
                            mesh_file = gr.Model3D(label="3D Model Preview", height=512)
        
                examples = gr.Examples(
                    examples=[
                        ["example_images/jennifer_lawrence.png"],
                        ["example_images/jim_carrey.png"],
                        ["example_images/margaret_qualley.png"],
                    ],
                    inputs=[image_in],
                    outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file],
                    fn=generate_results_and_mesh,
                    cache_examples=True
                )


        run_btn.click(
            fn=generate_results_and_mesh,
            inputs=[image_in, session_state],
            outputs=[status, crop_img, normals_img, uv_img, track_img, mesh_file]
        )
        image_in.upload(fn=reset_all, inputs=None, outputs=[crop_img, normals_img, uv_img, track_img, mesh_file, status, run_btn])

    demo.unload(cleanup)

demo.queue(default_concurrency_limit=1,        # ≤ 1 worker per event
           max_size=20)                        # optional: allow 20 waiting jobs

demo.launch()