Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- README.md +1 -1
- app.py +192 -0
- examples/barrel.png +3 -0
- examples/cactus.png +3 -0
- examples/cyan_car.png +3 -0
- examples/pickup.png +3 -0
- examples/rabbit.png +3 -0
- examples/robot.png +3 -0
- examples/swivelchair.png +3 -0
- examples/teapot.png +3 -0
- examples/warhammer.png +3 -0
- flow/__init__.py +0 -0
- flow/configs/__init__.py +0 -0
- flow/configs/big_parts_strict_pvae.py +33 -0
- flow/configs/schema.py +57 -0
- flow/flow_matching.py +58 -0
- flow/model.py +336 -0
- flow/modules/__init__.py +0 -0
- flow/modules/dit.py +235 -0
- flow/scripts/infer.py +180 -0
- flow/utils.py +119 -0
- requirements.txt +16 -0
- vae/__init__.py +0 -0
- vae/configs/__init__.py +0 -0
- vae/configs/part_woenc.py +30 -0
- vae/configs/schema.py +55 -0
- vae/model.py +451 -0
- vae/modules/__init__.py +0 -0
- vae/modules/attention.py +261 -0
- vae/modules/transformer.py +117 -0
- vae/scripts/infer.py +142 -0
- vae/utils.py +315 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: PartPacker
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: PartPacker
|
3 |
+
emoji: 🪴
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import kiui
|
5 |
+
import trimesh
|
6 |
+
import torch
|
7 |
+
import rembg
|
8 |
+
from datetime import datetime
|
9 |
+
import subprocess
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
try:
|
13 |
+
# running on Hugging Face Spaces
|
14 |
+
import spaces
|
15 |
+
|
16 |
+
except ImportError:
|
17 |
+
# running locally, use a dummy space
|
18 |
+
class spaces:
|
19 |
+
class GPU:
|
20 |
+
def __init__(self, duration=60):
|
21 |
+
self.duration = duration
|
22 |
+
def __call__(self, func):
|
23 |
+
return func
|
24 |
+
|
25 |
+
|
26 |
+
from flow.model import Model
|
27 |
+
from flow.configs.schema import ModelConfig
|
28 |
+
from flow.utils import get_random_color, recenter_foreground
|
29 |
+
from vae.utils import postprocess_mesh
|
30 |
+
|
31 |
+
# download checkpoints
|
32 |
+
from huggingface_hub import hf_hub_download
|
33 |
+
flow_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="flow.pt")
|
34 |
+
vae_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="vae.pt")
|
35 |
+
|
36 |
+
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
|
37 |
+
MAX_SEED = np.iinfo(np.int32).max
|
38 |
+
bg_remover = rembg.new_session()
|
39 |
+
|
40 |
+
# model config
|
41 |
+
model_config = ModelConfig(
|
42 |
+
vae_conf="vae.configs.part_woenc",
|
43 |
+
vae_ckpt_path=vae_ckpt_path,
|
44 |
+
qknorm=True,
|
45 |
+
qknorm_type="RMSNorm",
|
46 |
+
use_pos_embed=False,
|
47 |
+
dino_model="dinov2_vitg14",
|
48 |
+
hidden_dim=1536,
|
49 |
+
flow_shift=3.0,
|
50 |
+
logitnorm_mean=1.0,
|
51 |
+
logitnorm_std=1.0,
|
52 |
+
latent_size=4096,
|
53 |
+
use_parts=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
# instantiate model
|
57 |
+
model = Model(model_config).eval().cuda().bfloat16()
|
58 |
+
|
59 |
+
# load weight
|
60 |
+
ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
|
61 |
+
model.load_state_dict(ckpt_dict, strict=True)
|
62 |
+
|
63 |
+
# process function
|
64 |
+
@spaces.GPU(duration=120)
|
65 |
+
def process(input_image, input_num_steps=30, input_cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True):
|
66 |
+
|
67 |
+
# seed
|
68 |
+
if randomize_seed:
|
69 |
+
seed = np.random.randint(0, MAX_SEED)
|
70 |
+
kiui.seed_everything(seed)
|
71 |
+
|
72 |
+
# output path
|
73 |
+
os.makedirs("output", exist_ok=True)
|
74 |
+
output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
|
75 |
+
|
76 |
+
# input image
|
77 |
+
input_image = np.array(input_image) # uint8
|
78 |
+
|
79 |
+
# bg removal if there is no alpha channel
|
80 |
+
if input_image.shape[-1] == 3:
|
81 |
+
input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
|
82 |
+
mask = input_image[..., -1] > 0
|
83 |
+
image = recenter_foreground(input_image, mask, border_ratio=0.1)
|
84 |
+
image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
|
85 |
+
image = image.astype(np.float32) / 255.0
|
86 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
|
87 |
+
|
88 |
+
image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
|
89 |
+
data = {"cond_images": image_tensor}
|
90 |
+
|
91 |
+
with torch.inference_mode():
|
92 |
+
results = model(data, num_steps=input_num_steps, cfg_scale=input_cfg_scale)
|
93 |
+
|
94 |
+
latent = results["latent"]
|
95 |
+
|
96 |
+
# query mesh
|
97 |
+
|
98 |
+
data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
|
99 |
+
data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
|
100 |
+
|
101 |
+
with torch.inference_mode():
|
102 |
+
results_part0 = model.vae(data_part0, resolution=grid_res)
|
103 |
+
results_part1 = model.vae(data_part1, resolution=grid_res)
|
104 |
+
|
105 |
+
vertices, faces = results_part0["meshes"][0]
|
106 |
+
mesh_part0 = trimesh.Trimesh(vertices, faces)
|
107 |
+
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
|
108 |
+
mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
|
109 |
+
parts = mesh_part0.split(only_watertight=False)
|
110 |
+
|
111 |
+
vertices, faces = results_part1["meshes"][0]
|
112 |
+
mesh_part1 = trimesh.Trimesh(vertices, faces)
|
113 |
+
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
|
114 |
+
mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
|
115 |
+
parts.extend(mesh_part1.split(only_watertight=False))
|
116 |
+
|
117 |
+
# split connected components and assign different colors
|
118 |
+
for j, part in enumerate(parts):
|
119 |
+
# each component uses a random color
|
120 |
+
part.visual.vertex_colors = get_random_color(j, use_float=True)
|
121 |
+
|
122 |
+
mesh = trimesh.Scene(parts)
|
123 |
+
# export the whole mesh
|
124 |
+
mesh.export(output_glb_path)
|
125 |
+
|
126 |
+
return seed, image, output_glb_path
|
127 |
+
|
128 |
+
# gradio UI
|
129 |
+
|
130 |
+
_TITLE = '''PartPacker: Efficient Part-level 3D Object Generation via Dual Volume Packing'''
|
131 |
+
|
132 |
+
_DESCRIPTION = '''
|
133 |
+
<div>
|
134 |
+
<a style="display:inline-block" href="https://research.nvidia.com/labs/dir/partpacker/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
|
135 |
+
<a style="display:inline-block; margin-left: .5em" href="https://github.com/NVlabs/PartPacker"><img src='https://img.shields.io/github/stars/NVlabs/PartPacker?style=social'/></a>
|
136 |
+
</div>
|
137 |
+
|
138 |
+
* Each part is visualized with a random color, and can be separated in the GLB file.
|
139 |
+
* If the output is not satisfactory, please try different random seeds!
|
140 |
+
'''
|
141 |
+
|
142 |
+
block = gr.Blocks(title=_TITLE).queue()
|
143 |
+
with block:
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column(scale=1):
|
146 |
+
gr.Markdown('# ' + _TITLE)
|
147 |
+
gr.Markdown(_DESCRIPTION)
|
148 |
+
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column(scale=2):
|
151 |
+
# input image
|
152 |
+
input_image = gr.Image(label="Image", type='pil')
|
153 |
+
# inference steps
|
154 |
+
input_num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
|
155 |
+
# cfg scale
|
156 |
+
input_cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.5)
|
157 |
+
# grid resolution
|
158 |
+
input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
|
159 |
+
# random seed
|
160 |
+
seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
161 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
162 |
+
# gen button
|
163 |
+
button_gen = gr.Button("Generate")
|
164 |
+
|
165 |
+
|
166 |
+
with gr.Column(scale=4):
|
167 |
+
with gr.Tab("3D Model"):
|
168 |
+
# glb file
|
169 |
+
output_model = gr.Model3D(label="Geometry", height=380)
|
170 |
+
|
171 |
+
with gr.Tab("Input Image"):
|
172 |
+
# background removed image
|
173 |
+
output_image = gr.Image(interactive=False, show_label=False)
|
174 |
+
|
175 |
+
|
176 |
+
with gr.Column(scale=1):
|
177 |
+
gr.Examples(
|
178 |
+
examples=[
|
179 |
+
["examples/barrel.png"],
|
180 |
+
["examples/cactus.png"],
|
181 |
+
["examples/cyan_car.png"],
|
182 |
+
["examples/pickup.png"],
|
183 |
+
["examples/swivelchair.png"],
|
184 |
+
["examples/warhammer.png"],
|
185 |
+
],
|
186 |
+
inputs=[input_image],
|
187 |
+
cache_examples=False,
|
188 |
+
)
|
189 |
+
|
190 |
+
button_gen.click(process, inputs=[input_image, input_num_steps, input_cfg_scale, input_grid_res, seed, randomize_seed], outputs=[seed, output_image, output_model])
|
191 |
+
|
192 |
+
block.launch()
|
examples/barrel.png
ADDED
![]() |
Git LFS Details
|
examples/cactus.png
ADDED
![]() |
Git LFS Details
|
examples/cyan_car.png
ADDED
![]() |
Git LFS Details
|
examples/pickup.png
ADDED
![]() |
Git LFS Details
|
examples/rabbit.png
ADDED
![]() |
Git LFS Details
|
examples/robot.png
ADDED
![]() |
Git LFS Details
|
examples/swivelchair.png
ADDED
![]() |
Git LFS Details
|
examples/teapot.png
ADDED
![]() |
Git LFS Details
|
examples/warhammer.png
ADDED
![]() |
Git LFS Details
|
flow/__init__.py
ADDED
File without changes
|
flow/configs/__init__.py
ADDED
File without changes
|
flow/configs/big_parts_strict_pvae.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from flow.configs.schema import ModelConfig
|
14 |
+
|
15 |
+
|
16 |
+
def make_config():
|
17 |
+
|
18 |
+
model_config = ModelConfig(
|
19 |
+
vae_conf="vae.configs.part_woenc",
|
20 |
+
vae_ckpt_path="pretrained/vae.pt",
|
21 |
+
qknorm=True,
|
22 |
+
qknorm_type="RMSNorm",
|
23 |
+
use_pos_embed=False,
|
24 |
+
dino_model="dinov2_vitg14",
|
25 |
+
hidden_dim=1536,
|
26 |
+
flow_shift=3.0,
|
27 |
+
logitnorm_mean=1.0,
|
28 |
+
logitnorm_std=1.0,
|
29 |
+
latent_size=4096,
|
30 |
+
use_parts=True,
|
31 |
+
)
|
32 |
+
|
33 |
+
return model_config
|
flow/configs/schema.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Literal, Optional
|
14 |
+
|
15 |
+
import attrs
|
16 |
+
|
17 |
+
|
18 |
+
@attrs.define(slots=False)
|
19 |
+
class ModelConfig:
|
20 |
+
# vae
|
21 |
+
vae_conf: str = "vae.configs.part_woenc"
|
22 |
+
vae_ckpt_path: Optional[str] = None
|
23 |
+
|
24 |
+
# learn & generate parts
|
25 |
+
use_parts: bool = False
|
26 |
+
part_embed_mode: Literal["element", "part", "part2_only"] = "part2_only"
|
27 |
+
shuffle_parts: bool = False
|
28 |
+
use_num_parts_cond: bool = False
|
29 |
+
|
30 |
+
# flow matching hyper-params
|
31 |
+
flow_shift: float = 1.0
|
32 |
+
logitnorm_mean: float = 0.0
|
33 |
+
logitnorm_std: float = 1.0
|
34 |
+
|
35 |
+
# image encoder
|
36 |
+
dino_model: Literal["dinov2_vitl14_reg", "dinov2_vitg14"] = "dinov2_vitg14"
|
37 |
+
|
38 |
+
# backbone DiT
|
39 |
+
hidden_dim: int = 1536
|
40 |
+
num_heads: int = 16
|
41 |
+
num_layers: int = 24
|
42 |
+
qknorm: bool = True
|
43 |
+
qknorm_type: Literal["LayerNorm", "RMSNorm"] = "RMSNorm"
|
44 |
+
use_pos_embed: bool = False
|
45 |
+
|
46 |
+
# latent code
|
47 |
+
latent_size: Optional[int] = None # if None, will load from vae
|
48 |
+
latent_dim: Optional[int] = None
|
49 |
+
|
50 |
+
# preload vae weights
|
51 |
+
preload_vae: bool = True
|
52 |
+
|
53 |
+
# preload dinov2 weights
|
54 |
+
preload_dinov2: bool = True
|
55 |
+
|
56 |
+
# init weights from a pretrained checkpoint
|
57 |
+
pretrain_path: Optional[str] = None
|
flow/flow_matching.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
class FlowMatchingScheduler:
|
18 |
+
def __init__(self, num_train_timesteps: int = 1000, shift: float = 1):
|
19 |
+
# set timesteps
|
20 |
+
self.num_train_timesteps = num_train_timesteps
|
21 |
+
self.shift = shift
|
22 |
+
|
23 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
24 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
25 |
+
|
26 |
+
sigmas = timesteps / num_train_timesteps
|
27 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
28 |
+
|
29 |
+
self.sigmas = sigmas # 1 --> 0
|
30 |
+
self.timesteps = sigmas * num_train_timesteps # num_train_timesteps --> 1
|
31 |
+
|
32 |
+
# set device
|
33 |
+
def to(self, device):
|
34 |
+
self.sigmas = self.sigmas.to(device=device)
|
35 |
+
self.timesteps = self.timesteps.to(device=device)
|
36 |
+
|
37 |
+
# add random noise to latent during training
|
38 |
+
def add_noise(self, latent: torch.Tensor, logit_mean: float = 1.0, logit_std: float = 1.0):
|
39 |
+
# latent: [B, ...]
|
40 |
+
# timesteps: [B]
|
41 |
+
# return: [B, ...] noisy_latent, [B, ...] noise, [B] timesteps
|
42 |
+
|
43 |
+
# logit-normal sampling
|
44 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(latent.shape[0],), device=self.sigmas.device)
|
45 |
+
u = torch.nn.functional.sigmoid(u)
|
46 |
+
|
47 |
+
step_indices = (u * self.num_train_timesteps).long()
|
48 |
+
timesteps = self.timesteps[step_indices]
|
49 |
+
|
50 |
+
sigmas = self.sigmas[step_indices].flatten()
|
51 |
+
|
52 |
+
while len(sigmas.shape) < latent.ndim:
|
53 |
+
sigmas = sigmas.unsqueeze(-1)
|
54 |
+
|
55 |
+
noise = torch.randn_like(latent)
|
56 |
+
noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
|
57 |
+
|
58 |
+
return noisy_latent, noise, timesteps
|
flow/model.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import importlib
|
14 |
+
|
15 |
+
from transformers import Dinov2Model
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import tqdm
|
21 |
+
from torchvision import transforms
|
22 |
+
|
23 |
+
from flow.configs.schema import ModelConfig
|
24 |
+
from flow.flow_matching import FlowMatchingScheduler
|
25 |
+
from flow.modules.dit import DiT
|
26 |
+
from vae.model import Model as VAE
|
27 |
+
from vae.utils import sync_timer
|
28 |
+
|
29 |
+
|
30 |
+
class Model(nn.Module):
|
31 |
+
def __init__(self, config: ModelConfig) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.config = config
|
34 |
+
self.precision = torch.bfloat16
|
35 |
+
|
36 |
+
# image condition model (dinov2)
|
37 |
+
if self.config.dino_model == "dinov2_vitg14":
|
38 |
+
self.dino = Dinov2Model.from_pretrained("facebook/dinov2-giant")
|
39 |
+
elif self.config.dino_model == "dinov2_vitl14_reg":
|
40 |
+
self.dino = Dinov2Model.from_pretrained("facebook/dinov2-with-registers-large")
|
41 |
+
else:
|
42 |
+
raise ValueError(f"DINOv2 model {self.config.dino_model} not supported")
|
43 |
+
|
44 |
+
# hack to match our implementation
|
45 |
+
self.dino.layernorm = torch.nn.Identity()
|
46 |
+
|
47 |
+
self.dino.eval().to(dtype=self.precision)
|
48 |
+
self.dino.requires_grad_(False)
|
49 |
+
|
50 |
+
cond_dim = 1024 if self.config.dino_model == "dinov2_vitl14_reg" else 1536
|
51 |
+
assert cond_dim == config.hidden_dim, "DINOv2 dim must match backbone dim"
|
52 |
+
|
53 |
+
self.preprocess_cond_image = transforms.Compose(
|
54 |
+
[
|
55 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
56 |
+
]
|
57 |
+
)
|
58 |
+
|
59 |
+
# vae encoder
|
60 |
+
vae_config = importlib.import_module(config.vae_conf).make_config()
|
61 |
+
self.vae = VAE(vae_config).eval().to(dtype=self.precision)
|
62 |
+
self.vae.requires_grad_(False)
|
63 |
+
|
64 |
+
# load vae
|
65 |
+
if self.config.preload_vae:
|
66 |
+
try:
|
67 |
+
vae_ckpt = torch.load(self.config.vae_ckpt_path, weights_only=True) # local path
|
68 |
+
if "model" in vae_ckpt:
|
69 |
+
vae_ckpt = vae_ckpt["model"]
|
70 |
+
self.vae.load_state_dict(vae_ckpt, strict=True)
|
71 |
+
del vae_ckpt
|
72 |
+
print(f"Loaded VAE from {self.config.vae_ckpt_path}")
|
73 |
+
except Exception as e:
|
74 |
+
print(
|
75 |
+
f"Failed to load VAE from {self.config.vae_ckpt_path}: {e}, make sure you resumed from a valid checkpoint!"
|
76 |
+
)
|
77 |
+
|
78 |
+
# load info from vae config
|
79 |
+
if config.latent_size is None:
|
80 |
+
config.latent_size = self.vae.config.latent_size
|
81 |
+
if config.latent_dim is None:
|
82 |
+
config.latent_dim = self.vae.config.latent_dim
|
83 |
+
|
84 |
+
# dit
|
85 |
+
self.dit = DiT(
|
86 |
+
hidden_dim=config.hidden_dim,
|
87 |
+
num_heads=config.num_heads,
|
88 |
+
num_layers=config.num_layers,
|
89 |
+
latent_size=config.latent_size,
|
90 |
+
latent_dim=config.latent_dim,
|
91 |
+
qknorm=config.qknorm,
|
92 |
+
qknorm_type=config.qknorm_type,
|
93 |
+
use_pos_embed=config.use_pos_embed,
|
94 |
+
use_parts=config.use_parts,
|
95 |
+
part_embed_mode=config.part_embed_mode,
|
96 |
+
)
|
97 |
+
|
98 |
+
# num_part condition
|
99 |
+
if self.config.use_num_parts_cond:
|
100 |
+
assert self.config.use_parts, "use_num_parts_cond requires use_parts"
|
101 |
+
self.num_part_embed = nn.Embedding(5, config.hidden_dim)
|
102 |
+
|
103 |
+
# preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!)
|
104 |
+
if self.config.pretrain_path is not None:
|
105 |
+
try:
|
106 |
+
ckpt = torch.load(self.config.pretrain_path) # local path
|
107 |
+
self.load_state_dict(ckpt["model"], strict=True)
|
108 |
+
del ckpt
|
109 |
+
print(f"Loaded DiT from {self.config.pretrain_path}")
|
110 |
+
except Exception as e:
|
111 |
+
print(
|
112 |
+
f"Failed to load DiT from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!"
|
113 |
+
)
|
114 |
+
|
115 |
+
# sampler
|
116 |
+
self.scheduler = FlowMatchingScheduler(shift=config.flow_shift)
|
117 |
+
|
118 |
+
n_params = 0
|
119 |
+
for p in self.dit.parameters():
|
120 |
+
n_params += p.numel()
|
121 |
+
print(f"Number of parameters in DiT: {n_params/1e6:.2f}M")
|
122 |
+
|
123 |
+
# override state_dict to exclude vae and dino, so we only save the trainable params.
|
124 |
+
def state_dict(self, *args, **kwargs):
|
125 |
+
state_dict = super().state_dict(*args, **kwargs)
|
126 |
+
|
127 |
+
keys_to_del = []
|
128 |
+
for k in state_dict.keys():
|
129 |
+
if "vae" in k or "dino" in k:
|
130 |
+
keys_to_del.append(k)
|
131 |
+
|
132 |
+
for k in keys_to_del:
|
133 |
+
del state_dict[k]
|
134 |
+
|
135 |
+
return state_dict
|
136 |
+
|
137 |
+
# override to support tolerant loading (only load matched shape)
|
138 |
+
def load_state_dict(self, state_dict, strict=True, assign=False):
|
139 |
+
local_state_dict = self.state_dict()
|
140 |
+
seen_keys = {k: False for k in local_state_dict.keys()}
|
141 |
+
for k, v in state_dict.items():
|
142 |
+
if k in local_state_dict:
|
143 |
+
seen_keys[k] = True
|
144 |
+
if local_state_dict[k].shape == v.shape:
|
145 |
+
local_state_dict[k].copy_(v)
|
146 |
+
else:
|
147 |
+
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
|
148 |
+
else:
|
149 |
+
print(f"unexpected key {k} in loaded state dict")
|
150 |
+
for k in seen_keys:
|
151 |
+
if not seen_keys[k]:
|
152 |
+
print(f"missing key {k} in loaded state dict")
|
153 |
+
|
154 |
+
# this happens before checkpointer loading old models !!!
|
155 |
+
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
|
156 |
+
super().on_train_start(memory_format=memory_format)
|
157 |
+
device = next(self.dit.parameters()).device
|
158 |
+
|
159 |
+
self.dit.to(dtype=self.precision)
|
160 |
+
|
161 |
+
if self.config.use_num_parts_cond:
|
162 |
+
self.num_part_embed.to(dtype=self.precision)
|
163 |
+
|
164 |
+
# cast scheduler to device
|
165 |
+
self.scheduler.to(device)
|
166 |
+
|
167 |
+
def get_cond(self, cond_image, num_part=None):
|
168 |
+
# image condition
|
169 |
+
cond_image = cond_image.to(dtype=self.precision)
|
170 |
+
with torch.no_grad():
|
171 |
+
cond = self.dino(cond_image).last_hidden_state
|
172 |
+
cond = F.layer_norm(cond.float(), cond.shape[-1:]).to(dtype=self.precision) # [B, L, C]
|
173 |
+
|
174 |
+
# num_part condition
|
175 |
+
if self.config.use_num_parts_cond:
|
176 |
+
if num_part is None:
|
177 |
+
# use a default value (2-10 parts)
|
178 |
+
num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device) * 2
|
179 |
+
else:
|
180 |
+
# coarse range
|
181 |
+
num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device)
|
182 |
+
num_part_coarse[num_part == 2] = 1
|
183 |
+
num_part_coarse[(num_part > 2) & (num_part <= 10)] = 2
|
184 |
+
num_part_coarse[(num_part > 10) & (num_part <= 100)] = 3
|
185 |
+
num_part_coarse[num_part > 100] = 4
|
186 |
+
num_part_embed = self.num_part_embed(num_part_coarse).unsqueeze(1) # [B, 1, C]
|
187 |
+
cond = torch.cat([cond, num_part_embed], dim=1) # [B, L+1, C]
|
188 |
+
|
189 |
+
return cond
|
190 |
+
|
191 |
+
def training_step(
|
192 |
+
self,
|
193 |
+
data: dict[str, torch.Tensor],
|
194 |
+
iteration: int,
|
195 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
196 |
+
output = {}
|
197 |
+
loss = 0
|
198 |
+
|
199 |
+
cond_images = self.preprocess_cond_image(
|
200 |
+
data["cond_images"]
|
201 |
+
) # [B, N, 3, 518, 518], we may load multiple (N) cond images for the same shape
|
202 |
+
B, N, C, H, W = cond_images.shape
|
203 |
+
|
204 |
+
if self.config.use_num_parts_cond:
|
205 |
+
cond_num_part = data["num_part"].repeat_interleave(N, dim=0)
|
206 |
+
else:
|
207 |
+
cond_num_part = None
|
208 |
+
|
209 |
+
cond = self.get_cond(cond_images.view(-1, C, H, W), cond_num_part) # [B*N, L, C]
|
210 |
+
|
211 |
+
# random CFG dropout
|
212 |
+
if self.training:
|
213 |
+
mask = torch.rand((B * N, 1, 1), device=cond.device, dtype=cond.dtype) >= 0.1
|
214 |
+
cond = cond * mask
|
215 |
+
|
216 |
+
with torch.no_grad():
|
217 |
+
# encode latent
|
218 |
+
if self.config.use_parts:
|
219 |
+
# encode two parts and concat latent
|
220 |
+
part0_data = {k.replace("_part0", ""): v for k, v in data.items() if "_part0" in k}
|
221 |
+
part1_data = {k.replace("_part1", ""): v for k, v in data.items() if "_part1" in k}
|
222 |
+
posterior0 = self.vae.encode(part0_data)
|
223 |
+
posterior1 = self.vae.encode(part1_data)
|
224 |
+
if self.training and self.config.shuffle_parts:
|
225 |
+
if np.random.rand() < 0.5:
|
226 |
+
posterior0, posterior1 = posterior1, posterior0
|
227 |
+
latent = torch.cat(
|
228 |
+
[
|
229 |
+
posterior0.mode().float().nan_to_num_(0),
|
230 |
+
posterior1.mode().float().nan_to_num_(0),
|
231 |
+
],
|
232 |
+
dim=1,
|
233 |
+
) # [B, 2L, C]
|
234 |
+
else:
|
235 |
+
posterior = self.vae.encode(data)
|
236 |
+
latent = posterior.mode().float().nan_to_num_(0) # use mean as the latent, [B, L, C]
|
237 |
+
|
238 |
+
# repeat latent for each cond image
|
239 |
+
if N != 1:
|
240 |
+
latent = latent.repeat_interleave(N, dim=0)
|
241 |
+
|
242 |
+
# random sample timesteps and add noise
|
243 |
+
noisy_latent, noise, timesteps = self.scheduler.add_noise(
|
244 |
+
latent, self.config.logitnorm_mean, self.config.logitnorm_std
|
245 |
+
)
|
246 |
+
|
247 |
+
noisy_latent = noisy_latent.to(dtype=self.precision)
|
248 |
+
model_pred = self.dit(noisy_latent, cond, timesteps)
|
249 |
+
|
250 |
+
# flow-matching loss
|
251 |
+
target = noise - latent
|
252 |
+
loss = F.mse_loss(model_pred.float(), target.float())
|
253 |
+
|
254 |
+
# metrics
|
255 |
+
with torch.no_grad():
|
256 |
+
output["scalar"] = {} # for wandb logging
|
257 |
+
output["scalar"]["loss_mse"] = loss.detach()
|
258 |
+
|
259 |
+
return output, loss
|
260 |
+
|
261 |
+
@torch.no_grad()
|
262 |
+
def validation_step(
|
263 |
+
self,
|
264 |
+
data: dict[str, torch.Tensor],
|
265 |
+
iteration: int,
|
266 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
267 |
+
return self.training_step(data, iteration)
|
268 |
+
|
269 |
+
@torch.inference_mode()
|
270 |
+
@sync_timer("flow forward")
|
271 |
+
def forward(
|
272 |
+
self,
|
273 |
+
data: dict[str, torch.Tensor],
|
274 |
+
num_steps: int = 30,
|
275 |
+
cfg_scale: float = 7.0,
|
276 |
+
verbose: bool = True,
|
277 |
+
generator: torch.Generator | None = None,
|
278 |
+
) -> dict[str, torch.Tensor]:
|
279 |
+
# the inference sampling
|
280 |
+
cond_images = self.preprocess_cond_image(data["cond_images"]) # [B, 3, 518, 518]
|
281 |
+
B = cond_images.shape[0]
|
282 |
+
assert B == 1, "Only support batch size 1 for now."
|
283 |
+
|
284 |
+
# num_part condition
|
285 |
+
if self.config.use_num_parts_cond and "num_part" in data:
|
286 |
+
cond_num_part = data["num_part"] # [B,], int
|
287 |
+
else:
|
288 |
+
cond_num_part = None
|
289 |
+
|
290 |
+
cond = self.get_cond(cond_images, cond_num_part)
|
291 |
+
|
292 |
+
if self.config.use_parts:
|
293 |
+
x = torch.randn(
|
294 |
+
B,
|
295 |
+
self.config.latent_size * 2,
|
296 |
+
self.config.latent_dim,
|
297 |
+
device=cond.device,
|
298 |
+
dtype=torch.float32,
|
299 |
+
generator=generator,
|
300 |
+
)
|
301 |
+
else:
|
302 |
+
x = torch.randn(
|
303 |
+
B,
|
304 |
+
self.config.latent_size,
|
305 |
+
self.config.latent_dim,
|
306 |
+
device=cond.device,
|
307 |
+
dtype=torch.float32,
|
308 |
+
generator=generator,
|
309 |
+
)
|
310 |
+
|
311 |
+
cond_input = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
312 |
+
|
313 |
+
# flow-matching
|
314 |
+
sigmas = np.linspace(1, 0, num_steps + 1)
|
315 |
+
sigmas = self.scheduler.shift * sigmas / (1 + (self.scheduler.shift - 1) * sigmas)
|
316 |
+
sigmas_pair = list((sigmas[i], sigmas[i + 1]) for i in range(num_steps))
|
317 |
+
|
318 |
+
for sigma, sigma_prev in tqdm.tqdm(sigmas_pair, desc="Flow Sampling", disable=not verbose):
|
319 |
+
# classifier-free guidance
|
320 |
+
timesteps = torch.tensor([1000 * sigma] * B * 2, device=x.device, dtype=x.dtype)
|
321 |
+
x_input = torch.cat([x, x], dim=0)
|
322 |
+
|
323 |
+
# predict v
|
324 |
+
x_input = x_input.to(dtype=self.precision)
|
325 |
+
pred = self.dit(x_input, cond_input, timesteps).float()
|
326 |
+
cond_v, uncond_v = pred.chunk(2, dim=0)
|
327 |
+
pred_v = uncond_v + (cond_v - uncond_v) * cfg_scale
|
328 |
+
|
329 |
+
# scheduler step
|
330 |
+
x = x - (sigma - sigma_prev) * pred_v
|
331 |
+
|
332 |
+
output = {}
|
333 |
+
output["latent"] = x
|
334 |
+
|
335 |
+
# leave mesh extraction to vae
|
336 |
+
return output
|
flow/modules/__init__.py
ADDED
File without changes
|
flow/modules/dit.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.utils.checkpoint import checkpoint
|
18 |
+
|
19 |
+
from vae.modules.attention import CrossAttention, SelfAttention
|
20 |
+
|
21 |
+
|
22 |
+
class FeedForward(nn.Module):
|
23 |
+
def __init__(self, dim, mult=4):
|
24 |
+
super().__init__()
|
25 |
+
self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.net(x)
|
29 |
+
|
30 |
+
|
31 |
+
# Adapted from https://github.com/facebookresearch/DiT/blob/main/models.py#L27
|
32 |
+
class TimestepEmbedder(nn.Module):
|
33 |
+
"""
|
34 |
+
Embeds scalar timesteps into vector representations.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
38 |
+
super().__init__()
|
39 |
+
self.mlp = nn.Sequential(
|
40 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
41 |
+
nn.SiLU(),
|
42 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
43 |
+
)
|
44 |
+
self.frequency_embedding_size = frequency_embedding_size
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def timestep_embedding(t, dim, max_period=10000):
|
48 |
+
"""
|
49 |
+
Create sinusoidal timestep embeddings.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
t: a 1-D Tensor of N indices, one per batch element.
|
53 |
+
These may be fractional.
|
54 |
+
dim: the dimension of the output.
|
55 |
+
max_period: controls the minimum frequency of the embeddings.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
an (N, D) Tensor of positional embeddings.
|
59 |
+
"""
|
60 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
61 |
+
half = dim // 2
|
62 |
+
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
63 |
+
device=t.device
|
64 |
+
)
|
65 |
+
args = t[:, None].float() * freqs[None]
|
66 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
67 |
+
if dim % 2:
|
68 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
69 |
+
return embedding
|
70 |
+
|
71 |
+
def forward(self, t):
|
72 |
+
dtype = next(self.mlp.parameters()).dtype # need to determine on the fly...
|
73 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
74 |
+
t_freq = t_freq.to(dtype=dtype)
|
75 |
+
t_emb = self.mlp(t_freq)
|
76 |
+
return t_emb
|
77 |
+
|
78 |
+
|
79 |
+
class DiTLayer(nn.Module):
|
80 |
+
def __init__(self, dim, num_heads, qknorm=False, gradient_checkpointing=True, qknorm_type="LayerNorm"):
|
81 |
+
super().__init__()
|
82 |
+
self.dim = dim
|
83 |
+
self.num_heads = num_heads
|
84 |
+
self.gradient_checkpointing = gradient_checkpointing
|
85 |
+
|
86 |
+
self.norm1 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
87 |
+
self.attn1 = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
|
88 |
+
self.norm2 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
89 |
+
self.attn2 = CrossAttention(dim, num_heads, context_dim=dim, qknorm=qknorm, qknorm_type=qknorm_type)
|
90 |
+
self.norm3 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
91 |
+
self.ff = FeedForward(dim)
|
92 |
+
self.adaln_linear = nn.Linear(dim, dim * 6, bias=True)
|
93 |
+
|
94 |
+
def forward(self, x, c, t_emb):
|
95 |
+
if self.training and self.gradient_checkpointing:
|
96 |
+
return checkpoint(self._forward, x, c, t_emb, use_reentrant=False)
|
97 |
+
else:
|
98 |
+
return self._forward(x, c, t_emb)
|
99 |
+
|
100 |
+
def _forward(self, x, c, t_emb):
|
101 |
+
# x: [B, N, C], hidden states
|
102 |
+
# c: [B, M, C], condition (assume normed and projected to C)
|
103 |
+
# t_emb: [B, C], timestep embedding of adaln
|
104 |
+
# return: [B, N, C], updated hidden states
|
105 |
+
|
106 |
+
B, N, C = x.shape
|
107 |
+
t_adaln = self.adaln_linear(F.silu(t_emb)).view(B, 6, -1) # [B, 6, C]
|
108 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_adaln.chunk(6, dim=1)
|
109 |
+
|
110 |
+
h = self.norm1(x)
|
111 |
+
h = h * (1 + scale_msa) + shift_msa
|
112 |
+
x = x + gate_msa * self.attn1(h)
|
113 |
+
|
114 |
+
h = self.norm2(x)
|
115 |
+
x = x + self.attn2(h, c)
|
116 |
+
|
117 |
+
h = self.norm3(x)
|
118 |
+
h = h * (1 + scale_mlp) + shift_mlp
|
119 |
+
x = x + gate_mlp * self.ff(h)
|
120 |
+
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class DiT(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
hidden_dim=1024,
|
128 |
+
num_heads=16,
|
129 |
+
latent_size=2048,
|
130 |
+
latent_dim=8,
|
131 |
+
num_layers=24,
|
132 |
+
qknorm=False,
|
133 |
+
gradient_checkpointing=True,
|
134 |
+
qknorm_type="LayerNorm",
|
135 |
+
use_pos_embed=False,
|
136 |
+
use_parts=False,
|
137 |
+
part_embed_mode="part2_only",
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
# project in
|
142 |
+
self.proj_in = nn.Linear(latent_dim, hidden_dim)
|
143 |
+
|
144 |
+
# positional encoding (just use a learnable positional encoding)
|
145 |
+
self.use_pos_embed = use_pos_embed
|
146 |
+
if self.use_pos_embed:
|
147 |
+
self.pos_embed = nn.Parameter(torch.randn(1, latent_size, hidden_dim) / hidden_dim**0.5)
|
148 |
+
|
149 |
+
# part encoding (a must to distinguish parts!)
|
150 |
+
self.use_parts = use_parts
|
151 |
+
self.part_embed_mode = part_embed_mode
|
152 |
+
if self.use_parts:
|
153 |
+
if self.part_embed_mode == "element":
|
154 |
+
self.part_embed = nn.Parameter(torch.randn(latent_size, hidden_dim) / hidden_dim**0.5)
|
155 |
+
elif self.part_embed_mode == "part":
|
156 |
+
self.part_embed = nn.Parameter(torch.randn(2, hidden_dim))
|
157 |
+
elif self.part_embed_mode == "part2_only":
|
158 |
+
# we only add this to the second part to distinguish from the first part
|
159 |
+
self.part_embed = nn.Parameter(torch.randn(1, hidden_dim) / hidden_dim**0.5)
|
160 |
+
|
161 |
+
# timestep encoding
|
162 |
+
self.timestep_embed = TimestepEmbedder(hidden_dim)
|
163 |
+
|
164 |
+
# transformer layers
|
165 |
+
self.layers = nn.ModuleList(
|
166 |
+
[DiTLayer(hidden_dim, num_heads, qknorm, gradient_checkpointing, qknorm_type) for _ in range(num_layers)]
|
167 |
+
)
|
168 |
+
|
169 |
+
# project out
|
170 |
+
self.norm_out = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
|
171 |
+
self.proj_out = nn.Linear(hidden_dim, latent_dim)
|
172 |
+
|
173 |
+
# init
|
174 |
+
self.init_weight()
|
175 |
+
|
176 |
+
def init_weight(self):
|
177 |
+
# Initialize transformer layers
|
178 |
+
def _basic_init(module):
|
179 |
+
if isinstance(module, nn.Linear):
|
180 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
181 |
+
if module.bias is not None:
|
182 |
+
nn.init.constant_(module.bias, 0)
|
183 |
+
|
184 |
+
self.apply(_basic_init)
|
185 |
+
|
186 |
+
# Initialize timestep embedding MLP:
|
187 |
+
nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
|
188 |
+
nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
|
189 |
+
|
190 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
191 |
+
for layer in self.layers:
|
192 |
+
nn.init.constant_(layer.adaln_linear.weight, 0)
|
193 |
+
nn.init.constant_(layer.adaln_linear.bias, 0)
|
194 |
+
|
195 |
+
# Zero-out output layers:
|
196 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
197 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
198 |
+
|
199 |
+
def forward(self, x, c, t):
|
200 |
+
# x: [B, N, C], hidden states
|
201 |
+
# c: [B, M, C], condition (assume normed and projected to C)
|
202 |
+
# t: [B,], timestep
|
203 |
+
# return: [B, N, C], updated hidden states
|
204 |
+
|
205 |
+
B, N, C = x.shape
|
206 |
+
|
207 |
+
# project in
|
208 |
+
x = self.proj_in(x)
|
209 |
+
|
210 |
+
# positional encoding
|
211 |
+
if self.use_pos_embed:
|
212 |
+
x = x + self.pos_embed
|
213 |
+
|
214 |
+
# part encoding
|
215 |
+
if self.use_parts:
|
216 |
+
if self.part_embed_mode == "element":
|
217 |
+
x += self.part_embed
|
218 |
+
elif self.part_embed_mode == "part":
|
219 |
+
x[:, : x.shape[1] // 2, :] += self.part_embed[0]
|
220 |
+
x[:, x.shape[1] // 2 :, :] += self.part_embed[1]
|
221 |
+
elif self.part_embed_mode == "part2_only":
|
222 |
+
x[:, x.shape[1] // 2 :, :] += self.part_embed[0]
|
223 |
+
|
224 |
+
# timestep encoding
|
225 |
+
t_emb = self.timestep_embed(t) # [B, C]
|
226 |
+
|
227 |
+
# transformer layers
|
228 |
+
for layer in self.layers:
|
229 |
+
x = layer(x, c, t_emb)
|
230 |
+
|
231 |
+
# project out
|
232 |
+
x = self.norm_out(x)
|
233 |
+
x = self.proj_out(x)
|
234 |
+
|
235 |
+
return x
|
flow/scripts/infer.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import glob
|
15 |
+
import importlib
|
16 |
+
import os
|
17 |
+
from datetime import datetime
|
18 |
+
|
19 |
+
import cv2
|
20 |
+
import kiui
|
21 |
+
import numpy as np
|
22 |
+
import rembg
|
23 |
+
import torch
|
24 |
+
import trimesh
|
25 |
+
|
26 |
+
from flow.model import Model
|
27 |
+
from flow.utils import get_random_color, recenter_foreground
|
28 |
+
from vae.utils import postprocess_mesh
|
29 |
+
|
30 |
+
# PYTHONPATH=. python flow/scripts/infer.py
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument(
|
33 |
+
"--config",
|
34 |
+
type=str,
|
35 |
+
help="config file path",
|
36 |
+
default="flow.configs.big_parts_strict_pvae",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--ckpt_path",
|
40 |
+
type=str,
|
41 |
+
help="checkpoint path",
|
42 |
+
default="pretrained/flow.pt",
|
43 |
+
)
|
44 |
+
parser.add_argument("--input", type=str, help="input directory", default="assets/images/")
|
45 |
+
parser.add_argument("--limit", type=int, help="limit number of images", default=-1)
|
46 |
+
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
|
47 |
+
parser.add_argument("--grid_res", type=int, help="grid resolution", default=384)
|
48 |
+
parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30)
|
49 |
+
parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0)
|
50 |
+
parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1)
|
51 |
+
parser.add_argument("--seed", type=int, help="seed", default=42)
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
|
55 |
+
|
56 |
+
bg_remover = rembg.new_session()
|
57 |
+
|
58 |
+
|
59 |
+
def preprocess_image(path):
|
60 |
+
input_image = kiui.read_image(path, mode="uint8", order="RGBA")
|
61 |
+
|
62 |
+
# bg removal if there is no alpha channel
|
63 |
+
if input_image.shape[-1] == 3:
|
64 |
+
input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
|
65 |
+
|
66 |
+
mask = input_image[..., -1] > 0
|
67 |
+
image = recenter_foreground(input_image, mask, border_ratio=0.1)
|
68 |
+
image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
|
69 |
+
image = image.astype(np.float32) / 255.0
|
70 |
+
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
|
71 |
+
return image
|
72 |
+
|
73 |
+
|
74 |
+
print(f"Loading checkpoint from {args.ckpt_path}")
|
75 |
+
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
|
76 |
+
|
77 |
+
# delete all keys other than model
|
78 |
+
if "model" in ckpt_dict:
|
79 |
+
ckpt_dict = ckpt_dict["model"]
|
80 |
+
|
81 |
+
# instantiate model
|
82 |
+
print(f"Instantiating model from {args.config}")
|
83 |
+
model_config = importlib.import_module(args.config).make_config()
|
84 |
+
model = Model(model_config).eval().cuda().bfloat16()
|
85 |
+
|
86 |
+
# load weight
|
87 |
+
print(f"Loading weights from {args.ckpt_path}")
|
88 |
+
model.load_state_dict(ckpt_dict, strict=True)
|
89 |
+
|
90 |
+
# output folder
|
91 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
92 |
+
workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp)
|
93 |
+
if not os.path.exists(workspace):
|
94 |
+
os.makedirs(workspace)
|
95 |
+
else:
|
96 |
+
os.system(f"rm {workspace}/*")
|
97 |
+
print(f"Output directory: {workspace}")
|
98 |
+
|
99 |
+
# load test images
|
100 |
+
if os.path.isdir(args.input):
|
101 |
+
paths = glob.glob(os.path.join(args.input, "*"))
|
102 |
+
paths = sorted(paths)
|
103 |
+
if args.limit > 0:
|
104 |
+
paths = paths[: args.limit]
|
105 |
+
else: # single file
|
106 |
+
paths = [args.input]
|
107 |
+
|
108 |
+
for path in paths:
|
109 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
110 |
+
print(f"Processing {name}")
|
111 |
+
|
112 |
+
image = preprocess_image(path)
|
113 |
+
|
114 |
+
kiui.write_image(os.path.join(workspace, name + ".jpg"), image)
|
115 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
|
116 |
+
|
117 |
+
# run model
|
118 |
+
data = {"cond_images": image}
|
119 |
+
|
120 |
+
for i in range(args.num_repeats):
|
121 |
+
|
122 |
+
kiui.seed_everything(args.seed + i)
|
123 |
+
|
124 |
+
with torch.inference_mode():
|
125 |
+
results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale)
|
126 |
+
|
127 |
+
latent = results["latent"]
|
128 |
+
# kiui.lo(latent)
|
129 |
+
|
130 |
+
# query mesh
|
131 |
+
if model.config.use_parts:
|
132 |
+
data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
|
133 |
+
data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
|
134 |
+
|
135 |
+
with torch.inference_mode():
|
136 |
+
results_part0 = model.vae(data_part0, resolution=args.grid_res)
|
137 |
+
results_part1 = model.vae(data_part1, resolution=args.grid_res)
|
138 |
+
|
139 |
+
vertices, faces = results_part0["meshes"][0]
|
140 |
+
mesh_part0 = trimesh.Trimesh(vertices, faces)
|
141 |
+
mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
|
142 |
+
mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
|
143 |
+
parts = mesh_part0.split(only_watertight=False)
|
144 |
+
|
145 |
+
vertices, faces = results_part1["meshes"][0]
|
146 |
+
mesh_part1 = trimesh.Trimesh(vertices, faces)
|
147 |
+
mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
|
148 |
+
mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
|
149 |
+
parts.extend(mesh_part1.split(only_watertight=False))
|
150 |
+
|
151 |
+
# split connected components and assign different colors
|
152 |
+
for j, part in enumerate(parts):
|
153 |
+
# each component uses a random color
|
154 |
+
part.visual.vertex_colors = get_random_color(j, use_float=True)
|
155 |
+
|
156 |
+
mesh = trimesh.Scene(parts)
|
157 |
+
# export the whole mesh
|
158 |
+
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
|
159 |
+
|
160 |
+
# export each part
|
161 |
+
for j, part in enumerate(parts):
|
162 |
+
part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb"))
|
163 |
+
|
164 |
+
# export dual volumes
|
165 |
+
mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb"))
|
166 |
+
mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb"))
|
167 |
+
|
168 |
+
else:
|
169 |
+
data = {"latent": latent}
|
170 |
+
|
171 |
+
with torch.inference_mode():
|
172 |
+
results = model.vae(data, resolution=args.grid_res)
|
173 |
+
|
174 |
+
vertices, faces = results["meshes"][0]
|
175 |
+
mesh = trimesh.Trimesh(vertices, faces)
|
176 |
+
mesh = postprocess_mesh(mesh, 5e4)
|
177 |
+
|
178 |
+
# kiui.lo(mesh.vertices, mesh.faces)
|
179 |
+
mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T
|
180 |
+
mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
|
flow/utils.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
def recenter_foreground(image, mask, border_ratio: float = 0.1):
|
20 |
+
"""recenter an image to leave some empty space at the image border.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
image (ndarray): input image, float/uint8 [H, W, 3/4]
|
24 |
+
mask (ndarray): alpha mask, bool [H, W]
|
25 |
+
border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.1.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
ndarray: output image, float/uint8 [H, W, 3/4]
|
29 |
+
"""
|
30 |
+
|
31 |
+
# empty foreground: just return
|
32 |
+
if mask.sum() == 0:
|
33 |
+
return image
|
34 |
+
|
35 |
+
return_int = False
|
36 |
+
if image.dtype == np.uint8:
|
37 |
+
image = image.astype(np.float32) / 255
|
38 |
+
return_int = True
|
39 |
+
|
40 |
+
H, W, C = image.shape
|
41 |
+
size = max(H, W)
|
42 |
+
|
43 |
+
# default to white bg if rgb, but use 0 if rgba
|
44 |
+
if C == 3:
|
45 |
+
result = np.ones((size, size, C), dtype=np.float32)
|
46 |
+
else:
|
47 |
+
result = np.zeros((size, size, C), dtype=np.float32)
|
48 |
+
|
49 |
+
coords = np.nonzero(mask)
|
50 |
+
x_min, x_max = coords[0].min(), coords[0].max()
|
51 |
+
y_min, y_max = coords[1].min(), coords[1].max()
|
52 |
+
h = x_max - x_min
|
53 |
+
w = y_max - y_min
|
54 |
+
desired_size = int(size * (1 - border_ratio))
|
55 |
+
scale = desired_size / max(h, w)
|
56 |
+
h2 = int(h * scale)
|
57 |
+
w2 = int(w * scale)
|
58 |
+
x2_min = (size - h2) // 2
|
59 |
+
x2_max = x2_min + h2
|
60 |
+
y2_min = (size - w2) // 2
|
61 |
+
y2_max = y2_min + w2
|
62 |
+
result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(
|
63 |
+
image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA
|
64 |
+
)
|
65 |
+
|
66 |
+
if return_int:
|
67 |
+
result = (result * 255).astype(np.uint8)
|
68 |
+
|
69 |
+
return result
|
70 |
+
|
71 |
+
|
72 |
+
def get_random_color(index: Optional[int] = None, use_float: bool = False):
|
73 |
+
# some pleasing colors
|
74 |
+
# matplotlib.colormaps['Set3'].colors + matplotlib.colormaps['Set2'].colors + matplotlib.colormaps['Set1'].colors
|
75 |
+
palette = np.array(
|
76 |
+
[
|
77 |
+
[141, 211, 199, 255],
|
78 |
+
[255, 255, 179, 255],
|
79 |
+
[190, 186, 218, 255],
|
80 |
+
[251, 128, 114, 255],
|
81 |
+
[128, 177, 211, 255],
|
82 |
+
[253, 180, 98, 255],
|
83 |
+
[179, 222, 105, 255],
|
84 |
+
[252, 205, 229, 255],
|
85 |
+
[217, 217, 217, 255],
|
86 |
+
[188, 128, 189, 255],
|
87 |
+
[204, 235, 197, 255],
|
88 |
+
[255, 237, 111, 255],
|
89 |
+
[102, 194, 165, 255],
|
90 |
+
[252, 141, 98, 255],
|
91 |
+
[141, 160, 203, 255],
|
92 |
+
[231, 138, 195, 255],
|
93 |
+
[166, 216, 84, 255],
|
94 |
+
[255, 217, 47, 255],
|
95 |
+
[229, 196, 148, 255],
|
96 |
+
[179, 179, 179, 255],
|
97 |
+
[228, 26, 28, 255],
|
98 |
+
[55, 126, 184, 255],
|
99 |
+
[77, 175, 74, 255],
|
100 |
+
[152, 78, 163, 255],
|
101 |
+
[255, 127, 0, 255],
|
102 |
+
[255, 255, 51, 255],
|
103 |
+
[166, 86, 40, 255],
|
104 |
+
[247, 129, 191, 255],
|
105 |
+
[153, 153, 153, 255],
|
106 |
+
],
|
107 |
+
dtype=np.uint8,
|
108 |
+
)
|
109 |
+
|
110 |
+
if index is None:
|
111 |
+
index = np.random.randint(0, len(palette))
|
112 |
+
|
113 |
+
if index >= len(palette):
|
114 |
+
index = index % len(palette)
|
115 |
+
|
116 |
+
if use_float:
|
117 |
+
return palette[index].astype(np.float32) / 255
|
118 |
+
else:
|
119 |
+
return palette[index]
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1 # should be installed by HF
|
2 |
+
torchvision==0.20.1
|
3 |
+
numpy
|
4 |
+
trimesh
|
5 |
+
# meshiki # we don't use it for flow inference
|
6 |
+
fpsample
|
7 |
+
einops
|
8 |
+
onnxruntime
|
9 |
+
rembg
|
10 |
+
kiui
|
11 |
+
pymcubes
|
12 |
+
tqdm
|
13 |
+
opencv-python
|
14 |
+
ninja
|
15 |
+
pymeshlab
|
16 |
+
transformers
|
vae/__init__.py
ADDED
File without changes
|
vae/configs/__init__.py
ADDED
File without changes
|
vae/configs/part_woenc.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from vae.configs.schema import ModelConfig
|
14 |
+
|
15 |
+
|
16 |
+
def make_config():
|
17 |
+
|
18 |
+
model_config = ModelConfig(
|
19 |
+
use_salient_point=True,
|
20 |
+
latent_size=4096,
|
21 |
+
cutoff_fps_point=(256, 512, 512, 512, 1024, 1024, 2048),
|
22 |
+
cutoff_fps_salient_point=(0, 0, 256, 512, 512, 1024, 2048),
|
23 |
+
cutoff_fps_prob=(0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.2),
|
24 |
+
kl_weight=1e-3,
|
25 |
+
salient_attn_mode="dual",
|
26 |
+
num_enc_layers=0,
|
27 |
+
num_dec_layers=24,
|
28 |
+
)
|
29 |
+
|
30 |
+
return model_config
|
vae/configs/schema.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Literal, Optional, Tuple
|
14 |
+
|
15 |
+
import attrs
|
16 |
+
|
17 |
+
|
18 |
+
@attrs.define(slots=False)
|
19 |
+
class ModelConfig:
|
20 |
+
# input
|
21 |
+
use_salient_point: bool = True
|
22 |
+
|
23 |
+
# random cutoff during training
|
24 |
+
cutoff_fps_point: Tuple[int, ...] = (256, 512, 512, 512, 1024, 1024, 2048)
|
25 |
+
cutoff_fps_salient_point: Tuple[int, ...] = (0, 0, 256, 512, 512, 1024, 2048)
|
26 |
+
cutoff_fps_prob: Tuple[float, ...] = (0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.2) # sum to 1.0
|
27 |
+
|
28 |
+
# backbone transformer
|
29 |
+
num_enc_layers: int = 0
|
30 |
+
hidden_dim: int = 1024
|
31 |
+
num_heads: int = 16
|
32 |
+
num_dec_layers: int = 24
|
33 |
+
dec_hidden_dim: int = 1024
|
34 |
+
dec_num_heads: int = 16
|
35 |
+
qknorm: bool = True
|
36 |
+
qknorm_type: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" # type of qknorm
|
37 |
+
salient_attn_mode: Literal["dual_shared", "single", "dual"] = "dual"
|
38 |
+
|
39 |
+
# query decoder
|
40 |
+
fourier_version: Literal["v1", "v2", "v3"] = "v3"
|
41 |
+
point_fourier_dim: int = 48 # must be divisible by 6 (sin/cos, x/y/z)
|
42 |
+
query_hidden_dim: int = 1024
|
43 |
+
query_num_heads: int = 16
|
44 |
+
use_flash_query: bool = False
|
45 |
+
|
46 |
+
# latent code
|
47 |
+
latent_size: int = 4096 # == num_fps_point + num_fps_salient_point
|
48 |
+
latent_dim: int = 64
|
49 |
+
|
50 |
+
# loss
|
51 |
+
use_ae: bool = False # if true, variance will be ignored, and kl_weight is used as a L2 norm weight
|
52 |
+
kl_weight: float = 1e-3
|
53 |
+
|
54 |
+
# init weights from a pretrained checkpoint
|
55 |
+
pretrain_path: Optional[str] = None
|
vae/model.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import Literal
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from vae.configs.schema import ModelConfig
|
21 |
+
from vae.modules.transformer import AttentionBlock, FlashQueryLayer
|
22 |
+
from vae.utils import (
|
23 |
+
DiagonalGaussianDistribution,
|
24 |
+
DummyLatent,
|
25 |
+
calculate_iou,
|
26 |
+
calculate_metrics,
|
27 |
+
construct_grid_points,
|
28 |
+
extract_mesh,
|
29 |
+
sync_timer,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class Model(nn.Module):
|
34 |
+
def __init__(self, config: ModelConfig) -> None:
|
35 |
+
super().__init__()
|
36 |
+
self.config = config
|
37 |
+
|
38 |
+
self.precision = torch.bfloat16 # manually handle low-precision training, always use bf16
|
39 |
+
|
40 |
+
# point encoder
|
41 |
+
self.proj_input = nn.Linear(3 + config.point_fourier_dim, config.hidden_dim)
|
42 |
+
|
43 |
+
self.perceiver = AttentionBlock(
|
44 |
+
config.hidden_dim,
|
45 |
+
num_heads=config.num_heads,
|
46 |
+
dim_context=config.hidden_dim,
|
47 |
+
qknorm=config.qknorm,
|
48 |
+
qknorm_type=config.qknorm_type,
|
49 |
+
)
|
50 |
+
|
51 |
+
if self.config.salient_attn_mode == "dual":
|
52 |
+
self.perceiver_dorases = AttentionBlock(
|
53 |
+
config.hidden_dim,
|
54 |
+
num_heads=config.num_heads,
|
55 |
+
dim_context=config.hidden_dim,
|
56 |
+
qknorm=config.qknorm,
|
57 |
+
qknorm_type=config.qknorm_type,
|
58 |
+
)
|
59 |
+
|
60 |
+
# self-attention encoder
|
61 |
+
self.encoder = nn.ModuleList(
|
62 |
+
[
|
63 |
+
AttentionBlock(
|
64 |
+
config.hidden_dim, config.num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type
|
65 |
+
)
|
66 |
+
for _ in range(config.num_enc_layers)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
# vae bottleneck
|
71 |
+
self.norm_down = nn.LayerNorm(config.hidden_dim)
|
72 |
+
self.proj_down_mean = nn.Linear(config.hidden_dim, config.latent_dim)
|
73 |
+
if not self.config.use_ae:
|
74 |
+
self.proj_down_std = nn.Linear(config.hidden_dim, config.latent_dim)
|
75 |
+
self.proj_up = nn.Linear(config.latent_dim, config.dec_hidden_dim)
|
76 |
+
|
77 |
+
# self-attention decoder
|
78 |
+
self.decoder = nn.ModuleList(
|
79 |
+
[
|
80 |
+
AttentionBlock(
|
81 |
+
config.dec_hidden_dim, config.dec_num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type
|
82 |
+
)
|
83 |
+
for _ in range(config.num_dec_layers)
|
84 |
+
]
|
85 |
+
)
|
86 |
+
|
87 |
+
# cross-attention query
|
88 |
+
self.proj_query = nn.Linear(3 + config.point_fourier_dim, config.query_hidden_dim)
|
89 |
+
if self.config.use_flash_query:
|
90 |
+
self.norm_query_context = nn.LayerNorm(config.hidden_dim, eps=1e-6, elementwise_affine=False)
|
91 |
+
self.attn_query = FlashQueryLayer(
|
92 |
+
config.query_hidden_dim,
|
93 |
+
num_heads=config.query_num_heads,
|
94 |
+
dim_context=config.hidden_dim,
|
95 |
+
qknorm=config.qknorm,
|
96 |
+
qknorm_type=config.qknorm_type,
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.attn_query = AttentionBlock(
|
100 |
+
config.query_hidden_dim,
|
101 |
+
num_heads=config.query_num_heads,
|
102 |
+
dim_context=config.hidden_dim,
|
103 |
+
qknorm=config.qknorm,
|
104 |
+
qknorm_type=config.qknorm_type,
|
105 |
+
)
|
106 |
+
self.norm_out = nn.LayerNorm(config.query_hidden_dim)
|
107 |
+
self.proj_out = nn.Linear(config.query_hidden_dim, 1)
|
108 |
+
|
109 |
+
# preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!)
|
110 |
+
if self.config.pretrain_path is not None:
|
111 |
+
try:
|
112 |
+
ckpt = torch.load(self.config.pretrain_path) # local path
|
113 |
+
self.load_state_dict(ckpt["model"], strict=True)
|
114 |
+
del ckpt
|
115 |
+
print(f"Loaded VAE from {self.config.pretrain_path}")
|
116 |
+
except Exception as e:
|
117 |
+
print(
|
118 |
+
f"Failed to load VAE from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!"
|
119 |
+
)
|
120 |
+
|
121 |
+
# log
|
122 |
+
n_params = 0
|
123 |
+
for p in self.parameters():
|
124 |
+
n_params += p.numel()
|
125 |
+
print(f"Number of parameters in VAE: {n_params / 1e6:.2f}M")
|
126 |
+
|
127 |
+
# override to support tolerant loading (only load matched shape)
|
128 |
+
def load_state_dict(self, state_dict, strict=True, assign=False):
|
129 |
+
local_state_dict = self.state_dict()
|
130 |
+
seen_keys = {k: False for k in local_state_dict.keys()}
|
131 |
+
for k, v in state_dict.items():
|
132 |
+
if k in local_state_dict:
|
133 |
+
seen_keys[k] = True
|
134 |
+
if local_state_dict[k].shape == v.shape:
|
135 |
+
local_state_dict[k].copy_(v)
|
136 |
+
else:
|
137 |
+
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
|
138 |
+
else:
|
139 |
+
print(f"unexpected key {k} in loaded state dict")
|
140 |
+
for k in seen_keys:
|
141 |
+
if not seen_keys[k]:
|
142 |
+
print(f"missing key {k} in loaded state dict")
|
143 |
+
|
144 |
+
def fourier_encoding(self, points: torch.Tensor):
|
145 |
+
# points: [B, N, 3], float32 for precision
|
146 |
+
# assert points.dtype == torch.float32, "Query points must be float32"
|
147 |
+
|
148 |
+
F = self.config.point_fourier_dim // (2 * points.shape[-1])
|
149 |
+
|
150 |
+
if self.config.fourier_version == "v1": # default
|
151 |
+
exponent = torch.arange(1, F + 1, device=points.device, dtype=torch.float32) / F # [F], range from 0 to 1
|
152 |
+
freq_band = 512**exponent # [F], min frequency is 1, max frequency is 1/freq
|
153 |
+
freq_band *= torch.pi
|
154 |
+
elif self.config.fourier_version == "v2":
|
155 |
+
exponent = torch.arange(F, device=points.device, dtype=torch.float32) / (F - 1) # [F], range from 0 to 1
|
156 |
+
freq_band = 1024**exponent # [F]
|
157 |
+
freq_band *= torch.pi
|
158 |
+
elif self.config.fourier_version == "v3": # hunyuan3d-2
|
159 |
+
freq_band = 2 ** torch.arange(F, device=points.device, dtype=torch.float32) # [F]
|
160 |
+
|
161 |
+
spectrum = points.unsqueeze(-1) * freq_band # [B,...,3,F]
|
162 |
+
sin, cos = spectrum.sin(), spectrum.cos() # [B,...,3,F]
|
163 |
+
input_enc = torch.stack([sin, cos], dim=-2) # [B,...,3,2,F]
|
164 |
+
input_enc = input_enc.view(*points.shape[:-1], -1) # [B,...,6F] = [B,...,dim]
|
165 |
+
return torch.cat([input_enc, points], dim=-1).to(dtype=self.precision) # [B,...,dim+input_dim]
|
166 |
+
|
167 |
+
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
|
168 |
+
super().on_train_start(memory_format=memory_format)
|
169 |
+
self.to(dtype=self.precision, memory_format=memory_format) # use bfloat16 for training
|
170 |
+
|
171 |
+
def encode(self, data: dict[str, torch.Tensor]):
|
172 |
+
# uniform points
|
173 |
+
pointcloud = data["pointcloud"] # [B, N, 3]
|
174 |
+
|
175 |
+
# fourier embed and project
|
176 |
+
pointcloud = self.fourier_encoding(pointcloud) # [B, N, 3+C]
|
177 |
+
pointcloud = self.proj_input(pointcloud) # [B, N, hidden_dim]
|
178 |
+
|
179 |
+
# salient points
|
180 |
+
if self.config.use_salient_point:
|
181 |
+
pointcloud_dorases = data["pointcloud_dorases"] # [B, M, 3]
|
182 |
+
|
183 |
+
# fourier embed and project (shared weights)
|
184 |
+
pointcloud_dorases = self.fourier_encoding(pointcloud_dorases) # [B, M, 3+C]
|
185 |
+
pointcloud_dorases = self.proj_input(pointcloud_dorases) # [B, M, hidden_dim]
|
186 |
+
|
187 |
+
# gather fps point
|
188 |
+
fps_indices = data["fps_indices"] # [B, N']
|
189 |
+
pointcloud_query = torch.gather(pointcloud, 1, fps_indices.unsqueeze(-1).expand(-1, -1, pointcloud.shape[-1]))
|
190 |
+
|
191 |
+
if self.config.use_salient_point:
|
192 |
+
fps_indices_dorases = data["fps_indices_dorases"] # [B, M']
|
193 |
+
|
194 |
+
if fps_indices_dorases.shape[1] > 0:
|
195 |
+
pointcloud_query_dorases = torch.gather(
|
196 |
+
pointcloud_dorases,
|
197 |
+
1,
|
198 |
+
fps_indices_dorases.unsqueeze(-1).expand(-1, -1, pointcloud_dorases.shape[-1]),
|
199 |
+
)
|
200 |
+
|
201 |
+
# combine both fps points as the query
|
202 |
+
pointcloud_query = torch.cat(
|
203 |
+
[pointcloud_query, pointcloud_query_dorases], dim=1
|
204 |
+
) # [B, N'+M', hidden_dim]
|
205 |
+
|
206 |
+
# dual cross-attention
|
207 |
+
if self.config.salient_attn_mode == "dual_shared":
|
208 |
+
hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver(
|
209 |
+
pointcloud_query, pointcloud_dorases
|
210 |
+
) # [B, N'+M', hidden_dim]
|
211 |
+
elif self.config.salient_attn_mode == "dual":
|
212 |
+
hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver_dorases(
|
213 |
+
pointcloud_query, pointcloud_dorases
|
214 |
+
)
|
215 |
+
else: # single, hunyuan3d-2 style
|
216 |
+
hidden_states = self.perceiver(pointcloud_query, torch.cat([pointcloud, pointcloud_dorases], dim=1))
|
217 |
+
else:
|
218 |
+
hidden_states = self.perceiver(pointcloud_query, pointcloud) # [B, N', hidden_dim]
|
219 |
+
|
220 |
+
# encoder
|
221 |
+
for block in self.encoder:
|
222 |
+
hidden_states = block(hidden_states)
|
223 |
+
|
224 |
+
# bottleneck
|
225 |
+
hidden_states = self.norm_down(hidden_states)
|
226 |
+
latent_mean = self.proj_down_mean(hidden_states).float()
|
227 |
+
if not self.config.use_ae:
|
228 |
+
latent_std = self.proj_down_std(hidden_states).float()
|
229 |
+
posterior = DiagonalGaussianDistribution(latent_mean, latent_std)
|
230 |
+
else:
|
231 |
+
posterior = DummyLatent(latent_mean)
|
232 |
+
|
233 |
+
return posterior
|
234 |
+
|
235 |
+
def decode(self, latent: torch.Tensor):
|
236 |
+
latent = latent.to(dtype=self.precision)
|
237 |
+
hidden_states = self.proj_up(latent)
|
238 |
+
|
239 |
+
for block in self.decoder:
|
240 |
+
hidden_states = block(hidden_states)
|
241 |
+
|
242 |
+
return hidden_states
|
243 |
+
|
244 |
+
def query(self, query_points: torch.Tensor, hidden_states: torch.Tensor):
|
245 |
+
# query_points: [B, N, 3], float32 to keep the precision
|
246 |
+
|
247 |
+
query_points = self.fourier_encoding(query_points) # [B, N, 3+C]
|
248 |
+
query_points = self.proj_query(query_points) # [B, N, hidden_dim]
|
249 |
+
|
250 |
+
# cross attention
|
251 |
+
query_output = self.attn_query(query_points, hidden_states) # [B, N, hidden_dim]
|
252 |
+
|
253 |
+
# output linear
|
254 |
+
query_output = self.norm_out(query_output)
|
255 |
+
pred = self.proj_out(query_output) # [B, N, 1]
|
256 |
+
|
257 |
+
return pred
|
258 |
+
|
259 |
+
def training_step(
|
260 |
+
self,
|
261 |
+
data: dict[str, torch.Tensor],
|
262 |
+
iteration: int,
|
263 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
264 |
+
output = {}
|
265 |
+
|
266 |
+
# cut off fps point during training for progressive flow
|
267 |
+
if self.training:
|
268 |
+
# randomly choose from a set of cutoff candidates
|
269 |
+
cutoff_index = np.random.choice(len(self.config.cutoff_fps_prob), p=self.config.cutoff_fps_prob)
|
270 |
+
cutoff_fps_point = self.config.cutoff_fps_point[cutoff_index]
|
271 |
+
cutoff_fps_salient_point = self.config.cutoff_fps_salient_point[cutoff_index]
|
272 |
+
# prefix of FPS points are still FPS points
|
273 |
+
data["fps_indices"] = data["fps_indices"][:, :cutoff_fps_point]
|
274 |
+
if self.config.use_salient_point:
|
275 |
+
data["fps_indices_dorases"] = data["fps_indices_dorases"][:, :cutoff_fps_salient_point]
|
276 |
+
|
277 |
+
loss = 0
|
278 |
+
|
279 |
+
# encode
|
280 |
+
posterior = self.encode(data)
|
281 |
+
latent_geom = posterior.sample() if self.training else posterior.mode()
|
282 |
+
|
283 |
+
# decode
|
284 |
+
hidden_states = self.decode(latent_geom)
|
285 |
+
|
286 |
+
# cross-attention query
|
287 |
+
query_points = data["query_points"] # [B, N, 3], float32
|
288 |
+
|
289 |
+
# the context norm can be moved out to avoid repeated computation
|
290 |
+
if self.config.use_flash_query:
|
291 |
+
hidden_states = self.norm_query_context(hidden_states)
|
292 |
+
|
293 |
+
pred = self.query(query_points, hidden_states).squeeze(-1).float() # [B, N]
|
294 |
+
gt = data["query_gt"].float() # [B, N], in [-1, 1]
|
295 |
+
|
296 |
+
# main loss
|
297 |
+
loss_mse = F.mse_loss(pred, gt, reduction="mean")
|
298 |
+
loss += loss_mse
|
299 |
+
|
300 |
+
loss_l1 = F.l1_loss(pred, gt, reduction="mean")
|
301 |
+
loss += loss_l1
|
302 |
+
|
303 |
+
# kl loss
|
304 |
+
loss_kl = posterior.kl().mean()
|
305 |
+
loss += self.config.kl_weight * loss_kl
|
306 |
+
|
307 |
+
# metrics
|
308 |
+
with torch.no_grad():
|
309 |
+
output["scalar"] = {} # for wandb logging
|
310 |
+
output["scalar"]["loss_mse"] = loss_mse.detach()
|
311 |
+
output["scalar"]["loss_l1"] = loss_l1.detach()
|
312 |
+
output["scalar"]["loss_kl"] = loss_kl.detach()
|
313 |
+
output["scalar"]["iou_fg"] = calculate_iou(pred, gt, target_value=1)
|
314 |
+
output["scalar"]["iou_bg"] = calculate_iou(pred, gt, target_value=0)
|
315 |
+
output["scalar"]["precision"], output["scalar"]["recall"], output["scalar"]["f1"] = calculate_metrics(
|
316 |
+
pred, gt, target_value=1
|
317 |
+
)
|
318 |
+
|
319 |
+
return output, loss
|
320 |
+
|
321 |
+
@torch.no_grad()
|
322 |
+
def validation_step(
|
323 |
+
self,
|
324 |
+
data: dict[str, torch.Tensor],
|
325 |
+
iteration: int,
|
326 |
+
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
|
327 |
+
return self.training_step(data, iteration)
|
328 |
+
|
329 |
+
@torch.inference_mode()
|
330 |
+
@sync_timer("vae forward")
|
331 |
+
def forward(
|
332 |
+
self,
|
333 |
+
data: dict[str, torch.Tensor],
|
334 |
+
mode: Literal["dense", "hierarchical"] = "hierarchical",
|
335 |
+
max_samples_per_iter: int = 512**2,
|
336 |
+
resolution: int = 512,
|
337 |
+
min_resolution: int = 64, # for hierarchical
|
338 |
+
) -> dict[str, torch.Tensor]:
|
339 |
+
output = {}
|
340 |
+
|
341 |
+
# encode
|
342 |
+
if "latent" in data:
|
343 |
+
latent = data["latent"]
|
344 |
+
else:
|
345 |
+
posterior = self.encode(data)
|
346 |
+
output["posterior"] = posterior
|
347 |
+
latent = posterior.mode()
|
348 |
+
|
349 |
+
output["latent"] = latent
|
350 |
+
B = latent.shape[0]
|
351 |
+
|
352 |
+
# decode
|
353 |
+
hidden_states = self.decode(latent)
|
354 |
+
output["hidden_states"] = hidden_states # [B, N, hidden_dim] for the last cross-attention decoder
|
355 |
+
|
356 |
+
# the context norm can be moved out to avoid repeated computation
|
357 |
+
if self.config.use_flash_query:
|
358 |
+
hidden_states = self.norm_query_context(hidden_states)
|
359 |
+
|
360 |
+
# query
|
361 |
+
def chunked_query(grid_points):
|
362 |
+
if grid_points.shape[0] <= max_samples_per_iter:
|
363 |
+
return self.query(grid_points.unsqueeze(0), hidden_states).squeeze(-1) # [B, N]
|
364 |
+
all_pred = []
|
365 |
+
for i in range(0, grid_points.shape[0], max_samples_per_iter):
|
366 |
+
grid_chunk = grid_points[i : i + max_samples_per_iter]
|
367 |
+
pred_chunk = self.query(grid_chunk.unsqueeze(0), hidden_states)
|
368 |
+
all_pred.append(pred_chunk)
|
369 |
+
return torch.cat(all_pred, dim=1).squeeze(-1) # [B, N]
|
370 |
+
|
371 |
+
if mode == "dense":
|
372 |
+
grid_points = construct_grid_points(resolution).to(latent.device)
|
373 |
+
grid_points = grid_points.contiguous().view(-1, 3)
|
374 |
+
grid_vals = chunked_query(grid_points).float().view(B, resolution + 1, resolution + 1, resolution + 1)
|
375 |
+
|
376 |
+
elif mode == "hierarchical":
|
377 |
+
assert resolution >= min_resolution, "Resolution must be greater than or equal to min_resolution"
|
378 |
+
assert B == 1, "Only one batch is supported for hierarchical mode"
|
379 |
+
|
380 |
+
resolutions = []
|
381 |
+
res = resolution
|
382 |
+
while res >= min_resolution:
|
383 |
+
resolutions.append(res)
|
384 |
+
res = res // 2
|
385 |
+
resolutions.reverse() # e.g., [64, 128, 256, 512]
|
386 |
+
|
387 |
+
# dense-query the coarsest resolution
|
388 |
+
res = resolutions[0]
|
389 |
+
grid_points = construct_grid_points(res).to(latent.device)
|
390 |
+
grid_points = grid_points.contiguous().view(-1, 3)
|
391 |
+
grid_vals = chunked_query(grid_points).float().view(res + 1, res + 1, res + 1)
|
392 |
+
|
393 |
+
# sparse-query finer resolutions
|
394 |
+
dilate_kernel_3 = torch.ones(1, 1, 3, 3, 3, dtype=torch.float32, device=latent.device)
|
395 |
+
dilate_kernel_5 = torch.ones(1, 1, 5, 5, 5, dtype=torch.float32, device=latent.device)
|
396 |
+
for i in range(1, len(resolutions)):
|
397 |
+
res = resolutions[i]
|
398 |
+
# get the boundary grid mask in the coarser grid (where the grid_vals have different signs with at least one of its neighbors)
|
399 |
+
grid_signs = grid_vals >= 0
|
400 |
+
mask = torch.zeros_like(grid_signs)
|
401 |
+
mask[1:, :, :] += grid_signs[1:, :, :] != grid_signs[:-1, :, :]
|
402 |
+
mask[:-1, :, :] += grid_signs[:-1, :, :] != grid_signs[1:, :, :]
|
403 |
+
mask[:, 1:, :] += grid_signs[:, 1:, :] != grid_signs[:, :-1, :]
|
404 |
+
mask[:, :-1, :] += grid_signs[:, :-1, :] != grid_signs[:, 1:, :]
|
405 |
+
mask[:, :, 1:] += grid_signs[:, :, 1:] != grid_signs[:, :, :-1]
|
406 |
+
mask[:, :, :-1] += grid_signs[:, :, :-1] != grid_signs[:, :, 1:]
|
407 |
+
# empirical: also add those with abs(grid_vals) < 0.95
|
408 |
+
mask += grid_vals.abs() < 0.95
|
409 |
+
mask = (mask > 0).float()
|
410 |
+
# empirical: dilate the coarse mask
|
411 |
+
if res < 512:
|
412 |
+
mask = mask.unsqueeze(0).unsqueeze(0)
|
413 |
+
mask = F.conv3d(mask, weight=dilate_kernel_3, padding=1)
|
414 |
+
mask = mask.squeeze(0).squeeze(0)
|
415 |
+
# get the coarse coordinates
|
416 |
+
cidx_x, cidx_y, cidx_z = torch.nonzero(mask, as_tuple=True)
|
417 |
+
# fill to the fine indices
|
418 |
+
mask_fine = torch.zeros(res + 1, res + 1, res + 1, dtype=torch.float32, device=latent.device)
|
419 |
+
mask_fine[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
420 |
+
# empirical: dilate the fine mask
|
421 |
+
if res < 512:
|
422 |
+
mask_fine = mask_fine.unsqueeze(0).unsqueeze(0)
|
423 |
+
mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_3, padding=1)
|
424 |
+
mask_fine = mask_fine.squeeze(0).squeeze(0)
|
425 |
+
else:
|
426 |
+
mask_fine = mask_fine.unsqueeze(0).unsqueeze(0)
|
427 |
+
mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_5, padding=2)
|
428 |
+
mask_fine = mask_fine.squeeze(0).squeeze(0)
|
429 |
+
# get the fine coordinates
|
430 |
+
fidx_x, fidx_y, fidx_z = torch.nonzero(mask_fine, as_tuple=True)
|
431 |
+
# convert to float query points
|
432 |
+
query_points = torch.stack([fidx_x, fidx_y, fidx_z], dim=-1) # [N, 3]
|
433 |
+
query_points = query_points * 2 / res - 1 # [N, 3], in [-1, 1]
|
434 |
+
# query
|
435 |
+
pred = chunked_query(query_points).float()
|
436 |
+
# fill to the fine indices
|
437 |
+
grid_vals = torch.full((res + 1, res + 1, res + 1), -100.0, dtype=torch.float32, device=latent.device)
|
438 |
+
grid_vals[fidx_x, fidx_y, fidx_z] = pred
|
439 |
+
# print(f"[INFO] hierarchical: resolution: {res}, valid coarse points: {len(cidx_x)}, valid fine points: {len(fidx_x)}")
|
440 |
+
|
441 |
+
grid_vals = grid_vals.unsqueeze(0) # [1, res+1, res+1, res+1]
|
442 |
+
grid_vals[grid_vals <= -100.0] = float("nan") # use nans to ignore invalid regions
|
443 |
+
|
444 |
+
# extract mesh
|
445 |
+
meshes = []
|
446 |
+
for b in range(B):
|
447 |
+
vertices, faces = extract_mesh(grid_vals[b], resolution)
|
448 |
+
meshes.append((vertices, faces))
|
449 |
+
output["meshes"] = meshes
|
450 |
+
|
451 |
+
return output
|
vae/modules/__init__.py
ADDED
File without changes
|
vae/modules/attention.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from einops import rearrange
|
17 |
+
|
18 |
+
try:
|
19 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
20 |
+
from flash_attn.bert_padding import ( # , unpad_input # noqa
|
21 |
+
index_first_axis,
|
22 |
+
pad_input,
|
23 |
+
)
|
24 |
+
|
25 |
+
FLASH_ATTN_AVAILABLE = True
|
26 |
+
except Exception as e:
|
27 |
+
print("[WARN] flash_attn not available, using torch/naive implementation")
|
28 |
+
FLASH_ATTN_AVAILABLE = False
|
29 |
+
|
30 |
+
|
31 |
+
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L98
|
32 |
+
# flashattn 2.7.0 changes the API, we are overriding it here
|
33 |
+
def unpad_input(hidden_states, attention_mask):
|
34 |
+
"""
|
35 |
+
Arguments:
|
36 |
+
hidden_states: (batch, seqlen, ...)
|
37 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
38 |
+
Return:
|
39 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
40 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
41 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
42 |
+
max_seqlen_in_batch: int
|
43 |
+
"""
|
44 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
45 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
46 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
47 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
48 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
49 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
50 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
51 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
52 |
+
# so we write custom forward and backward to make it a bit faster.
|
53 |
+
return (
|
54 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
55 |
+
indices,
|
56 |
+
cu_seqlens,
|
57 |
+
max_seqlen_in_batch,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
def attention(q, k, v, mask_q=None, mask_kv=None, dropout=0, causal=False, window_size=(-1, -1), backend="torch"):
|
62 |
+
# q: (B, N, H, D)
|
63 |
+
# k: (B, M, H, D)
|
64 |
+
# v: (B, M, H, D)
|
65 |
+
# mask_q: (B, N)
|
66 |
+
# mask_kv: (B, M)
|
67 |
+
# return: (B, N, H, D)
|
68 |
+
|
69 |
+
B, N, H, D = q.shape
|
70 |
+
M = k.shape[1]
|
71 |
+
|
72 |
+
if causal:
|
73 |
+
assert N == 1 or N == M, "Causal mask only supports self-attention"
|
74 |
+
|
75 |
+
# unmasked case (usually inference)
|
76 |
+
# will ignore window_size except flash-attn impl. Only provide the effective window!
|
77 |
+
if mask_q is None and mask_kv is None:
|
78 |
+
if backend == "flash-attn" and FLASH_ATTN_AVAILABLE:
|
79 |
+
return flash_attn_func(q, k, v, dropout, causal=causal, window_size=window_size) # [B, N, H, D]
|
80 |
+
elif backend == "torch": # torch implementation
|
81 |
+
q = q.permute(0, 2, 1, 3)
|
82 |
+
k = k.permute(0, 2, 1, 3)
|
83 |
+
v = v.permute(0, 2, 1, 3)
|
84 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=dropout, is_causal=causal)
|
85 |
+
out = out.permute(0, 2, 1, 3).contiguous()
|
86 |
+
return out
|
87 |
+
else: # naive implementation
|
88 |
+
q = q.transpose(1, 2).reshape(B * H, N, D)
|
89 |
+
k = k.transpose(1, 2).reshape(B * H, M, D)
|
90 |
+
v = v.transpose(1, 2).reshape(B * H, M, D)
|
91 |
+
w = torch.bmm(q, k.transpose(1, 2)) / (D**0.5) # [B*H, N, M]
|
92 |
+
if causal and N > 1:
|
93 |
+
causal_mask = torch.full((N, M), float("-inf"), device=w.device, dtype=w.dtype)
|
94 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
95 |
+
w = w + causal_mask.unsqueeze(0)
|
96 |
+
w = F.softmax(w, dim=-1)
|
97 |
+
if dropout > 0:
|
98 |
+
w = F.dropout(w, p=dropout)
|
99 |
+
out = torch.bmm(w, v) # [B*H, N, D]
|
100 |
+
out = out.reshape(B, H, N, D).transpose(1, 2).contiguous() # [B, N, H, D]
|
101 |
+
return out
|
102 |
+
|
103 |
+
# at least one of q or kv is masked (training)
|
104 |
+
# only support flash-attn for now...
|
105 |
+
if mask_q is None:
|
106 |
+
mask_q = torch.ones(B, N, dtype=torch.bool, device=q.device)
|
107 |
+
elif mask_kv is None:
|
108 |
+
mask_kv = torch.ones(B, M, dtype=torch.bool, device=q.device)
|
109 |
+
|
110 |
+
if FLASH_ATTN_AVAILABLE:
|
111 |
+
# unpad (gather) input
|
112 |
+
# mask_q: [B, N], first row has N1 1s, second row has N2 1s, ...
|
113 |
+
# indices: [Ns,], Ns = N1 + N2 + ...
|
114 |
+
# cu_seqlens_q: [B+1,], (0, N1, N1+N2, ...), cu=cumulative
|
115 |
+
# max_len_q: scalar, max(N1, N2, ...)
|
116 |
+
q, indices_q, cu_seqlens_q, max_len_q = unpad_input(q, mask_q)
|
117 |
+
k, indices_kv, cu_seqlens_kv, max_len_kv = unpad_input(k, mask_kv)
|
118 |
+
v = index_first_axis(v.reshape(-1, H, D), indices_kv) # same indice as k
|
119 |
+
|
120 |
+
# call varlen_func
|
121 |
+
out = flash_attn_varlen_func(
|
122 |
+
q,
|
123 |
+
k,
|
124 |
+
v,
|
125 |
+
cu_seqlens_q=cu_seqlens_q,
|
126 |
+
cu_seqlens_k=cu_seqlens_kv,
|
127 |
+
max_seqlen_q=max_len_q,
|
128 |
+
max_seqlen_k=max_len_kv,
|
129 |
+
dropout_p=dropout,
|
130 |
+
causal=causal,
|
131 |
+
window_size=window_size,
|
132 |
+
)
|
133 |
+
|
134 |
+
# pad (put back) output
|
135 |
+
out = pad_input(out, indices_q, B, N)
|
136 |
+
return out
|
137 |
+
else:
|
138 |
+
raise NotImplementedError("masked attention requires flash_attn!")
|
139 |
+
|
140 |
+
|
141 |
+
class RMSNorm(nn.Module):
|
142 |
+
def __init__(self, dim, eps=1e-6):
|
143 |
+
super().__init__()
|
144 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
145 |
+
self.eps = eps
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
rnorm = torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
149 |
+
return (x * rnorm).to(dtype=self.weight.dtype) * self.weight
|
150 |
+
|
151 |
+
|
152 |
+
class SelfAttention(nn.Module):
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
hidden_dim,
|
156 |
+
num_heads,
|
157 |
+
input_dim=None,
|
158 |
+
output_dim=None,
|
159 |
+
dropout=0,
|
160 |
+
causal=False,
|
161 |
+
qknorm=False,
|
162 |
+
qknorm_type="LayerNorm",
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.hidden_dim = hidden_dim
|
166 |
+
self.input_dim = input_dim if input_dim is not None else hidden_dim
|
167 |
+
self.output_dim = output_dim if output_dim is not None else hidden_dim
|
168 |
+
self.num_heads = num_heads
|
169 |
+
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
|
170 |
+
self.head_dim = hidden_dim // num_heads
|
171 |
+
self.causal = causal
|
172 |
+
self.dropout = dropout
|
173 |
+
self.qknorm = qknorm
|
174 |
+
|
175 |
+
self.qkv_proj = nn.Linear(self.input_dim, 3 * self.hidden_dim)
|
176 |
+
self.out_proj = nn.Linear(self.hidden_dim, self.output_dim)
|
177 |
+
|
178 |
+
if self.qknorm:
|
179 |
+
if qknorm_type == "RMSNorm":
|
180 |
+
self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6)
|
181 |
+
self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6)
|
182 |
+
else:
|
183 |
+
self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
|
184 |
+
self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
|
185 |
+
|
186 |
+
def forward(self, x, mask=None):
|
187 |
+
# x: [B, N, C]
|
188 |
+
# mask: [B, N]
|
189 |
+
B, N, C = x.shape
|
190 |
+
qkv = self.qkv_proj(x) # [B, N, C] -> [B, N, 3 * D]
|
191 |
+
qkv = qkv.reshape(B, N, 3, -1).permute(2, 0, 1, 3) # [3, B, N, D]
|
192 |
+
q, k, v = qkv.chunk(3, dim=0) # [3, B, N, D] -> 3 * [1, B, N, D]
|
193 |
+
q = q.squeeze(0)
|
194 |
+
k = k.squeeze(0)
|
195 |
+
v = v.squeeze(0)
|
196 |
+
if self.qknorm:
|
197 |
+
q = self.q_norm(q)
|
198 |
+
k = self.k_norm(k)
|
199 |
+
q = q.reshape(B, N, self.num_heads, self.head_dim)
|
200 |
+
k = k.reshape(B, N, self.num_heads, self.head_dim)
|
201 |
+
v = v.reshape(B, N, self.num_heads, self.head_dim)
|
202 |
+
x = attention(q, k, v, mask_q=mask, mask_kv=mask, dropout=self.dropout, causal=self.causal) # [B, N, H, D]
|
203 |
+
x = self.out_proj(x.reshape(B, N, -1))
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class CrossAttention(nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
hidden_dim,
|
211 |
+
num_heads,
|
212 |
+
input_dim=None,
|
213 |
+
context_dim=None,
|
214 |
+
output_dim=None,
|
215 |
+
dropout=0,
|
216 |
+
qknorm=False,
|
217 |
+
qknorm_type="LayerNorm",
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.hidden_dim = hidden_dim
|
221 |
+
self.input_dim = input_dim if input_dim is not None else hidden_dim
|
222 |
+
self.context_dim = context_dim if context_dim is not None else hidden_dim
|
223 |
+
self.output_dim = output_dim if output_dim is not None else hidden_dim
|
224 |
+
self.num_heads = num_heads
|
225 |
+
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
|
226 |
+
self.head_dim = hidden_dim // num_heads
|
227 |
+
self.dropout = dropout
|
228 |
+
self.qknorm = qknorm
|
229 |
+
|
230 |
+
self.q_proj = nn.Linear(self.input_dim, self.hidden_dim)
|
231 |
+
self.k_proj = nn.Linear(self.context_dim, self.hidden_dim)
|
232 |
+
self.v_proj = nn.Linear(self.context_dim, self.hidden_dim)
|
233 |
+
self.out_proj = nn.Linear(self.hidden_dim, self.output_dim)
|
234 |
+
|
235 |
+
if self.qknorm:
|
236 |
+
if qknorm_type == "RMSNorm":
|
237 |
+
self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6)
|
238 |
+
self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6)
|
239 |
+
else:
|
240 |
+
self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
|
241 |
+
self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
|
242 |
+
|
243 |
+
def forward(self, x, context, mask_q=None, mask_kv=None):
|
244 |
+
# x: [B, N, C]
|
245 |
+
# context: [B, M, C']
|
246 |
+
# mask_q: [B, N]
|
247 |
+
# mask_kv: [B, M]
|
248 |
+
B, N, C = x.shape
|
249 |
+
M = context.shape[1]
|
250 |
+
q = self.q_proj(x)
|
251 |
+
k = self.k_proj(context)
|
252 |
+
v = self.v_proj(context)
|
253 |
+
if self.qknorm:
|
254 |
+
q = self.q_norm(q)
|
255 |
+
k = self.k_norm(k)
|
256 |
+
q = q.reshape(B, N, self.num_heads, self.head_dim)
|
257 |
+
k = k.reshape(B, M, self.num_heads, self.head_dim)
|
258 |
+
v = v.reshape(B, M, self.num_heads, self.head_dim)
|
259 |
+
x = attention(q, k, v, mask_q=mask_q, mask_kv=mask_kv, dropout=self.dropout, causal=False) # [B, N, H, D]
|
260 |
+
x = self.out_proj(x.reshape(B, N, -1))
|
261 |
+
return x
|
vae/modules/transformer.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.utils.checkpoint import checkpoint
|
15 |
+
|
16 |
+
from vae.modules.attention import CrossAttention, SelfAttention
|
17 |
+
|
18 |
+
|
19 |
+
class FeedForward(nn.Module):
|
20 |
+
def __init__(self, dim, mult=4):
|
21 |
+
super().__init__()
|
22 |
+
self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return self.net(x)
|
26 |
+
|
27 |
+
|
28 |
+
class AttentionBlock(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dim,
|
32 |
+
num_heads,
|
33 |
+
dim_context=None,
|
34 |
+
qknorm=False,
|
35 |
+
gradient_checkpointing=True,
|
36 |
+
qknorm_type="LayerNorm",
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.dim = dim
|
40 |
+
self.num_heads = num_heads
|
41 |
+
self.dim_context = dim_context
|
42 |
+
self.gradient_checkpointing = gradient_checkpointing
|
43 |
+
|
44 |
+
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
45 |
+
if dim_context is not None:
|
46 |
+
self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False)
|
47 |
+
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
|
48 |
+
else:
|
49 |
+
self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
|
50 |
+
|
51 |
+
self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
52 |
+
self.ff = FeedForward(dim)
|
53 |
+
|
54 |
+
def forward(self, x, c=None, mask=None, mask_c=None):
|
55 |
+
if self.training and self.gradient_checkpointing:
|
56 |
+
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
|
57 |
+
else:
|
58 |
+
return self._forward(x, c, mask, mask_c)
|
59 |
+
|
60 |
+
def _forward(self, x, c=None, mask=None, mask_c=None):
|
61 |
+
# x: [B, N, C], hidden states
|
62 |
+
# c: [B, M, C'], condition (assume normed and projected to C)
|
63 |
+
# mask: [B, N], mask for x
|
64 |
+
# mask_c: [B, M], mask for c
|
65 |
+
# return: [B, N, C], updated hidden states
|
66 |
+
|
67 |
+
if c is not None:
|
68 |
+
x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c)
|
69 |
+
else:
|
70 |
+
x = x + self.attn(self.norm_attn(x), mask=mask)
|
71 |
+
|
72 |
+
x = x + self.ff(self.norm_ff(x))
|
73 |
+
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
# special attention block for the last cross-attn query layer
|
78 |
+
# 1. simple feed-forward (mult=1, no post ln)
|
79 |
+
# 2. no residual connection
|
80 |
+
# 3. no context ln
|
81 |
+
class FlashQueryLayer(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim,
|
85 |
+
num_heads,
|
86 |
+
dim_context,
|
87 |
+
qknorm=False,
|
88 |
+
gradient_checkpointing=True,
|
89 |
+
qknorm_type="LayerNorm",
|
90 |
+
):
|
91 |
+
super().__init__()
|
92 |
+
self.dim = dim
|
93 |
+
self.num_heads = num_heads
|
94 |
+
self.dim_context = dim_context
|
95 |
+
self.gradient_checkpointing = gradient_checkpointing
|
96 |
+
|
97 |
+
self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
98 |
+
self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
|
99 |
+
self.ff = FeedForward(dim, mult=1)
|
100 |
+
|
101 |
+
def forward(self, x, c=None, mask=None, mask_c=None):
|
102 |
+
if self.training and self.gradient_checkpointing:
|
103 |
+
return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
|
104 |
+
else:
|
105 |
+
return self._forward(x, c, mask, mask_c)
|
106 |
+
|
107 |
+
def _forward(self, x, c, mask=None, mask_c=None):
|
108 |
+
# x: [B, N, C], hidden states
|
109 |
+
# c: [B, M, C'], condition (assume normed and projected to C)
|
110 |
+
# mask: [B, N], mask for x
|
111 |
+
# mask_c: [B, M], mask for c
|
112 |
+
# return: [B, N, C], updated hidden states
|
113 |
+
|
114 |
+
x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c)
|
115 |
+
x = self.ff(x)
|
116 |
+
|
117 |
+
return x
|
vae/scripts/infer.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import argparse
|
14 |
+
import glob
|
15 |
+
import importlib
|
16 |
+
import os
|
17 |
+
from datetime import datetime
|
18 |
+
|
19 |
+
import fpsample
|
20 |
+
import kiui
|
21 |
+
import meshiki
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import trimesh
|
25 |
+
|
26 |
+
from vae.model import Model
|
27 |
+
from vae.utils import box_normalize, postprocess_mesh, sphere_normalize, sync_timer
|
28 |
+
|
29 |
+
# PYTHONPATH=. python vae/scripts/infer.py
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument("--config", type=str, help="config file path", default="vae.configs.part_woenc")
|
32 |
+
parser.add_argument(
|
33 |
+
"--ckpt_path",
|
34 |
+
type=str,
|
35 |
+
help="checkpoint path",
|
36 |
+
default="pretrained/vae.pt",
|
37 |
+
)
|
38 |
+
parser.add_argument("--input", type=str, help="input directory", default="assets/meshes/")
|
39 |
+
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
|
40 |
+
parser.add_argument("--limit", type=int, help="how many samples to test", default=-1)
|
41 |
+
parser.add_argument("--num_fps_point", type=int, help="number of fps points", default=1024)
|
42 |
+
parser.add_argument("--num_fps_salient_point", type=int, help="number of fps salient points", default=1024)
|
43 |
+
parser.add_argument("--grid_res", type=int, help="grid resolution", default=512)
|
44 |
+
parser.add_argument("--seed", type=int, help="seed", default=42)
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
|
48 |
+
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
|
49 |
+
|
50 |
+
kiui.seed_everything(args.seed)
|
51 |
+
|
52 |
+
|
53 |
+
@sync_timer("prepare_input_from_mesh")
|
54 |
+
def prepare_input_from_mesh(mesh_path, use_salient_point=True, num_fps_point=1024, num_fps_salient_point=1024):
|
55 |
+
# load mesh, assume it's already processed to be watertight.
|
56 |
+
|
57 |
+
mesh_name = mesh_path.split("/")[-1].split(".")[0]
|
58 |
+
vertices, faces = meshiki.load_mesh(mesh_path)
|
59 |
+
|
60 |
+
# vertices = sphere_normalize(vertices)
|
61 |
+
vertices = box_normalize(vertices)
|
62 |
+
|
63 |
+
mesh = meshiki.Mesh(vertices, faces)
|
64 |
+
|
65 |
+
uniform_surface_points = mesh.uniform_point_sample(200000)
|
66 |
+
uniform_surface_points = meshiki.fps(uniform_surface_points, 32768) # hardcoded...
|
67 |
+
salient_surface_points = mesh.salient_point_sample(16384, thresh_bihedral=15)
|
68 |
+
|
69 |
+
# save points
|
70 |
+
# trimesh.PointCloud(vertices=uniform_surface_points).export(os.path.join(workspace, mesh_name + "_uniform.ply"))
|
71 |
+
# trimesh.PointCloud(vertices=salient_surface_points).export(os.path.join(workspace, mesh_name + "_salient.ply"))
|
72 |
+
|
73 |
+
sample = {}
|
74 |
+
|
75 |
+
sample["pointcloud"] = torch.from_numpy(uniform_surface_points)
|
76 |
+
|
77 |
+
# fps subsample
|
78 |
+
fps_indices = fpsample.bucket_fps_kdline_sampling(uniform_surface_points, num_fps_point, h=5, start_idx=0)
|
79 |
+
sample["fps_indices"] = torch.from_numpy(fps_indices).long() # [num_fps_point,]
|
80 |
+
|
81 |
+
if use_salient_point:
|
82 |
+
sample["pointcloud_dorases"] = torch.from_numpy(salient_surface_points) # [N', 3]
|
83 |
+
|
84 |
+
# fps subsample
|
85 |
+
fps_indices_dorases = fpsample.bucket_fps_kdline_sampling(
|
86 |
+
salient_surface_points, num_fps_salient_point, h=5, start_idx=0
|
87 |
+
)
|
88 |
+
sample["fps_indices_dorases"] = torch.from_numpy(fps_indices_dorases).long() # [num_fps_point,]
|
89 |
+
|
90 |
+
return sample
|
91 |
+
|
92 |
+
|
93 |
+
print(f"Loading checkpoint from {args.ckpt_path}")
|
94 |
+
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
|
95 |
+
|
96 |
+
# delete all keys other than model
|
97 |
+
if "model" in ckpt_dict:
|
98 |
+
ckpt_dict = ckpt_dict["model"]
|
99 |
+
|
100 |
+
# instantiate model
|
101 |
+
print(f"Instantiating model from {args.config}")
|
102 |
+
model_config = importlib.import_module(args.config).make_config()
|
103 |
+
model = Model(model_config).eval().cuda().bfloat16()
|
104 |
+
|
105 |
+
# load weight
|
106 |
+
print(f"Loading weights from {args.ckpt_path}")
|
107 |
+
model.load_state_dict(ckpt_dict, strict=True)
|
108 |
+
|
109 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
110 |
+
workspace = os.path.join(args.output_dir, "vae_" + args.config.split(".")[-1] + "_" + timestamp)
|
111 |
+
if not os.path.exists(workspace):
|
112 |
+
os.makedirs(workspace)
|
113 |
+
else:
|
114 |
+
os.system(f"rm {workspace}/*")
|
115 |
+
print(f"Output directory: {workspace}")
|
116 |
+
|
117 |
+
# load dataset
|
118 |
+
mesh_list = glob.glob(os.path.join(args.input, "*"))
|
119 |
+
mesh_list = mesh_list[: args.limit] if args.limit > 0 else mesh_list
|
120 |
+
|
121 |
+
for i, mesh_path in enumerate(mesh_list):
|
122 |
+
print(f"Processing {i}/{len(mesh_list)}: {mesh_path}")
|
123 |
+
|
124 |
+
mesh_name = mesh_path.split("/")[-1].split(".")[0]
|
125 |
+
|
126 |
+
sample = prepare_input_from_mesh(
|
127 |
+
mesh_path, num_fps_point=args.num_fps_point, num_fps_salient_point=args.num_fps_salient_point
|
128 |
+
)
|
129 |
+
for k in sample:
|
130 |
+
sample[k] = sample[k].unsqueeze(0).cuda()
|
131 |
+
|
132 |
+
# call vae
|
133 |
+
with torch.inference_mode():
|
134 |
+
output = model(sample, resolution=args.grid_res)
|
135 |
+
|
136 |
+
latent = output["latent"]
|
137 |
+
vertices, faces = output["meshes"][0]
|
138 |
+
|
139 |
+
mesh = trimesh.Trimesh(vertices, faces)
|
140 |
+
mesh = postprocess_mesh(mesh, 5e5)
|
141 |
+
|
142 |
+
mesh.export(f"{workspace}/{mesh_name}.glb")
|
vae/utils.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
-----------------------------------------------------------------------------
|
3 |
+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
|
5 |
+
NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
and proprietary rights in and to this software, related documentation
|
7 |
+
and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
distribution of this software and related documentation without an express
|
9 |
+
license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
-----------------------------------------------------------------------------
|
11 |
+
"""
|
12 |
+
|
13 |
+
import os
|
14 |
+
from functools import wraps
|
15 |
+
from typing import Literal
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import trimesh
|
20 |
+
from kiui.mesh_utils import clean_mesh, decimate_mesh
|
21 |
+
|
22 |
+
|
23 |
+
# Adapted from https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/utils.py#L38
|
24 |
+
class sync_timer:
|
25 |
+
"""
|
26 |
+
Synchronized timer to count the inference time of `nn.Module.forward` or else.
|
27 |
+
set env var TIMER=1 to enable logging!
|
28 |
+
|
29 |
+
Example as context manager:
|
30 |
+
```python
|
31 |
+
with timer('name'):
|
32 |
+
run()
|
33 |
+
```
|
34 |
+
|
35 |
+
Example as decorator:
|
36 |
+
```python
|
37 |
+
@timer('name')
|
38 |
+
def run():
|
39 |
+
pass
|
40 |
+
```
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, name=None, flag_env="TIMER"):
|
44 |
+
self.name = name
|
45 |
+
self.flag_env = flag_env
|
46 |
+
|
47 |
+
def __enter__(self):
|
48 |
+
if os.environ.get(self.flag_env, "0") == "1":
|
49 |
+
self.start = torch.cuda.Event(enable_timing=True)
|
50 |
+
self.end = torch.cuda.Event(enable_timing=True)
|
51 |
+
self.start.record()
|
52 |
+
return lambda: self.time
|
53 |
+
|
54 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
55 |
+
if os.environ.get(self.flag_env, "0") == "1":
|
56 |
+
self.end.record()
|
57 |
+
torch.cuda.synchronize()
|
58 |
+
self.time = self.start.elapsed_time(self.end)
|
59 |
+
if self.name is not None:
|
60 |
+
print(f"{self.name} takes {self.time} ms")
|
61 |
+
|
62 |
+
def __call__(self, func):
|
63 |
+
@wraps(func)
|
64 |
+
def wrapper(*args, **kwargs):
|
65 |
+
with self:
|
66 |
+
result = func(*args, **kwargs)
|
67 |
+
return result
|
68 |
+
|
69 |
+
return wrapper
|
70 |
+
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def calculate_iou(pred: torch.Tensor, gt: torch.Tensor, target_value: int, thresh: float = 0) -> torch.Tensor:
|
74 |
+
"""Calculate the Intersection over Union (IoU) between two volumes.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
pred (torch.Tensor): [*] continuous value between 0 and 1
|
78 |
+
gt (torch.Tensor): [*] discrete value of 0 or 1
|
79 |
+
target_value (int): The value to be considered as the target class
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
torch.Tensor: IoU value
|
83 |
+
"""
|
84 |
+
# Ensure volumes have the same shape
|
85 |
+
assert pred.shape == gt.shape, "Volumes must have the same shape"
|
86 |
+
|
87 |
+
# binarize
|
88 |
+
pred_binary = pred > thresh
|
89 |
+
gt = gt > thresh
|
90 |
+
|
91 |
+
# Convert the volumes to boolean tensors for logical operations
|
92 |
+
intersection = torch.logical_and(pred_binary == target_value, gt == target_value).sum().float()
|
93 |
+
union = torch.logical_or(pred_binary == target_value, gt == target_value).sum().float()
|
94 |
+
|
95 |
+
# Compute IoU
|
96 |
+
iou = intersection / union if union != 0 else torch.tensor(0.0)
|
97 |
+
return iou
|
98 |
+
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def calculate_metrics(
|
102 |
+
pred: torch.Tensor, gt: torch.Tensor, target_value: int = 1, thresh: float = 0.5
|
103 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
104 |
+
"""Calculate Precision, Recall, and F1 between two volumes.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
pred (torch.Tensor): [*] continuous value between 0 and 1
|
108 |
+
gt (torch.Tensor): [*] discrete value of 0 or 1
|
109 |
+
target_value (int): The value to be considered as the target class
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
tuple: Precision, Recall, F1 values
|
113 |
+
"""
|
114 |
+
assert pred.shape == gt.shape, f"Pred {pred.shape} and gt {gt.shape} must have the same shape"
|
115 |
+
|
116 |
+
# Binarize prediction
|
117 |
+
pred_binary = pred > thresh
|
118 |
+
gt = gt > thresh
|
119 |
+
|
120 |
+
# True Positive (TP): pred == target_value and gt == target_value
|
121 |
+
true_positive = torch.logical_and(pred_binary == target_value, gt == target_value).sum().float()
|
122 |
+
|
123 |
+
# False Positive (FP): pred == target_value and gt != target_value
|
124 |
+
false_positive = torch.logical_and(pred_binary == target_value, gt != target_value).sum().float()
|
125 |
+
|
126 |
+
# False Negative (FN): pred != target_value and gt == target_value
|
127 |
+
false_negative = torch.logical_and(pred_binary != target_value, gt == target_value).sum().float()
|
128 |
+
|
129 |
+
# Precision: TP / (TP + FP), best to detect False Positives
|
130 |
+
precision = (
|
131 |
+
true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else torch.tensor(0.0)
|
132 |
+
)
|
133 |
+
|
134 |
+
# Recall: TP / (TP + FN), best to detect False Negatives
|
135 |
+
recall = (
|
136 |
+
true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else torch.tensor(0.0)
|
137 |
+
)
|
138 |
+
|
139 |
+
# f1: 2 / (1 / precision + 1 / recall)
|
140 |
+
f1 = 2 / (1 / precision + 1 / recall) if (precision != 0 and recall != 0) else torch.tensor(0.0)
|
141 |
+
|
142 |
+
return precision, recall, f1
|
143 |
+
|
144 |
+
|
145 |
+
# Adapted from https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/distributions/distributions.py#L24
|
146 |
+
class DiagonalGaussianDistribution:
|
147 |
+
"""VAE latent"""
|
148 |
+
|
149 |
+
def __init__(self, mean, logvar, deterministic=False):
|
150 |
+
# mean, logvar: [B, L, D] x 2
|
151 |
+
self.mean, self.logvar = mean, logvar
|
152 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
153 |
+
self.deterministic = deterministic
|
154 |
+
self.std = torch.exp(0.5 * self.logvar)
|
155 |
+
self.var = torch.exp(self.logvar)
|
156 |
+
if self.deterministic:
|
157 |
+
self.var = self.std = torch.zeros_like(self.mean, device=self.mean.device, dtype=self.mean.dtype)
|
158 |
+
|
159 |
+
def sample(self, weight: float = 1.0):
|
160 |
+
sample = weight * torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype)
|
161 |
+
x = self.mean + self.std * sample
|
162 |
+
return x
|
163 |
+
|
164 |
+
def kl(self, other=None, dims=[1, 2]):
|
165 |
+
if self.deterministic:
|
166 |
+
return torch.Tensor([0.0])
|
167 |
+
else:
|
168 |
+
if other is None:
|
169 |
+
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
|
170 |
+
else:
|
171 |
+
return 0.5 * torch.mean(
|
172 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
173 |
+
+ self.var / other.var
|
174 |
+
- 1.0
|
175 |
+
- self.logvar
|
176 |
+
+ other.logvar,
|
177 |
+
dim=dims,
|
178 |
+
)
|
179 |
+
|
180 |
+
def nll(self, sample, dims=[1, 2]):
|
181 |
+
if self.deterministic:
|
182 |
+
return torch.Tensor([0.0])
|
183 |
+
logtwopi = np.log(2.0 * np.pi)
|
184 |
+
return 0.5 * torch.mean(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
|
185 |
+
|
186 |
+
def mode(self):
|
187 |
+
return self.mean
|
188 |
+
|
189 |
+
|
190 |
+
class DummyLatent:
|
191 |
+
def __init__(self, mean):
|
192 |
+
self.mean = mean
|
193 |
+
|
194 |
+
def sample(self, weight=0):
|
195 |
+
# simply perturb the mean
|
196 |
+
if weight > 0:
|
197 |
+
noise = torch.randn_like(self.mean) * weight
|
198 |
+
else:
|
199 |
+
noise = 0
|
200 |
+
return self.mean + noise
|
201 |
+
|
202 |
+
def mode(self):
|
203 |
+
return self.mean
|
204 |
+
|
205 |
+
def kl(self):
|
206 |
+
# just an l2 penalty
|
207 |
+
return 0.5 * torch.mean(torch.pow(self.mean, 2))
|
208 |
+
|
209 |
+
|
210 |
+
def construct_grid_points(
|
211 |
+
resolution: int,
|
212 |
+
indexing: str = "ij",
|
213 |
+
):
|
214 |
+
"""Generate dense grid points in [-1, 1]^3.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
resolution (int): resolution of the grid
|
218 |
+
indexing (str, optional): indexing of the grid. Defaults to "ij".
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
torch.Tensor: grid points (resolution + 1, resolution + 1, resolution + 1, 3), inside bbox.
|
222 |
+
"""
|
223 |
+
x = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
|
224 |
+
y = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
|
225 |
+
z = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
|
226 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
227 |
+
xyzs = np.stack((xs, ys, zs), axis=-1)
|
228 |
+
xyzs = torch.from_numpy(xyzs).float()
|
229 |
+
return xyzs
|
230 |
+
|
231 |
+
|
232 |
+
_diso_session = None # lazy session for reuse
|
233 |
+
|
234 |
+
|
235 |
+
@sync_timer("extract_mesh")
|
236 |
+
def extract_mesh(
|
237 |
+
grid_vals: torch.Tensor,
|
238 |
+
resolution: int,
|
239 |
+
isosurface_level: float = 0,
|
240 |
+
backend: Literal["mcubes", "diso"] = "mcubes",
|
241 |
+
):
|
242 |
+
"""Extract mesh from grid occupancy.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
grid_vals (torch.Tensor): [resolution + 1, resolution + 1, resolution + 1], assume to be TSDF in [-1, 1] (inner is positive)
|
246 |
+
resolution (int, optional): Grid resolution.
|
247 |
+
isosurface_level (float, optional): Iso-surface level. Defaults to 0.
|
248 |
+
backend (Literal["mcubes", "diso"], optional): Backend for mesh extraction. Defaults to "diso", which uses GPU and is faster.
|
249 |
+
Returns:
|
250 |
+
vertices (np.ndarray): [N, 3], float32, in [-1, 1]
|
251 |
+
faces (np.ndarray): [M, 3], int32
|
252 |
+
"""
|
253 |
+
|
254 |
+
grid_vals = grid_vals.view(resolution + 1, resolution + 1, resolution + 1)
|
255 |
+
|
256 |
+
if backend == "mcubes":
|
257 |
+
try:
|
258 |
+
import mcubes
|
259 |
+
except ImportError:
|
260 |
+
os.system("pip install pymcubes")
|
261 |
+
import mcubes
|
262 |
+
grid_vals = grid_vals.float().cpu().numpy()
|
263 |
+
verts, faces = mcubes.marching_cubes(grid_vals, isosurface_level)
|
264 |
+
verts = 2 * verts / resolution - 1.0 # normalize to [-1, 1]
|
265 |
+
elif backend == "diso":
|
266 |
+
try:
|
267 |
+
import diso
|
268 |
+
except ImportError:
|
269 |
+
os.system("pip install diso")
|
270 |
+
import diso
|
271 |
+
global _diso_session
|
272 |
+
if _diso_session is None:
|
273 |
+
_diso_session = diso.DiffDMC(dtype=torch.float32).cuda()
|
274 |
+
|
275 |
+
grid_vals = -grid_vals.float().cuda() # diso assumes inner is NEGATIVE!
|
276 |
+
verts, faces = _diso_session(grid_vals, deform=None, normalize=True) # verts in [0, 1]
|
277 |
+
verts = verts.cpu().numpy() * 2 - 1.0 # normalize to [-1, 1]
|
278 |
+
faces = faces.cpu().numpy()
|
279 |
+
|
280 |
+
return verts, faces
|
281 |
+
|
282 |
+
|
283 |
+
@sync_timer("postprocess_mesh")
|
284 |
+
def postprocess_mesh(mesh: trimesh.Trimesh, decimate_target=100000):
|
285 |
+
vertices = mesh.vertices
|
286 |
+
triangles = mesh.faces
|
287 |
+
|
288 |
+
if vertices.shape[0] > 0 and triangles.shape[0] > 0:
|
289 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
|
290 |
+
if triangles.shape[0] > decimate_target:
|
291 |
+
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
|
292 |
+
if vertices.shape[0] > 0 and triangles.shape[0] > 0:
|
293 |
+
vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
|
294 |
+
|
295 |
+
mesh.vertices = vertices
|
296 |
+
mesh.faces = triangles
|
297 |
+
|
298 |
+
return mesh
|
299 |
+
|
300 |
+
|
301 |
+
def sphere_normalize(vertices):
|
302 |
+
bmin = vertices.min(axis=0)
|
303 |
+
bmax = vertices.max(axis=0)
|
304 |
+
bcenter = (bmax + bmin) / 2
|
305 |
+
radius = np.linalg.norm(vertices - bcenter, axis=-1).max()
|
306 |
+
vertices = (vertices - bcenter) / radius # to [-1, 1]
|
307 |
+
return vertices
|
308 |
+
|
309 |
+
|
310 |
+
def box_normalize(vertices, bound=0.95):
|
311 |
+
bmin = vertices.min(axis=0)
|
312 |
+
bmax = vertices.max(axis=0)
|
313 |
+
bcenter = (bmax + bmin) / 2
|
314 |
+
vertices = bound * (vertices - bcenter) / (bmax - bmin).max()
|
315 |
+
return vertices
|