File size: 6,235 Bytes
db6a3b7
3057b36
7d475c1
db6a3b7
 
cd41f5f
690b53e
db6a3b7
9880f3d
7d475c1
db6a3b7
 
9880f3d
db6a3b7
 
9880f3d
db6a3b7
 
258ea5a
bd46f72
cd41f5f
d7b1815
 
258ea5a
cd41f5f
 
 
 
258ea5a
cd41f5f
 
 
 
 
258ea5a
db894f7
a481d7a
342cabd
16dfcc8
 
 
 
3057b36
cd41f5f
 
 
 
 
 
 
 
258ea5a
db6a3b7
258ea5a
 
db6a3b7
258ea5a
db6a3b7
cd41f5f
db894f7
cd41f5f
db894f7
bd46f72
 
 
 
 
 
 
 
 
 
 
258ea5a
 
7d475c1
15fe7bc
 
258ea5a
cd41f5f
7d475c1
258ea5a
 
 
 
 
 
 
 
 
 
c260ece
258ea5a
c260ece
258ea5a
c260ece
cd41f5f
7d475c1
 
258ea5a
 
 
7d475c1
a481d7a
db6a3b7
 
342cabd
a481d7a
bd46f72
342cabd
 
258ea5a
bd46f72
342cabd
 
258ea5a
bd46f72
342cabd
 
 
bd46f72
9173005
db6a3b7
258ea5a
 
 
a481d7a
258ea5a
db6a3b7
 
 
 
 
 
 
2e7f188
cd41f5f
db6a3b7
 
 
 
258ea5a
cd41f5f
 
a481d7a
db6a3b7
 
 
cd41f5f
db6a3b7
 
 
cd41f5f
 
 
 
db6a3b7
258ea5a
 
4241cf4
258ea5a
4241cf4
db6a3b7
 
a481d7a
db6a3b7
 
 
c666caf
 
258ea5a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import spaces
from gradio_litmodel3d import LitModel3D

import os
import shutil
os.environ['SPCONV_ALGO'] = 'native'
from typing import *
import torch
import numpy as np
import imageio
import uuid
from easydict import EasyDict as edict
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils


MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)


def start_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    print(f'Creating user directory: {user_dir}')
    os.makedirs(user_dir, exist_ok=True)
    
def end_session(req: gr.Request):
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    print(f'Removing user directory: {user_dir}')
    shutil.rmtree(user_dir)

def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
    processed_image = pipeline.preprocess_image(image)
    return processed_image

def get_seed(randomize_seed: bool, seed: int) -> int:
    """Get the random seed."""
    return np.random.randint(0, MAX_SEED) if randomize_seed else seed

@spaces.GPU
def image_to_3d(
    image: Image.Image,
    seed: int,
    ss_guidance_strength: float,
    ss_sampling_steps: int,
    slat_guidance_strength: float,
    slat_sampling_steps: int,
    req: gr.Request,
) -> Tuple[str, str, str]:
    """
    Convert an image to a 3D model and save both video preview and full-quality GLB.
    
    Returns:
        Tuple[str, str, str]: (video_path, glb_path, download_path)
    """
    user_dir = os.path.join(TMP_DIR, str(req.session_hash))
    outputs = pipeline.run(
        image,
        seed=seed,
        formats=["gaussian", "mesh"],
        preprocess_image=False,
        sparse_structure_sampler_params={
            "steps": ss_sampling_steps,
            "cfg_strength": ss_guidance_strength,
        },
        slat_sampler_params={
            "steps": slat_sampling_steps,
            "cfg_strength": slat_guidance_strength,
        },
    )
    
    # Generate and save video preview
    video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
    video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
    video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
    trial_id = str(uuid.uuid4())
    video_path = os.path.join(user_dir, f"{trial_id}.mp4")
    imageio.mimsave(video_path, video, fps=15)
    
    # Save full-quality GLB directly from the generated mesh
    glb = postprocessing_utils.to_glb(
        outputs['gaussian'][0], 
        outputs['mesh'][0],
        simplify=0.0,  # No simplification
        texture_size=2048,  # Maximum texture resolution
        verbose=False
    )
    glb_path = os.path.join(user_dir, f"{trial_id}_full.glb")
    glb.export(glb_path)
    
    torch.cuda.empty_cache()
    return video_path, glb_path, glb_path

with gr.Blocks(delete_cache=(600, 600)) as demo:
    gr.Markdown("""
    ## Image to 3D Asset with [TRELLIS](https://trellis3d.github.io/)
    * Upload an image and click "Generate" to create a high-quality 3D model
    * Once generation is complete, you can download the full-quality GLB file
    * The model will be in maximum quality with no reduction applied
    """)
    
    with gr.Row():
        with gr.Column():
            image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300)
            
            with gr.Accordion(label="Generation Settings", open=False):
                seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
                randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
                gr.Markdown("Stage 1: Sparse Structure Generation")
                with gr.Row():
                    ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
                    ss_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)
                gr.Markdown("Stage 2: Structured Latent Generation")
                with gr.Row():
                    slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
                    slat_sampling_steps = gr.Slider(1, 500, label="Sampling Steps", value=12, step=1)

            generate_btn = gr.Button("Generate")

        with gr.Column():
            video_output = gr.Video(label="Generated 3D Asset Preview", autoplay=True, loop=True, height=300)
            model_output = LitModel3D(label="3D Model Preview", exposure=20.0, height=300)
            download_glb = gr.DownloadButton(label="Download Full-Quality GLB", interactive=False)
            
    # Example images
    with gr.Row():
        examples = gr.Examples(
            examples=[
                f'assets/example_image/{image}'
                for image in os.listdir("assets/example_image")
            ],
            inputs=[image_prompt],
            fn=preprocess_image,
            outputs=[image_prompt],
            run_on_click=True,
            examples_per_page=64,
        )

    # Event handlers
    demo.load(start_session)
    demo.unload(end_session)
    
    image_prompt.upload(
        preprocess_image,
        inputs=[image_prompt],
        outputs=[image_prompt],
    )

    generate_btn.click(
        get_seed,
        inputs=[randomize_seed, seed],
        outputs=[seed],
    ).then(
        image_to_3d,
        inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
        outputs=[video_output, model_output, download_glb],
    ).then(
        lambda: gr.Button(interactive=True),
        outputs=[download_glb],
    )

# Launch the Gradio app
if __name__ == "__main__":
    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    pipeline.cuda()
    try:
        pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))    # Preload rembg
    except:
        pass
    demo.launch()