File size: 7,163 Bytes
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba0984
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba0984
daa6779
 
 
 
 
 
 
 
 
 
 
 
8ba0984
 
 
daa6779
 
 
8ba0984
daa6779
 
 
 
 
8ba0984
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba0984
daa6779
 
 
8ba0984
daa6779
8ba0984
daa6779
 
 
 
 
8ba0984
 
 
daa6779
 
 
 
8ba0984
daa6779
 
6def56d
daa6779
 
 
 
 
 
 
 
 
e128474
 
 
daa6779
 
 
 
 
 
 
 
 
 
 
8ba0984
daa6779
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import numpy as np
import cv2
import kiui
import trimesh
import torch
import rembg
from datetime import datetime
import subprocess
import gradio as gr

try:
    # running on Hugging Face Spaces
    import spaces

except ImportError:
    # running locally, use a dummy space
    class spaces:
        class GPU:
            def __init__(self, duration=60):
                self.duration = duration
            def __call__(self, func):
                return func


from flow.model import Model
from flow.configs.schema import ModelConfig
from flow.utils import get_random_color, recenter_foreground
from vae.utils import postprocess_mesh

# download checkpoints
from huggingface_hub import hf_hub_download
flow_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="flow.pt")
vae_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="vae.pt")

TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
MAX_SEED = np.iinfo(np.int32).max
bg_remover = rembg.new_session()

# model config
model_config = ModelConfig(
    vae_conf="vae.configs.part_woenc",
    vae_ckpt_path=vae_ckpt_path,
    qknorm=True,
    qknorm_type="RMSNorm",
    use_pos_embed=False,
    dino_model="dinov2_vitg14",
    hidden_dim=1536,
    flow_shift=3.0,
    logitnorm_mean=1.0,
    logitnorm_std=1.0,
    latent_size=4096,
    use_parts=True,
)

# instantiate model
model = Model(model_config).eval().cuda().bfloat16()

# load weight
ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
model.load_state_dict(ckpt_dict, strict=True)

# process function
@spaces.GPU(duration=120)
def process(input_image, num_steps=30, cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True, simplify_mesh=False, target_num_faces=100000):

    # seed
    if randomize_seed:
        seed = np.random.randint(0, MAX_SEED)
    kiui.seed_everything(seed)

    # output path
    os.makedirs("output", exist_ok=True)
    output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"

    # input image
    input_image = np.array(input_image) # uint8

    # bg removal if there is no alpha channel
    if input_image.shape[-1] == 3:
        input_image = rembg.remove(input_image, session=bg_remover)  # [H, W, 4]
    mask = input_image[..., -1] > 0
    image = recenter_foreground(input_image, mask, border_ratio=0.1)
    image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
    image = image.astype(np.float32) / 255.0
    image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])  # white background

    image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
    data = {"cond_images": image_tensor}

    with torch.inference_mode():
        results = model(data, num_steps=num_steps, cfg_scale=cfg_scale)

    latent = results["latent"]

    # query mesh

    data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
    data_part1 = {"latent": latent[:, model.config.latent_size :, :]}

    with torch.inference_mode():
        results_part0 = model.vae(data_part0, resolution=grid_res)
        results_part1 = model.vae(data_part1, resolution=grid_res)

    if not simplify_mesh:
        target_num_faces = -1

    vertices, faces = results_part0["meshes"][0]
    mesh_part0 = trimesh.Trimesh(vertices, faces)
    mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
    mesh_part0 = postprocess_mesh(mesh_part0, target_num_faces)
    parts = mesh_part0.split(only_watertight=False)

    vertices, faces = results_part1["meshes"][0]
    mesh_part1 = trimesh.Trimesh(vertices, faces)
    mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
    mesh_part1 = postprocess_mesh(mesh_part1, target_num_faces)
    parts.extend(mesh_part1.split(only_watertight=False))

    # split connected components and assign different colors
    for j, part in enumerate(parts):
        # each component uses a random color
        part.visual.vertex_colors = get_random_color(j, use_float=True)

    mesh = trimesh.Scene(parts)
    # export the whole mesh
    mesh.export(output_glb_path)

    return seed, image, output_glb_path

# gradio UI

_TITLE = '''PartPacker: Efficient Part-level 3D Object Generation via Dual Volume Packing'''

_DESCRIPTION = '''
<div>
<a style="display:inline-block" href="https://research.nvidia.com/labs/dir/partpacker/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NVlabs/PartPacker"><img src='https://img.shields.io/github/stars/NVlabs/PartPacker?style=social'/></a>
</div>

* Each part is visualized with a random color, and can be separated in the GLB file.
* If the output is not satisfactory, please try different random seeds!
'''

block = gr.Blocks(title=_TITLE).queue()
with block:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown('# ' + _TITLE)
    gr.Markdown(_DESCRIPTION)

    with gr.Row():
        with gr.Column(scale=4):
            # input image
            input_image = gr.Image(label="Image", type='pil')
            # inference steps
            num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
            # cfg scale
            cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.0)
            # grid resolution
            input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
            # random seed
            seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            # simplify mesh
            simplify_mesh = gr.Checkbox(label="Simplify mesh", value=False)
            target_num_faces = gr.Slider(label="Face number", minimum=10000, maximum=1000000, step=1000, value=100000)
            # gen button
            button_gen = gr.Button("Generate")


        with gr.Column(scale=8):
            with gr.Tab("3D Model"):
                # glb file
                output_model = gr.Model3D(label="Geometry", height=512)

            with gr.Tab("Input Image"):
                # background removed image
                output_image = gr.Image(interactive=False, show_label=False)
                

        with gr.Column(scale=1):
            gr.Examples(
                examples=[
                    ["examples/rabbit.png"],
                    ["examples/robot.png"],
                    ["examples/teapot.png"],
                    ["examples/barrel.png"],
                    ["examples/cactus.png"],
                    ["examples/cyan_car.png"],
                    ["examples/pickup.png"],
                    ["examples/swivelchair.png"],
                    ["examples/warhammer.png"],
                ],
                inputs=[input_image],
                cache_examples=False,
            )

        button_gen.click(process, inputs=[input_image, num_steps, cfg_scale, input_grid_res, seed, randomize_seed, simplify_mesh, target_num_faces], outputs=[seed, output_image, output_model])

block.launch()