ashawkey commited on
Commit
daa6779
·
1 Parent(s): 40a41ad
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: PartPacker
3
- emoji: 📈
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
 
1
  ---
2
  title: PartPacker
3
+ emoji: 🪴
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import kiui
5
+ import trimesh
6
+ import torch
7
+ import rembg
8
+ from datetime import datetime
9
+ import subprocess
10
+ import gradio as gr
11
+
12
+ try:
13
+ # running on Hugging Face Spaces
14
+ import spaces
15
+
16
+ except ImportError:
17
+ # running locally, use a dummy space
18
+ class spaces:
19
+ class GPU:
20
+ def __init__(self, duration=60):
21
+ self.duration = duration
22
+ def __call__(self, func):
23
+ return func
24
+
25
+
26
+ from flow.model import Model
27
+ from flow.configs.schema import ModelConfig
28
+ from flow.utils import get_random_color, recenter_foreground
29
+ from vae.utils import postprocess_mesh
30
+
31
+ # download checkpoints
32
+ from huggingface_hub import hf_hub_download
33
+ flow_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="flow.pt")
34
+ vae_ckpt_path = hf_hub_download(repo_id="nvidia/PartPacker", filename="vae.pt")
35
+
36
+ TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
37
+ MAX_SEED = np.iinfo(np.int32).max
38
+ bg_remover = rembg.new_session()
39
+
40
+ # model config
41
+ model_config = ModelConfig(
42
+ vae_conf="vae.configs.part_woenc",
43
+ vae_ckpt_path=vae_ckpt_path,
44
+ qknorm=True,
45
+ qknorm_type="RMSNorm",
46
+ use_pos_embed=False,
47
+ dino_model="dinov2_vitg14",
48
+ hidden_dim=1536,
49
+ flow_shift=3.0,
50
+ logitnorm_mean=1.0,
51
+ logitnorm_std=1.0,
52
+ latent_size=4096,
53
+ use_parts=True,
54
+ )
55
+
56
+ # instantiate model
57
+ model = Model(model_config).eval().cuda().bfloat16()
58
+
59
+ # load weight
60
+ ckpt_dict = torch.load(flow_ckpt_path, weights_only=True)
61
+ model.load_state_dict(ckpt_dict, strict=True)
62
+
63
+ # process function
64
+ @spaces.GPU(duration=120)
65
+ def process(input_image, input_num_steps=30, input_cfg_scale=7.5, grid_res=384, seed=42, randomize_seed=True):
66
+
67
+ # seed
68
+ if randomize_seed:
69
+ seed = np.random.randint(0, MAX_SEED)
70
+ kiui.seed_everything(seed)
71
+
72
+ # output path
73
+ os.makedirs("output", exist_ok=True)
74
+ output_glb_path = f"output/partpacker_{datetime.now().strftime('%Y%m%d_%H%M%S')}.glb"
75
+
76
+ # input image
77
+ input_image = np.array(input_image) # uint8
78
+
79
+ # bg removal if there is no alpha channel
80
+ if input_image.shape[-1] == 3:
81
+ input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
82
+ mask = input_image[..., -1] > 0
83
+ image = recenter_foreground(input_image, mask, border_ratio=0.1)
84
+ image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
85
+ image = image.astype(np.float32) / 255.0
86
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
87
+
88
+ image_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
89
+ data = {"cond_images": image_tensor}
90
+
91
+ with torch.inference_mode():
92
+ results = model(data, num_steps=input_num_steps, cfg_scale=input_cfg_scale)
93
+
94
+ latent = results["latent"]
95
+
96
+ # query mesh
97
+
98
+ data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
99
+ data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
100
+
101
+ with torch.inference_mode():
102
+ results_part0 = model.vae(data_part0, resolution=grid_res)
103
+ results_part1 = model.vae(data_part1, resolution=grid_res)
104
+
105
+ vertices, faces = results_part0["meshes"][0]
106
+ mesh_part0 = trimesh.Trimesh(vertices, faces)
107
+ mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
108
+ mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
109
+ parts = mesh_part0.split(only_watertight=False)
110
+
111
+ vertices, faces = results_part1["meshes"][0]
112
+ mesh_part1 = trimesh.Trimesh(vertices, faces)
113
+ mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
114
+ mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
115
+ parts.extend(mesh_part1.split(only_watertight=False))
116
+
117
+ # split connected components and assign different colors
118
+ for j, part in enumerate(parts):
119
+ # each component uses a random color
120
+ part.visual.vertex_colors = get_random_color(j, use_float=True)
121
+
122
+ mesh = trimesh.Scene(parts)
123
+ # export the whole mesh
124
+ mesh.export(output_glb_path)
125
+
126
+ return seed, image, output_glb_path
127
+
128
+ # gradio UI
129
+
130
+ _TITLE = '''PartPacker: Efficient Part-level 3D Object Generation via Dual Volume Packing'''
131
+
132
+ _DESCRIPTION = '''
133
+ <div>
134
+ <a style="display:inline-block" href="https://research.nvidia.com/labs/dir/partpacker/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
135
+ <a style="display:inline-block; margin-left: .5em" href="https://github.com/NVlabs/PartPacker"><img src='https://img.shields.io/github/stars/NVlabs/PartPacker?style=social'/></a>
136
+ </div>
137
+
138
+ * Each part is visualized with a random color, and can be separated in the GLB file.
139
+ * If the output is not satisfactory, please try different random seeds!
140
+ '''
141
+
142
+ block = gr.Blocks(title=_TITLE).queue()
143
+ with block:
144
+ with gr.Row():
145
+ with gr.Column(scale=1):
146
+ gr.Markdown('# ' + _TITLE)
147
+ gr.Markdown(_DESCRIPTION)
148
+
149
+ with gr.Row():
150
+ with gr.Column(scale=2):
151
+ # input image
152
+ input_image = gr.Image(label="Image", type='pil')
153
+ # inference steps
154
+ input_num_steps = gr.Slider(label="Inference steps", minimum=1, maximum=100, step=1, value=30)
155
+ # cfg scale
156
+ input_cfg_scale = gr.Slider(label="CFG scale", minimum=2, maximum=10, step=0.1, value=7.5)
157
+ # grid resolution
158
+ input_grid_res = gr.Slider(label="Grid resolution", minimum=256, maximum=512, step=1, value=384)
159
+ # random seed
160
+ seed = gr.Slider(label="Random seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
161
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
162
+ # gen button
163
+ button_gen = gr.Button("Generate")
164
+
165
+
166
+ with gr.Column(scale=4):
167
+ with gr.Tab("3D Model"):
168
+ # glb file
169
+ output_model = gr.Model3D(label="Geometry", height=380)
170
+
171
+ with gr.Tab("Input Image"):
172
+ # background removed image
173
+ output_image = gr.Image(interactive=False, show_label=False)
174
+
175
+
176
+ with gr.Column(scale=1):
177
+ gr.Examples(
178
+ examples=[
179
+ ["examples/barrel.png"],
180
+ ["examples/cactus.png"],
181
+ ["examples/cyan_car.png"],
182
+ ["examples/pickup.png"],
183
+ ["examples/swivelchair.png"],
184
+ ["examples/warhammer.png"],
185
+ ],
186
+ inputs=[input_image],
187
+ cache_examples=False,
188
+ )
189
+
190
+ button_gen.click(process, inputs=[input_image, input_num_steps, input_cfg_scale, input_grid_res, seed, randomize_seed], outputs=[seed, output_image, output_model])
191
+
192
+ block.launch()
examples/barrel.png ADDED

Git LFS Details

  • SHA256: 6efc5d01a6460ffe2aaf3f644f26d278ba7d4801476d17e98a212c588079c978
  • Pointer size: 131 Bytes
  • Size of remote file: 314 kB
examples/cactus.png ADDED

Git LFS Details

  • SHA256: b63b5a14ec5df6cf05ce537d8ba7eec8e67a9260d0f521e9b68417f86f7942ad
  • Pointer size: 131 Bytes
  • Size of remote file: 186 kB
examples/cyan_car.png ADDED

Git LFS Details

  • SHA256: 61dc2c1b2e940a9d2ecded4d7c60fe0249c8ca905a55029293e0f062b559f795
  • Pointer size: 130 Bytes
  • Size of remote file: 69.6 kB
examples/pickup.png ADDED

Git LFS Details

  • SHA256: 9f89940c0bf2dbaf6a48346b8b9861f56a1b002ad5452c14e0019149132873bd
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
examples/rabbit.png ADDED

Git LFS Details

  • SHA256: 7c06b1e3364b3417dd4e92eaa6a5978d01a9ab9b1d9b99d7b15b2393d7952fbd
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
examples/robot.png ADDED

Git LFS Details

  • SHA256: 0ebaf8657cdb7d233ee2661926cff32656b54809bfc20fb7ac4d5ed7aa71fb15
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
examples/swivelchair.png ADDED

Git LFS Details

  • SHA256: af28e5c01b48ca27bb8baf68d7522eaa109f8998588f02152f935c9e3e5e2b57
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
examples/teapot.png ADDED

Git LFS Details

  • SHA256: 967ec369e4bb45e835f7ac057d9406a46178ba6f2d6a5958788f805deb2fd5ec
  • Pointer size: 131 Bytes
  • Size of remote file: 204 kB
examples/warhammer.png ADDED

Git LFS Details

  • SHA256: bc63bda34774288092d069808b7cb28c9544dd253cfbcfb33a98b22c9ec19537
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
flow/__init__.py ADDED
File without changes
flow/configs/__init__.py ADDED
File without changes
flow/configs/big_parts_strict_pvae.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from flow.configs.schema import ModelConfig
14
+
15
+
16
+ def make_config():
17
+
18
+ model_config = ModelConfig(
19
+ vae_conf="vae.configs.part_woenc",
20
+ vae_ckpt_path="pretrained/vae.pt",
21
+ qknorm=True,
22
+ qknorm_type="RMSNorm",
23
+ use_pos_embed=False,
24
+ dino_model="dinov2_vitg14",
25
+ hidden_dim=1536,
26
+ flow_shift=3.0,
27
+ logitnorm_mean=1.0,
28
+ logitnorm_std=1.0,
29
+ latent_size=4096,
30
+ use_parts=True,
31
+ )
32
+
33
+ return model_config
flow/configs/schema.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from typing import Literal, Optional
14
+
15
+ import attrs
16
+
17
+
18
+ @attrs.define(slots=False)
19
+ class ModelConfig:
20
+ # vae
21
+ vae_conf: str = "vae.configs.part_woenc"
22
+ vae_ckpt_path: Optional[str] = None
23
+
24
+ # learn & generate parts
25
+ use_parts: bool = False
26
+ part_embed_mode: Literal["element", "part", "part2_only"] = "part2_only"
27
+ shuffle_parts: bool = False
28
+ use_num_parts_cond: bool = False
29
+
30
+ # flow matching hyper-params
31
+ flow_shift: float = 1.0
32
+ logitnorm_mean: float = 0.0
33
+ logitnorm_std: float = 1.0
34
+
35
+ # image encoder
36
+ dino_model: Literal["dinov2_vitl14_reg", "dinov2_vitg14"] = "dinov2_vitg14"
37
+
38
+ # backbone DiT
39
+ hidden_dim: int = 1536
40
+ num_heads: int = 16
41
+ num_layers: int = 24
42
+ qknorm: bool = True
43
+ qknorm_type: Literal["LayerNorm", "RMSNorm"] = "RMSNorm"
44
+ use_pos_embed: bool = False
45
+
46
+ # latent code
47
+ latent_size: Optional[int] = None # if None, will load from vae
48
+ latent_dim: Optional[int] = None
49
+
50
+ # preload vae weights
51
+ preload_vae: bool = True
52
+
53
+ # preload dinov2 weights
54
+ preload_dinov2: bool = True
55
+
56
+ # init weights from a pretrained checkpoint
57
+ pretrain_path: Optional[str] = None
flow/flow_matching.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+
17
+ class FlowMatchingScheduler:
18
+ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1):
19
+ # set timesteps
20
+ self.num_train_timesteps = num_train_timesteps
21
+ self.shift = shift
22
+
23
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
24
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
25
+
26
+ sigmas = timesteps / num_train_timesteps
27
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
28
+
29
+ self.sigmas = sigmas # 1 --> 0
30
+ self.timesteps = sigmas * num_train_timesteps # num_train_timesteps --> 1
31
+
32
+ # set device
33
+ def to(self, device):
34
+ self.sigmas = self.sigmas.to(device=device)
35
+ self.timesteps = self.timesteps.to(device=device)
36
+
37
+ # add random noise to latent during training
38
+ def add_noise(self, latent: torch.Tensor, logit_mean: float = 1.0, logit_std: float = 1.0):
39
+ # latent: [B, ...]
40
+ # timesteps: [B]
41
+ # return: [B, ...] noisy_latent, [B, ...] noise, [B] timesteps
42
+
43
+ # logit-normal sampling
44
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(latent.shape[0],), device=self.sigmas.device)
45
+ u = torch.nn.functional.sigmoid(u)
46
+
47
+ step_indices = (u * self.num_train_timesteps).long()
48
+ timesteps = self.timesteps[step_indices]
49
+
50
+ sigmas = self.sigmas[step_indices].flatten()
51
+
52
+ while len(sigmas.shape) < latent.ndim:
53
+ sigmas = sigmas.unsqueeze(-1)
54
+
55
+ noise = torch.randn_like(latent)
56
+ noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
57
+
58
+ return noisy_latent, noise, timesteps
flow/model.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import importlib
14
+
15
+ from transformers import Dinov2Model
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import tqdm
21
+ from torchvision import transforms
22
+
23
+ from flow.configs.schema import ModelConfig
24
+ from flow.flow_matching import FlowMatchingScheduler
25
+ from flow.modules.dit import DiT
26
+ from vae.model import Model as VAE
27
+ from vae.utils import sync_timer
28
+
29
+
30
+ class Model(nn.Module):
31
+ def __init__(self, config: ModelConfig) -> None:
32
+ super().__init__()
33
+ self.config = config
34
+ self.precision = torch.bfloat16
35
+
36
+ # image condition model (dinov2)
37
+ if self.config.dino_model == "dinov2_vitg14":
38
+ self.dino = Dinov2Model.from_pretrained("facebook/dinov2-giant")
39
+ elif self.config.dino_model == "dinov2_vitl14_reg":
40
+ self.dino = Dinov2Model.from_pretrained("facebook/dinov2-with-registers-large")
41
+ else:
42
+ raise ValueError(f"DINOv2 model {self.config.dino_model} not supported")
43
+
44
+ # hack to match our implementation
45
+ self.dino.layernorm = torch.nn.Identity()
46
+
47
+ self.dino.eval().to(dtype=self.precision)
48
+ self.dino.requires_grad_(False)
49
+
50
+ cond_dim = 1024 if self.config.dino_model == "dinov2_vitl14_reg" else 1536
51
+ assert cond_dim == config.hidden_dim, "DINOv2 dim must match backbone dim"
52
+
53
+ self.preprocess_cond_image = transforms.Compose(
54
+ [
55
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
56
+ ]
57
+ )
58
+
59
+ # vae encoder
60
+ vae_config = importlib.import_module(config.vae_conf).make_config()
61
+ self.vae = VAE(vae_config).eval().to(dtype=self.precision)
62
+ self.vae.requires_grad_(False)
63
+
64
+ # load vae
65
+ if self.config.preload_vae:
66
+ try:
67
+ vae_ckpt = torch.load(self.config.vae_ckpt_path, weights_only=True) # local path
68
+ if "model" in vae_ckpt:
69
+ vae_ckpt = vae_ckpt["model"]
70
+ self.vae.load_state_dict(vae_ckpt, strict=True)
71
+ del vae_ckpt
72
+ print(f"Loaded VAE from {self.config.vae_ckpt_path}")
73
+ except Exception as e:
74
+ print(
75
+ f"Failed to load VAE from {self.config.vae_ckpt_path}: {e}, make sure you resumed from a valid checkpoint!"
76
+ )
77
+
78
+ # load info from vae config
79
+ if config.latent_size is None:
80
+ config.latent_size = self.vae.config.latent_size
81
+ if config.latent_dim is None:
82
+ config.latent_dim = self.vae.config.latent_dim
83
+
84
+ # dit
85
+ self.dit = DiT(
86
+ hidden_dim=config.hidden_dim,
87
+ num_heads=config.num_heads,
88
+ num_layers=config.num_layers,
89
+ latent_size=config.latent_size,
90
+ latent_dim=config.latent_dim,
91
+ qknorm=config.qknorm,
92
+ qknorm_type=config.qknorm_type,
93
+ use_pos_embed=config.use_pos_embed,
94
+ use_parts=config.use_parts,
95
+ part_embed_mode=config.part_embed_mode,
96
+ )
97
+
98
+ # num_part condition
99
+ if self.config.use_num_parts_cond:
100
+ assert self.config.use_parts, "use_num_parts_cond requires use_parts"
101
+ self.num_part_embed = nn.Embedding(5, config.hidden_dim)
102
+
103
+ # preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!)
104
+ if self.config.pretrain_path is not None:
105
+ try:
106
+ ckpt = torch.load(self.config.pretrain_path) # local path
107
+ self.load_state_dict(ckpt["model"], strict=True)
108
+ del ckpt
109
+ print(f"Loaded DiT from {self.config.pretrain_path}")
110
+ except Exception as e:
111
+ print(
112
+ f"Failed to load DiT from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!"
113
+ )
114
+
115
+ # sampler
116
+ self.scheduler = FlowMatchingScheduler(shift=config.flow_shift)
117
+
118
+ n_params = 0
119
+ for p in self.dit.parameters():
120
+ n_params += p.numel()
121
+ print(f"Number of parameters in DiT: {n_params/1e6:.2f}M")
122
+
123
+ # override state_dict to exclude vae and dino, so we only save the trainable params.
124
+ def state_dict(self, *args, **kwargs):
125
+ state_dict = super().state_dict(*args, **kwargs)
126
+
127
+ keys_to_del = []
128
+ for k in state_dict.keys():
129
+ if "vae" in k or "dino" in k:
130
+ keys_to_del.append(k)
131
+
132
+ for k in keys_to_del:
133
+ del state_dict[k]
134
+
135
+ return state_dict
136
+
137
+ # override to support tolerant loading (only load matched shape)
138
+ def load_state_dict(self, state_dict, strict=True, assign=False):
139
+ local_state_dict = self.state_dict()
140
+ seen_keys = {k: False for k in local_state_dict.keys()}
141
+ for k, v in state_dict.items():
142
+ if k in local_state_dict:
143
+ seen_keys[k] = True
144
+ if local_state_dict[k].shape == v.shape:
145
+ local_state_dict[k].copy_(v)
146
+ else:
147
+ print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
148
+ else:
149
+ print(f"unexpected key {k} in loaded state dict")
150
+ for k in seen_keys:
151
+ if not seen_keys[k]:
152
+ print(f"missing key {k} in loaded state dict")
153
+
154
+ # this happens before checkpointer loading old models !!!
155
+ def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
156
+ super().on_train_start(memory_format=memory_format)
157
+ device = next(self.dit.parameters()).device
158
+
159
+ self.dit.to(dtype=self.precision)
160
+
161
+ if self.config.use_num_parts_cond:
162
+ self.num_part_embed.to(dtype=self.precision)
163
+
164
+ # cast scheduler to device
165
+ self.scheduler.to(device)
166
+
167
+ def get_cond(self, cond_image, num_part=None):
168
+ # image condition
169
+ cond_image = cond_image.to(dtype=self.precision)
170
+ with torch.no_grad():
171
+ cond = self.dino(cond_image).last_hidden_state
172
+ cond = F.layer_norm(cond.float(), cond.shape[-1:]).to(dtype=self.precision) # [B, L, C]
173
+
174
+ # num_part condition
175
+ if self.config.use_num_parts_cond:
176
+ if num_part is None:
177
+ # use a default value (2-10 parts)
178
+ num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device) * 2
179
+ else:
180
+ # coarse range
181
+ num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device)
182
+ num_part_coarse[num_part == 2] = 1
183
+ num_part_coarse[(num_part > 2) & (num_part <= 10)] = 2
184
+ num_part_coarse[(num_part > 10) & (num_part <= 100)] = 3
185
+ num_part_coarse[num_part > 100] = 4
186
+ num_part_embed = self.num_part_embed(num_part_coarse).unsqueeze(1) # [B, 1, C]
187
+ cond = torch.cat([cond, num_part_embed], dim=1) # [B, L+1, C]
188
+
189
+ return cond
190
+
191
+ def training_step(
192
+ self,
193
+ data: dict[str, torch.Tensor],
194
+ iteration: int,
195
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
196
+ output = {}
197
+ loss = 0
198
+
199
+ cond_images = self.preprocess_cond_image(
200
+ data["cond_images"]
201
+ ) # [B, N, 3, 518, 518], we may load multiple (N) cond images for the same shape
202
+ B, N, C, H, W = cond_images.shape
203
+
204
+ if self.config.use_num_parts_cond:
205
+ cond_num_part = data["num_part"].repeat_interleave(N, dim=0)
206
+ else:
207
+ cond_num_part = None
208
+
209
+ cond = self.get_cond(cond_images.view(-1, C, H, W), cond_num_part) # [B*N, L, C]
210
+
211
+ # random CFG dropout
212
+ if self.training:
213
+ mask = torch.rand((B * N, 1, 1), device=cond.device, dtype=cond.dtype) >= 0.1
214
+ cond = cond * mask
215
+
216
+ with torch.no_grad():
217
+ # encode latent
218
+ if self.config.use_parts:
219
+ # encode two parts and concat latent
220
+ part0_data = {k.replace("_part0", ""): v for k, v in data.items() if "_part0" in k}
221
+ part1_data = {k.replace("_part1", ""): v for k, v in data.items() if "_part1" in k}
222
+ posterior0 = self.vae.encode(part0_data)
223
+ posterior1 = self.vae.encode(part1_data)
224
+ if self.training and self.config.shuffle_parts:
225
+ if np.random.rand() < 0.5:
226
+ posterior0, posterior1 = posterior1, posterior0
227
+ latent = torch.cat(
228
+ [
229
+ posterior0.mode().float().nan_to_num_(0),
230
+ posterior1.mode().float().nan_to_num_(0),
231
+ ],
232
+ dim=1,
233
+ ) # [B, 2L, C]
234
+ else:
235
+ posterior = self.vae.encode(data)
236
+ latent = posterior.mode().float().nan_to_num_(0) # use mean as the latent, [B, L, C]
237
+
238
+ # repeat latent for each cond image
239
+ if N != 1:
240
+ latent = latent.repeat_interleave(N, dim=0)
241
+
242
+ # random sample timesteps and add noise
243
+ noisy_latent, noise, timesteps = self.scheduler.add_noise(
244
+ latent, self.config.logitnorm_mean, self.config.logitnorm_std
245
+ )
246
+
247
+ noisy_latent = noisy_latent.to(dtype=self.precision)
248
+ model_pred = self.dit(noisy_latent, cond, timesteps)
249
+
250
+ # flow-matching loss
251
+ target = noise - latent
252
+ loss = F.mse_loss(model_pred.float(), target.float())
253
+
254
+ # metrics
255
+ with torch.no_grad():
256
+ output["scalar"] = {} # for wandb logging
257
+ output["scalar"]["loss_mse"] = loss.detach()
258
+
259
+ return output, loss
260
+
261
+ @torch.no_grad()
262
+ def validation_step(
263
+ self,
264
+ data: dict[str, torch.Tensor],
265
+ iteration: int,
266
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
267
+ return self.training_step(data, iteration)
268
+
269
+ @torch.inference_mode()
270
+ @sync_timer("flow forward")
271
+ def forward(
272
+ self,
273
+ data: dict[str, torch.Tensor],
274
+ num_steps: int = 30,
275
+ cfg_scale: float = 7.0,
276
+ verbose: bool = True,
277
+ generator: torch.Generator | None = None,
278
+ ) -> dict[str, torch.Tensor]:
279
+ # the inference sampling
280
+ cond_images = self.preprocess_cond_image(data["cond_images"]) # [B, 3, 518, 518]
281
+ B = cond_images.shape[0]
282
+ assert B == 1, "Only support batch size 1 for now."
283
+
284
+ # num_part condition
285
+ if self.config.use_num_parts_cond and "num_part" in data:
286
+ cond_num_part = data["num_part"] # [B,], int
287
+ else:
288
+ cond_num_part = None
289
+
290
+ cond = self.get_cond(cond_images, cond_num_part)
291
+
292
+ if self.config.use_parts:
293
+ x = torch.randn(
294
+ B,
295
+ self.config.latent_size * 2,
296
+ self.config.latent_dim,
297
+ device=cond.device,
298
+ dtype=torch.float32,
299
+ generator=generator,
300
+ )
301
+ else:
302
+ x = torch.randn(
303
+ B,
304
+ self.config.latent_size,
305
+ self.config.latent_dim,
306
+ device=cond.device,
307
+ dtype=torch.float32,
308
+ generator=generator,
309
+ )
310
+
311
+ cond_input = torch.cat([cond, torch.zeros_like(cond)], dim=0)
312
+
313
+ # flow-matching
314
+ sigmas = np.linspace(1, 0, num_steps + 1)
315
+ sigmas = self.scheduler.shift * sigmas / (1 + (self.scheduler.shift - 1) * sigmas)
316
+ sigmas_pair = list((sigmas[i], sigmas[i + 1]) for i in range(num_steps))
317
+
318
+ for sigma, sigma_prev in tqdm.tqdm(sigmas_pair, desc="Flow Sampling", disable=not verbose):
319
+ # classifier-free guidance
320
+ timesteps = torch.tensor([1000 * sigma] * B * 2, device=x.device, dtype=x.dtype)
321
+ x_input = torch.cat([x, x], dim=0)
322
+
323
+ # predict v
324
+ x_input = x_input.to(dtype=self.precision)
325
+ pred = self.dit(x_input, cond_input, timesteps).float()
326
+ cond_v, uncond_v = pred.chunk(2, dim=0)
327
+ pred_v = uncond_v + (cond_v - uncond_v) * cfg_scale
328
+
329
+ # scheduler step
330
+ x = x - (sigma - sigma_prev) * pred_v
331
+
332
+ output = {}
333
+ output["latent"] = x
334
+
335
+ # leave mesh extraction to vae
336
+ return output
flow/modules/__init__.py ADDED
File without changes
flow/modules/dit.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.checkpoint import checkpoint
18
+
19
+ from vae.modules.attention import CrossAttention, SelfAttention
20
+
21
+
22
+ class FeedForward(nn.Module):
23
+ def __init__(self, dim, mult=4):
24
+ super().__init__()
25
+ self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
26
+
27
+ def forward(self, x):
28
+ return self.net(x)
29
+
30
+
31
+ # Adapted from https://github.com/facebookresearch/DiT/blob/main/models.py#L27
32
+ class TimestepEmbedder(nn.Module):
33
+ """
34
+ Embeds scalar timesteps into vector representations.
35
+ """
36
+
37
+ def __init__(self, hidden_size, frequency_embedding_size=256):
38
+ super().__init__()
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
41
+ nn.SiLU(),
42
+ nn.Linear(hidden_size, hidden_size, bias=True),
43
+ )
44
+ self.frequency_embedding_size = frequency_embedding_size
45
+
46
+ @staticmethod
47
+ def timestep_embedding(t, dim, max_period=10000):
48
+ """
49
+ Create sinusoidal timestep embeddings.
50
+
51
+ Args:
52
+ t: a 1-D Tensor of N indices, one per batch element.
53
+ These may be fractional.
54
+ dim: the dimension of the output.
55
+ max_period: controls the minimum frequency of the embeddings.
56
+
57
+ Returns:
58
+ an (N, D) Tensor of positional embeddings.
59
+ """
60
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
61
+ half = dim // 2
62
+ freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
63
+ device=t.device
64
+ )
65
+ args = t[:, None].float() * freqs[None]
66
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67
+ if dim % 2:
68
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
69
+ return embedding
70
+
71
+ def forward(self, t):
72
+ dtype = next(self.mlp.parameters()).dtype # need to determine on the fly...
73
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
74
+ t_freq = t_freq.to(dtype=dtype)
75
+ t_emb = self.mlp(t_freq)
76
+ return t_emb
77
+
78
+
79
+ class DiTLayer(nn.Module):
80
+ def __init__(self, dim, num_heads, qknorm=False, gradient_checkpointing=True, qknorm_type="LayerNorm"):
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.num_heads = num_heads
84
+ self.gradient_checkpointing = gradient_checkpointing
85
+
86
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
87
+ self.attn1 = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
88
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
89
+ self.attn2 = CrossAttention(dim, num_heads, context_dim=dim, qknorm=qknorm, qknorm_type=qknorm_type)
90
+ self.norm3 = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
91
+ self.ff = FeedForward(dim)
92
+ self.adaln_linear = nn.Linear(dim, dim * 6, bias=True)
93
+
94
+ def forward(self, x, c, t_emb):
95
+ if self.training and self.gradient_checkpointing:
96
+ return checkpoint(self._forward, x, c, t_emb, use_reentrant=False)
97
+ else:
98
+ return self._forward(x, c, t_emb)
99
+
100
+ def _forward(self, x, c, t_emb):
101
+ # x: [B, N, C], hidden states
102
+ # c: [B, M, C], condition (assume normed and projected to C)
103
+ # t_emb: [B, C], timestep embedding of adaln
104
+ # return: [B, N, C], updated hidden states
105
+
106
+ B, N, C = x.shape
107
+ t_adaln = self.adaln_linear(F.silu(t_emb)).view(B, 6, -1) # [B, 6, C]
108
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_adaln.chunk(6, dim=1)
109
+
110
+ h = self.norm1(x)
111
+ h = h * (1 + scale_msa) + shift_msa
112
+ x = x + gate_msa * self.attn1(h)
113
+
114
+ h = self.norm2(x)
115
+ x = x + self.attn2(h, c)
116
+
117
+ h = self.norm3(x)
118
+ h = h * (1 + scale_mlp) + shift_mlp
119
+ x = x + gate_mlp * self.ff(h)
120
+
121
+ return x
122
+
123
+
124
+ class DiT(nn.Module):
125
+ def __init__(
126
+ self,
127
+ hidden_dim=1024,
128
+ num_heads=16,
129
+ latent_size=2048,
130
+ latent_dim=8,
131
+ num_layers=24,
132
+ qknorm=False,
133
+ gradient_checkpointing=True,
134
+ qknorm_type="LayerNorm",
135
+ use_pos_embed=False,
136
+ use_parts=False,
137
+ part_embed_mode="part2_only",
138
+ ):
139
+ super().__init__()
140
+
141
+ # project in
142
+ self.proj_in = nn.Linear(latent_dim, hidden_dim)
143
+
144
+ # positional encoding (just use a learnable positional encoding)
145
+ self.use_pos_embed = use_pos_embed
146
+ if self.use_pos_embed:
147
+ self.pos_embed = nn.Parameter(torch.randn(1, latent_size, hidden_dim) / hidden_dim**0.5)
148
+
149
+ # part encoding (a must to distinguish parts!)
150
+ self.use_parts = use_parts
151
+ self.part_embed_mode = part_embed_mode
152
+ if self.use_parts:
153
+ if self.part_embed_mode == "element":
154
+ self.part_embed = nn.Parameter(torch.randn(latent_size, hidden_dim) / hidden_dim**0.5)
155
+ elif self.part_embed_mode == "part":
156
+ self.part_embed = nn.Parameter(torch.randn(2, hidden_dim))
157
+ elif self.part_embed_mode == "part2_only":
158
+ # we only add this to the second part to distinguish from the first part
159
+ self.part_embed = nn.Parameter(torch.randn(1, hidden_dim) / hidden_dim**0.5)
160
+
161
+ # timestep encoding
162
+ self.timestep_embed = TimestepEmbedder(hidden_dim)
163
+
164
+ # transformer layers
165
+ self.layers = nn.ModuleList(
166
+ [DiTLayer(hidden_dim, num_heads, qknorm, gradient_checkpointing, qknorm_type) for _ in range(num_layers)]
167
+ )
168
+
169
+ # project out
170
+ self.norm_out = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
171
+ self.proj_out = nn.Linear(hidden_dim, latent_dim)
172
+
173
+ # init
174
+ self.init_weight()
175
+
176
+ def init_weight(self):
177
+ # Initialize transformer layers
178
+ def _basic_init(module):
179
+ if isinstance(module, nn.Linear):
180
+ torch.nn.init.xavier_uniform_(module.weight)
181
+ if module.bias is not None:
182
+ nn.init.constant_(module.bias, 0)
183
+
184
+ self.apply(_basic_init)
185
+
186
+ # Initialize timestep embedding MLP:
187
+ nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
188
+ nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
189
+
190
+ # Zero-out adaLN modulation layers in DiT blocks:
191
+ for layer in self.layers:
192
+ nn.init.constant_(layer.adaln_linear.weight, 0)
193
+ nn.init.constant_(layer.adaln_linear.bias, 0)
194
+
195
+ # Zero-out output layers:
196
+ nn.init.constant_(self.proj_out.weight, 0)
197
+ nn.init.constant_(self.proj_out.bias, 0)
198
+
199
+ def forward(self, x, c, t):
200
+ # x: [B, N, C], hidden states
201
+ # c: [B, M, C], condition (assume normed and projected to C)
202
+ # t: [B,], timestep
203
+ # return: [B, N, C], updated hidden states
204
+
205
+ B, N, C = x.shape
206
+
207
+ # project in
208
+ x = self.proj_in(x)
209
+
210
+ # positional encoding
211
+ if self.use_pos_embed:
212
+ x = x + self.pos_embed
213
+
214
+ # part encoding
215
+ if self.use_parts:
216
+ if self.part_embed_mode == "element":
217
+ x += self.part_embed
218
+ elif self.part_embed_mode == "part":
219
+ x[:, : x.shape[1] // 2, :] += self.part_embed[0]
220
+ x[:, x.shape[1] // 2 :, :] += self.part_embed[1]
221
+ elif self.part_embed_mode == "part2_only":
222
+ x[:, x.shape[1] // 2 :, :] += self.part_embed[0]
223
+
224
+ # timestep encoding
225
+ t_emb = self.timestep_embed(t) # [B, C]
226
+
227
+ # transformer layers
228
+ for layer in self.layers:
229
+ x = layer(x, c, t_emb)
230
+
231
+ # project out
232
+ x = self.norm_out(x)
233
+ x = self.proj_out(x)
234
+
235
+ return x
flow/scripts/infer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import argparse
14
+ import glob
15
+ import importlib
16
+ import os
17
+ from datetime import datetime
18
+
19
+ import cv2
20
+ import kiui
21
+ import numpy as np
22
+ import rembg
23
+ import torch
24
+ import trimesh
25
+
26
+ from flow.model import Model
27
+ from flow.utils import get_random_color, recenter_foreground
28
+ from vae.utils import postprocess_mesh
29
+
30
+ # PYTHONPATH=. python flow/scripts/infer.py
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument(
33
+ "--config",
34
+ type=str,
35
+ help="config file path",
36
+ default="flow.configs.big_parts_strict_pvae",
37
+ )
38
+ parser.add_argument(
39
+ "--ckpt_path",
40
+ type=str,
41
+ help="checkpoint path",
42
+ default="pretrained/flow.pt",
43
+ )
44
+ parser.add_argument("--input", type=str, help="input directory", default="assets/images/")
45
+ parser.add_argument("--limit", type=int, help="limit number of images", default=-1)
46
+ parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
47
+ parser.add_argument("--grid_res", type=int, help="grid resolution", default=384)
48
+ parser.add_argument("--num_steps", type=int, help="number of cfg steps", default=30)
49
+ parser.add_argument("--cfg_scale", type=float, help="cfg scale", default=7.0)
50
+ parser.add_argument("--num_repeats", type=int, help="number of repeats per image", default=1)
51
+ parser.add_argument("--seed", type=int, help="seed", default=42)
52
+ args = parser.parse_args()
53
+
54
+ TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
55
+
56
+ bg_remover = rembg.new_session()
57
+
58
+
59
+ def preprocess_image(path):
60
+ input_image = kiui.read_image(path, mode="uint8", order="RGBA")
61
+
62
+ # bg removal if there is no alpha channel
63
+ if input_image.shape[-1] == 3:
64
+ input_image = rembg.remove(input_image, session=bg_remover) # [H, W, 4]
65
+
66
+ mask = input_image[..., -1] > 0
67
+ image = recenter_foreground(input_image, mask, border_ratio=0.1)
68
+ image = cv2.resize(image, (518, 518), interpolation=cv2.INTER_LINEAR)
69
+ image = image.astype(np.float32) / 255.0
70
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) # white background
71
+ return image
72
+
73
+
74
+ print(f"Loading checkpoint from {args.ckpt_path}")
75
+ ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
76
+
77
+ # delete all keys other than model
78
+ if "model" in ckpt_dict:
79
+ ckpt_dict = ckpt_dict["model"]
80
+
81
+ # instantiate model
82
+ print(f"Instantiating model from {args.config}")
83
+ model_config = importlib.import_module(args.config).make_config()
84
+ model = Model(model_config).eval().cuda().bfloat16()
85
+
86
+ # load weight
87
+ print(f"Loading weights from {args.ckpt_path}")
88
+ model.load_state_dict(ckpt_dict, strict=True)
89
+
90
+ # output folder
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ workspace = os.path.join(args.output_dir, "flow_" + args.config.split(".")[-1] + "_" + timestamp)
93
+ if not os.path.exists(workspace):
94
+ os.makedirs(workspace)
95
+ else:
96
+ os.system(f"rm {workspace}/*")
97
+ print(f"Output directory: {workspace}")
98
+
99
+ # load test images
100
+ if os.path.isdir(args.input):
101
+ paths = glob.glob(os.path.join(args.input, "*"))
102
+ paths = sorted(paths)
103
+ if args.limit > 0:
104
+ paths = paths[: args.limit]
105
+ else: # single file
106
+ paths = [args.input]
107
+
108
+ for path in paths:
109
+ name = os.path.splitext(os.path.basename(path))[0]
110
+ print(f"Processing {name}")
111
+
112
+ image = preprocess_image(path)
113
+
114
+ kiui.write_image(os.path.join(workspace, name + ".jpg"), image)
115
+ image = torch.from_numpy(image).permute(2, 0, 1).contiguous().unsqueeze(0).float().cuda()
116
+
117
+ # run model
118
+ data = {"cond_images": image}
119
+
120
+ for i in range(args.num_repeats):
121
+
122
+ kiui.seed_everything(args.seed + i)
123
+
124
+ with torch.inference_mode():
125
+ results = model(data, num_steps=args.num_steps, cfg_scale=args.cfg_scale)
126
+
127
+ latent = results["latent"]
128
+ # kiui.lo(latent)
129
+
130
+ # query mesh
131
+ if model.config.use_parts:
132
+ data_part0 = {"latent": latent[:, : model.config.latent_size, :]}
133
+ data_part1 = {"latent": latent[:, model.config.latent_size :, :]}
134
+
135
+ with torch.inference_mode():
136
+ results_part0 = model.vae(data_part0, resolution=args.grid_res)
137
+ results_part1 = model.vae(data_part1, resolution=args.grid_res)
138
+
139
+ vertices, faces = results_part0["meshes"][0]
140
+ mesh_part0 = trimesh.Trimesh(vertices, faces)
141
+ mesh_part0.vertices = mesh_part0.vertices @ TRIMESH_GLB_EXPORT.T
142
+ mesh_part0 = postprocess_mesh(mesh_part0, 5e4)
143
+ parts = mesh_part0.split(only_watertight=False)
144
+
145
+ vertices, faces = results_part1["meshes"][0]
146
+ mesh_part1 = trimesh.Trimesh(vertices, faces)
147
+ mesh_part1.vertices = mesh_part1.vertices @ TRIMESH_GLB_EXPORT.T
148
+ mesh_part1 = postprocess_mesh(mesh_part1, 5e4)
149
+ parts.extend(mesh_part1.split(only_watertight=False))
150
+
151
+ # split connected components and assign different colors
152
+ for j, part in enumerate(parts):
153
+ # each component uses a random color
154
+ part.visual.vertex_colors = get_random_color(j, use_float=True)
155
+
156
+ mesh = trimesh.Scene(parts)
157
+ # export the whole mesh
158
+ mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
159
+
160
+ # export each part
161
+ for j, part in enumerate(parts):
162
+ part.export(os.path.join(workspace, name + "_" + str(i) + "_part" + str(j) + ".glb"))
163
+
164
+ # export dual volumes
165
+ mesh_part0.export(os.path.join(workspace, name + "_" + str(i) + "_vol0.glb"))
166
+ mesh_part1.export(os.path.join(workspace, name + "_" + str(i) + "_vol1.glb"))
167
+
168
+ else:
169
+ data = {"latent": latent}
170
+
171
+ with torch.inference_mode():
172
+ results = model.vae(data, resolution=args.grid_res)
173
+
174
+ vertices, faces = results["meshes"][0]
175
+ mesh = trimesh.Trimesh(vertices, faces)
176
+ mesh = postprocess_mesh(mesh, 5e4)
177
+
178
+ # kiui.lo(mesh.vertices, mesh.faces)
179
+ mesh.vertices = mesh.vertices @ TRIMESH_GLB_EXPORT.T
180
+ mesh.export(os.path.join(workspace, name + "_" + str(i) + ".glb"))
flow/utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from typing import Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+
18
+
19
+ def recenter_foreground(image, mask, border_ratio: float = 0.1):
20
+ """recenter an image to leave some empty space at the image border.
21
+
22
+ Args:
23
+ image (ndarray): input image, float/uint8 [H, W, 3/4]
24
+ mask (ndarray): alpha mask, bool [H, W]
25
+ border_ratio (float, optional): border ratio, image will be resized to (1 - border_ratio). Defaults to 0.1.
26
+
27
+ Returns:
28
+ ndarray: output image, float/uint8 [H, W, 3/4]
29
+ """
30
+
31
+ # empty foreground: just return
32
+ if mask.sum() == 0:
33
+ return image
34
+
35
+ return_int = False
36
+ if image.dtype == np.uint8:
37
+ image = image.astype(np.float32) / 255
38
+ return_int = True
39
+
40
+ H, W, C = image.shape
41
+ size = max(H, W)
42
+
43
+ # default to white bg if rgb, but use 0 if rgba
44
+ if C == 3:
45
+ result = np.ones((size, size, C), dtype=np.float32)
46
+ else:
47
+ result = np.zeros((size, size, C), dtype=np.float32)
48
+
49
+ coords = np.nonzero(mask)
50
+ x_min, x_max = coords[0].min(), coords[0].max()
51
+ y_min, y_max = coords[1].min(), coords[1].max()
52
+ h = x_max - x_min
53
+ w = y_max - y_min
54
+ desired_size = int(size * (1 - border_ratio))
55
+ scale = desired_size / max(h, w)
56
+ h2 = int(h * scale)
57
+ w2 = int(w * scale)
58
+ x2_min = (size - h2) // 2
59
+ x2_max = x2_min + h2
60
+ y2_min = (size - w2) // 2
61
+ y2_max = y2_min + w2
62
+ result[x2_min:x2_max, y2_min:y2_max] = cv2.resize(
63
+ image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA
64
+ )
65
+
66
+ if return_int:
67
+ result = (result * 255).astype(np.uint8)
68
+
69
+ return result
70
+
71
+
72
+ def get_random_color(index: Optional[int] = None, use_float: bool = False):
73
+ # some pleasing colors
74
+ # matplotlib.colormaps['Set3'].colors + matplotlib.colormaps['Set2'].colors + matplotlib.colormaps['Set1'].colors
75
+ palette = np.array(
76
+ [
77
+ [141, 211, 199, 255],
78
+ [255, 255, 179, 255],
79
+ [190, 186, 218, 255],
80
+ [251, 128, 114, 255],
81
+ [128, 177, 211, 255],
82
+ [253, 180, 98, 255],
83
+ [179, 222, 105, 255],
84
+ [252, 205, 229, 255],
85
+ [217, 217, 217, 255],
86
+ [188, 128, 189, 255],
87
+ [204, 235, 197, 255],
88
+ [255, 237, 111, 255],
89
+ [102, 194, 165, 255],
90
+ [252, 141, 98, 255],
91
+ [141, 160, 203, 255],
92
+ [231, 138, 195, 255],
93
+ [166, 216, 84, 255],
94
+ [255, 217, 47, 255],
95
+ [229, 196, 148, 255],
96
+ [179, 179, 179, 255],
97
+ [228, 26, 28, 255],
98
+ [55, 126, 184, 255],
99
+ [77, 175, 74, 255],
100
+ [152, 78, 163, 255],
101
+ [255, 127, 0, 255],
102
+ [255, 255, 51, 255],
103
+ [166, 86, 40, 255],
104
+ [247, 129, 191, 255],
105
+ [153, 153, 153, 255],
106
+ ],
107
+ dtype=np.uint8,
108
+ )
109
+
110
+ if index is None:
111
+ index = np.random.randint(0, len(palette))
112
+
113
+ if index >= len(palette):
114
+ index = index % len(palette)
115
+
116
+ if use_float:
117
+ return palette[index].astype(np.float32) / 255
118
+ else:
119
+ return palette[index]
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1 # should be installed by HF
2
+ torchvision==0.20.1
3
+ numpy
4
+ trimesh
5
+ # meshiki # we don't use it for flow inference
6
+ fpsample
7
+ einops
8
+ onnxruntime
9
+ rembg
10
+ kiui
11
+ pymcubes
12
+ tqdm
13
+ opencv-python
14
+ ninja
15
+ pymeshlab
16
+ transformers
vae/__init__.py ADDED
File without changes
vae/configs/__init__.py ADDED
File without changes
vae/configs/part_woenc.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from vae.configs.schema import ModelConfig
14
+
15
+
16
+ def make_config():
17
+
18
+ model_config = ModelConfig(
19
+ use_salient_point=True,
20
+ latent_size=4096,
21
+ cutoff_fps_point=(256, 512, 512, 512, 1024, 1024, 2048),
22
+ cutoff_fps_salient_point=(0, 0, 256, 512, 512, 1024, 2048),
23
+ cutoff_fps_prob=(0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.2),
24
+ kl_weight=1e-3,
25
+ salient_attn_mode="dual",
26
+ num_enc_layers=0,
27
+ num_dec_layers=24,
28
+ )
29
+
30
+ return model_config
vae/configs/schema.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from typing import Literal, Optional, Tuple
14
+
15
+ import attrs
16
+
17
+
18
+ @attrs.define(slots=False)
19
+ class ModelConfig:
20
+ # input
21
+ use_salient_point: bool = True
22
+
23
+ # random cutoff during training
24
+ cutoff_fps_point: Tuple[int, ...] = (256, 512, 512, 512, 1024, 1024, 2048)
25
+ cutoff_fps_salient_point: Tuple[int, ...] = (0, 0, 256, 512, 512, 1024, 2048)
26
+ cutoff_fps_prob: Tuple[float, ...] = (0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 0.2) # sum to 1.0
27
+
28
+ # backbone transformer
29
+ num_enc_layers: int = 0
30
+ hidden_dim: int = 1024
31
+ num_heads: int = 16
32
+ num_dec_layers: int = 24
33
+ dec_hidden_dim: int = 1024
34
+ dec_num_heads: int = 16
35
+ qknorm: bool = True
36
+ qknorm_type: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" # type of qknorm
37
+ salient_attn_mode: Literal["dual_shared", "single", "dual"] = "dual"
38
+
39
+ # query decoder
40
+ fourier_version: Literal["v1", "v2", "v3"] = "v3"
41
+ point_fourier_dim: int = 48 # must be divisible by 6 (sin/cos, x/y/z)
42
+ query_hidden_dim: int = 1024
43
+ query_num_heads: int = 16
44
+ use_flash_query: bool = False
45
+
46
+ # latent code
47
+ latent_size: int = 4096 # == num_fps_point + num_fps_salient_point
48
+ latent_dim: int = 64
49
+
50
+ # loss
51
+ use_ae: bool = False # if true, variance will be ignored, and kl_weight is used as a L2 norm weight
52
+ kl_weight: float = 1e-3
53
+
54
+ # init weights from a pretrained checkpoint
55
+ pretrain_path: Optional[str] = None
vae/model.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ from typing import Literal
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from vae.configs.schema import ModelConfig
21
+ from vae.modules.transformer import AttentionBlock, FlashQueryLayer
22
+ from vae.utils import (
23
+ DiagonalGaussianDistribution,
24
+ DummyLatent,
25
+ calculate_iou,
26
+ calculate_metrics,
27
+ construct_grid_points,
28
+ extract_mesh,
29
+ sync_timer,
30
+ )
31
+
32
+
33
+ class Model(nn.Module):
34
+ def __init__(self, config: ModelConfig) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+
38
+ self.precision = torch.bfloat16 # manually handle low-precision training, always use bf16
39
+
40
+ # point encoder
41
+ self.proj_input = nn.Linear(3 + config.point_fourier_dim, config.hidden_dim)
42
+
43
+ self.perceiver = AttentionBlock(
44
+ config.hidden_dim,
45
+ num_heads=config.num_heads,
46
+ dim_context=config.hidden_dim,
47
+ qknorm=config.qknorm,
48
+ qknorm_type=config.qknorm_type,
49
+ )
50
+
51
+ if self.config.salient_attn_mode == "dual":
52
+ self.perceiver_dorases = AttentionBlock(
53
+ config.hidden_dim,
54
+ num_heads=config.num_heads,
55
+ dim_context=config.hidden_dim,
56
+ qknorm=config.qknorm,
57
+ qknorm_type=config.qknorm_type,
58
+ )
59
+
60
+ # self-attention encoder
61
+ self.encoder = nn.ModuleList(
62
+ [
63
+ AttentionBlock(
64
+ config.hidden_dim, config.num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type
65
+ )
66
+ for _ in range(config.num_enc_layers)
67
+ ]
68
+ )
69
+
70
+ # vae bottleneck
71
+ self.norm_down = nn.LayerNorm(config.hidden_dim)
72
+ self.proj_down_mean = nn.Linear(config.hidden_dim, config.latent_dim)
73
+ if not self.config.use_ae:
74
+ self.proj_down_std = nn.Linear(config.hidden_dim, config.latent_dim)
75
+ self.proj_up = nn.Linear(config.latent_dim, config.dec_hidden_dim)
76
+
77
+ # self-attention decoder
78
+ self.decoder = nn.ModuleList(
79
+ [
80
+ AttentionBlock(
81
+ config.dec_hidden_dim, config.dec_num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type
82
+ )
83
+ for _ in range(config.num_dec_layers)
84
+ ]
85
+ )
86
+
87
+ # cross-attention query
88
+ self.proj_query = nn.Linear(3 + config.point_fourier_dim, config.query_hidden_dim)
89
+ if self.config.use_flash_query:
90
+ self.norm_query_context = nn.LayerNorm(config.hidden_dim, eps=1e-6, elementwise_affine=False)
91
+ self.attn_query = FlashQueryLayer(
92
+ config.query_hidden_dim,
93
+ num_heads=config.query_num_heads,
94
+ dim_context=config.hidden_dim,
95
+ qknorm=config.qknorm,
96
+ qknorm_type=config.qknorm_type,
97
+ )
98
+ else:
99
+ self.attn_query = AttentionBlock(
100
+ config.query_hidden_dim,
101
+ num_heads=config.query_num_heads,
102
+ dim_context=config.hidden_dim,
103
+ qknorm=config.qknorm,
104
+ qknorm_type=config.qknorm_type,
105
+ )
106
+ self.norm_out = nn.LayerNorm(config.query_hidden_dim)
107
+ self.proj_out = nn.Linear(config.query_hidden_dim, 1)
108
+
109
+ # preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!)
110
+ if self.config.pretrain_path is not None:
111
+ try:
112
+ ckpt = torch.load(self.config.pretrain_path) # local path
113
+ self.load_state_dict(ckpt["model"], strict=True)
114
+ del ckpt
115
+ print(f"Loaded VAE from {self.config.pretrain_path}")
116
+ except Exception as e:
117
+ print(
118
+ f"Failed to load VAE from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!"
119
+ )
120
+
121
+ # log
122
+ n_params = 0
123
+ for p in self.parameters():
124
+ n_params += p.numel()
125
+ print(f"Number of parameters in VAE: {n_params / 1e6:.2f}M")
126
+
127
+ # override to support tolerant loading (only load matched shape)
128
+ def load_state_dict(self, state_dict, strict=True, assign=False):
129
+ local_state_dict = self.state_dict()
130
+ seen_keys = {k: False for k in local_state_dict.keys()}
131
+ for k, v in state_dict.items():
132
+ if k in local_state_dict:
133
+ seen_keys[k] = True
134
+ if local_state_dict[k].shape == v.shape:
135
+ local_state_dict[k].copy_(v)
136
+ else:
137
+ print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}")
138
+ else:
139
+ print(f"unexpected key {k} in loaded state dict")
140
+ for k in seen_keys:
141
+ if not seen_keys[k]:
142
+ print(f"missing key {k} in loaded state dict")
143
+
144
+ def fourier_encoding(self, points: torch.Tensor):
145
+ # points: [B, N, 3], float32 for precision
146
+ # assert points.dtype == torch.float32, "Query points must be float32"
147
+
148
+ F = self.config.point_fourier_dim // (2 * points.shape[-1])
149
+
150
+ if self.config.fourier_version == "v1": # default
151
+ exponent = torch.arange(1, F + 1, device=points.device, dtype=torch.float32) / F # [F], range from 0 to 1
152
+ freq_band = 512**exponent # [F], min frequency is 1, max frequency is 1/freq
153
+ freq_band *= torch.pi
154
+ elif self.config.fourier_version == "v2":
155
+ exponent = torch.arange(F, device=points.device, dtype=torch.float32) / (F - 1) # [F], range from 0 to 1
156
+ freq_band = 1024**exponent # [F]
157
+ freq_band *= torch.pi
158
+ elif self.config.fourier_version == "v3": # hunyuan3d-2
159
+ freq_band = 2 ** torch.arange(F, device=points.device, dtype=torch.float32) # [F]
160
+
161
+ spectrum = points.unsqueeze(-1) * freq_band # [B,...,3,F]
162
+ sin, cos = spectrum.sin(), spectrum.cos() # [B,...,3,F]
163
+ input_enc = torch.stack([sin, cos], dim=-2) # [B,...,3,2,F]
164
+ input_enc = input_enc.view(*points.shape[:-1], -1) # [B,...,6F] = [B,...,dim]
165
+ return torch.cat([input_enc, points], dim=-1).to(dtype=self.precision) # [B,...,dim+input_dim]
166
+
167
+ def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
168
+ super().on_train_start(memory_format=memory_format)
169
+ self.to(dtype=self.precision, memory_format=memory_format) # use bfloat16 for training
170
+
171
+ def encode(self, data: dict[str, torch.Tensor]):
172
+ # uniform points
173
+ pointcloud = data["pointcloud"] # [B, N, 3]
174
+
175
+ # fourier embed and project
176
+ pointcloud = self.fourier_encoding(pointcloud) # [B, N, 3+C]
177
+ pointcloud = self.proj_input(pointcloud) # [B, N, hidden_dim]
178
+
179
+ # salient points
180
+ if self.config.use_salient_point:
181
+ pointcloud_dorases = data["pointcloud_dorases"] # [B, M, 3]
182
+
183
+ # fourier embed and project (shared weights)
184
+ pointcloud_dorases = self.fourier_encoding(pointcloud_dorases) # [B, M, 3+C]
185
+ pointcloud_dorases = self.proj_input(pointcloud_dorases) # [B, M, hidden_dim]
186
+
187
+ # gather fps point
188
+ fps_indices = data["fps_indices"] # [B, N']
189
+ pointcloud_query = torch.gather(pointcloud, 1, fps_indices.unsqueeze(-1).expand(-1, -1, pointcloud.shape[-1]))
190
+
191
+ if self.config.use_salient_point:
192
+ fps_indices_dorases = data["fps_indices_dorases"] # [B, M']
193
+
194
+ if fps_indices_dorases.shape[1] > 0:
195
+ pointcloud_query_dorases = torch.gather(
196
+ pointcloud_dorases,
197
+ 1,
198
+ fps_indices_dorases.unsqueeze(-1).expand(-1, -1, pointcloud_dorases.shape[-1]),
199
+ )
200
+
201
+ # combine both fps points as the query
202
+ pointcloud_query = torch.cat(
203
+ [pointcloud_query, pointcloud_query_dorases], dim=1
204
+ ) # [B, N'+M', hidden_dim]
205
+
206
+ # dual cross-attention
207
+ if self.config.salient_attn_mode == "dual_shared":
208
+ hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver(
209
+ pointcloud_query, pointcloud_dorases
210
+ ) # [B, N'+M', hidden_dim]
211
+ elif self.config.salient_attn_mode == "dual":
212
+ hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver_dorases(
213
+ pointcloud_query, pointcloud_dorases
214
+ )
215
+ else: # single, hunyuan3d-2 style
216
+ hidden_states = self.perceiver(pointcloud_query, torch.cat([pointcloud, pointcloud_dorases], dim=1))
217
+ else:
218
+ hidden_states = self.perceiver(pointcloud_query, pointcloud) # [B, N', hidden_dim]
219
+
220
+ # encoder
221
+ for block in self.encoder:
222
+ hidden_states = block(hidden_states)
223
+
224
+ # bottleneck
225
+ hidden_states = self.norm_down(hidden_states)
226
+ latent_mean = self.proj_down_mean(hidden_states).float()
227
+ if not self.config.use_ae:
228
+ latent_std = self.proj_down_std(hidden_states).float()
229
+ posterior = DiagonalGaussianDistribution(latent_mean, latent_std)
230
+ else:
231
+ posterior = DummyLatent(latent_mean)
232
+
233
+ return posterior
234
+
235
+ def decode(self, latent: torch.Tensor):
236
+ latent = latent.to(dtype=self.precision)
237
+ hidden_states = self.proj_up(latent)
238
+
239
+ for block in self.decoder:
240
+ hidden_states = block(hidden_states)
241
+
242
+ return hidden_states
243
+
244
+ def query(self, query_points: torch.Tensor, hidden_states: torch.Tensor):
245
+ # query_points: [B, N, 3], float32 to keep the precision
246
+
247
+ query_points = self.fourier_encoding(query_points) # [B, N, 3+C]
248
+ query_points = self.proj_query(query_points) # [B, N, hidden_dim]
249
+
250
+ # cross attention
251
+ query_output = self.attn_query(query_points, hidden_states) # [B, N, hidden_dim]
252
+
253
+ # output linear
254
+ query_output = self.norm_out(query_output)
255
+ pred = self.proj_out(query_output) # [B, N, 1]
256
+
257
+ return pred
258
+
259
+ def training_step(
260
+ self,
261
+ data: dict[str, torch.Tensor],
262
+ iteration: int,
263
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
264
+ output = {}
265
+
266
+ # cut off fps point during training for progressive flow
267
+ if self.training:
268
+ # randomly choose from a set of cutoff candidates
269
+ cutoff_index = np.random.choice(len(self.config.cutoff_fps_prob), p=self.config.cutoff_fps_prob)
270
+ cutoff_fps_point = self.config.cutoff_fps_point[cutoff_index]
271
+ cutoff_fps_salient_point = self.config.cutoff_fps_salient_point[cutoff_index]
272
+ # prefix of FPS points are still FPS points
273
+ data["fps_indices"] = data["fps_indices"][:, :cutoff_fps_point]
274
+ if self.config.use_salient_point:
275
+ data["fps_indices_dorases"] = data["fps_indices_dorases"][:, :cutoff_fps_salient_point]
276
+
277
+ loss = 0
278
+
279
+ # encode
280
+ posterior = self.encode(data)
281
+ latent_geom = posterior.sample() if self.training else posterior.mode()
282
+
283
+ # decode
284
+ hidden_states = self.decode(latent_geom)
285
+
286
+ # cross-attention query
287
+ query_points = data["query_points"] # [B, N, 3], float32
288
+
289
+ # the context norm can be moved out to avoid repeated computation
290
+ if self.config.use_flash_query:
291
+ hidden_states = self.norm_query_context(hidden_states)
292
+
293
+ pred = self.query(query_points, hidden_states).squeeze(-1).float() # [B, N]
294
+ gt = data["query_gt"].float() # [B, N], in [-1, 1]
295
+
296
+ # main loss
297
+ loss_mse = F.mse_loss(pred, gt, reduction="mean")
298
+ loss += loss_mse
299
+
300
+ loss_l1 = F.l1_loss(pred, gt, reduction="mean")
301
+ loss += loss_l1
302
+
303
+ # kl loss
304
+ loss_kl = posterior.kl().mean()
305
+ loss += self.config.kl_weight * loss_kl
306
+
307
+ # metrics
308
+ with torch.no_grad():
309
+ output["scalar"] = {} # for wandb logging
310
+ output["scalar"]["loss_mse"] = loss_mse.detach()
311
+ output["scalar"]["loss_l1"] = loss_l1.detach()
312
+ output["scalar"]["loss_kl"] = loss_kl.detach()
313
+ output["scalar"]["iou_fg"] = calculate_iou(pred, gt, target_value=1)
314
+ output["scalar"]["iou_bg"] = calculate_iou(pred, gt, target_value=0)
315
+ output["scalar"]["precision"], output["scalar"]["recall"], output["scalar"]["f1"] = calculate_metrics(
316
+ pred, gt, target_value=1
317
+ )
318
+
319
+ return output, loss
320
+
321
+ @torch.no_grad()
322
+ def validation_step(
323
+ self,
324
+ data: dict[str, torch.Tensor],
325
+ iteration: int,
326
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
327
+ return self.training_step(data, iteration)
328
+
329
+ @torch.inference_mode()
330
+ @sync_timer("vae forward")
331
+ def forward(
332
+ self,
333
+ data: dict[str, torch.Tensor],
334
+ mode: Literal["dense", "hierarchical"] = "hierarchical",
335
+ max_samples_per_iter: int = 512**2,
336
+ resolution: int = 512,
337
+ min_resolution: int = 64, # for hierarchical
338
+ ) -> dict[str, torch.Tensor]:
339
+ output = {}
340
+
341
+ # encode
342
+ if "latent" in data:
343
+ latent = data["latent"]
344
+ else:
345
+ posterior = self.encode(data)
346
+ output["posterior"] = posterior
347
+ latent = posterior.mode()
348
+
349
+ output["latent"] = latent
350
+ B = latent.shape[0]
351
+
352
+ # decode
353
+ hidden_states = self.decode(latent)
354
+ output["hidden_states"] = hidden_states # [B, N, hidden_dim] for the last cross-attention decoder
355
+
356
+ # the context norm can be moved out to avoid repeated computation
357
+ if self.config.use_flash_query:
358
+ hidden_states = self.norm_query_context(hidden_states)
359
+
360
+ # query
361
+ def chunked_query(grid_points):
362
+ if grid_points.shape[0] <= max_samples_per_iter:
363
+ return self.query(grid_points.unsqueeze(0), hidden_states).squeeze(-1) # [B, N]
364
+ all_pred = []
365
+ for i in range(0, grid_points.shape[0], max_samples_per_iter):
366
+ grid_chunk = grid_points[i : i + max_samples_per_iter]
367
+ pred_chunk = self.query(grid_chunk.unsqueeze(0), hidden_states)
368
+ all_pred.append(pred_chunk)
369
+ return torch.cat(all_pred, dim=1).squeeze(-1) # [B, N]
370
+
371
+ if mode == "dense":
372
+ grid_points = construct_grid_points(resolution).to(latent.device)
373
+ grid_points = grid_points.contiguous().view(-1, 3)
374
+ grid_vals = chunked_query(grid_points).float().view(B, resolution + 1, resolution + 1, resolution + 1)
375
+
376
+ elif mode == "hierarchical":
377
+ assert resolution >= min_resolution, "Resolution must be greater than or equal to min_resolution"
378
+ assert B == 1, "Only one batch is supported for hierarchical mode"
379
+
380
+ resolutions = []
381
+ res = resolution
382
+ while res >= min_resolution:
383
+ resolutions.append(res)
384
+ res = res // 2
385
+ resolutions.reverse() # e.g., [64, 128, 256, 512]
386
+
387
+ # dense-query the coarsest resolution
388
+ res = resolutions[0]
389
+ grid_points = construct_grid_points(res).to(latent.device)
390
+ grid_points = grid_points.contiguous().view(-1, 3)
391
+ grid_vals = chunked_query(grid_points).float().view(res + 1, res + 1, res + 1)
392
+
393
+ # sparse-query finer resolutions
394
+ dilate_kernel_3 = torch.ones(1, 1, 3, 3, 3, dtype=torch.float32, device=latent.device)
395
+ dilate_kernel_5 = torch.ones(1, 1, 5, 5, 5, dtype=torch.float32, device=latent.device)
396
+ for i in range(1, len(resolutions)):
397
+ res = resolutions[i]
398
+ # get the boundary grid mask in the coarser grid (where the grid_vals have different signs with at least one of its neighbors)
399
+ grid_signs = grid_vals >= 0
400
+ mask = torch.zeros_like(grid_signs)
401
+ mask[1:, :, :] += grid_signs[1:, :, :] != grid_signs[:-1, :, :]
402
+ mask[:-1, :, :] += grid_signs[:-1, :, :] != grid_signs[1:, :, :]
403
+ mask[:, 1:, :] += grid_signs[:, 1:, :] != grid_signs[:, :-1, :]
404
+ mask[:, :-1, :] += grid_signs[:, :-1, :] != grid_signs[:, 1:, :]
405
+ mask[:, :, 1:] += grid_signs[:, :, 1:] != grid_signs[:, :, :-1]
406
+ mask[:, :, :-1] += grid_signs[:, :, :-1] != grid_signs[:, :, 1:]
407
+ # empirical: also add those with abs(grid_vals) < 0.95
408
+ mask += grid_vals.abs() < 0.95
409
+ mask = (mask > 0).float()
410
+ # empirical: dilate the coarse mask
411
+ if res < 512:
412
+ mask = mask.unsqueeze(0).unsqueeze(0)
413
+ mask = F.conv3d(mask, weight=dilate_kernel_3, padding=1)
414
+ mask = mask.squeeze(0).squeeze(0)
415
+ # get the coarse coordinates
416
+ cidx_x, cidx_y, cidx_z = torch.nonzero(mask, as_tuple=True)
417
+ # fill to the fine indices
418
+ mask_fine = torch.zeros(res + 1, res + 1, res + 1, dtype=torch.float32, device=latent.device)
419
+ mask_fine[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
420
+ # empirical: dilate the fine mask
421
+ if res < 512:
422
+ mask_fine = mask_fine.unsqueeze(0).unsqueeze(0)
423
+ mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_3, padding=1)
424
+ mask_fine = mask_fine.squeeze(0).squeeze(0)
425
+ else:
426
+ mask_fine = mask_fine.unsqueeze(0).unsqueeze(0)
427
+ mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_5, padding=2)
428
+ mask_fine = mask_fine.squeeze(0).squeeze(0)
429
+ # get the fine coordinates
430
+ fidx_x, fidx_y, fidx_z = torch.nonzero(mask_fine, as_tuple=True)
431
+ # convert to float query points
432
+ query_points = torch.stack([fidx_x, fidx_y, fidx_z], dim=-1) # [N, 3]
433
+ query_points = query_points * 2 / res - 1 # [N, 3], in [-1, 1]
434
+ # query
435
+ pred = chunked_query(query_points).float()
436
+ # fill to the fine indices
437
+ grid_vals = torch.full((res + 1, res + 1, res + 1), -100.0, dtype=torch.float32, device=latent.device)
438
+ grid_vals[fidx_x, fidx_y, fidx_z] = pred
439
+ # print(f"[INFO] hierarchical: resolution: {res}, valid coarse points: {len(cidx_x)}, valid fine points: {len(fidx_x)}")
440
+
441
+ grid_vals = grid_vals.unsqueeze(0) # [1, res+1, res+1, res+1]
442
+ grid_vals[grid_vals <= -100.0] = float("nan") # use nans to ignore invalid regions
443
+
444
+ # extract mesh
445
+ meshes = []
446
+ for b in range(B):
447
+ vertices, faces = extract_mesh(grid_vals[b], resolution)
448
+ meshes.append((vertices, faces))
449
+ output["meshes"] = meshes
450
+
451
+ return output
vae/modules/__init__.py ADDED
File without changes
vae/modules/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
20
+ from flash_attn.bert_padding import ( # , unpad_input # noqa
21
+ index_first_axis,
22
+ pad_input,
23
+ )
24
+
25
+ FLASH_ATTN_AVAILABLE = True
26
+ except Exception as e:
27
+ print("[WARN] flash_attn not available, using torch/naive implementation")
28
+ FLASH_ATTN_AVAILABLE = False
29
+
30
+
31
+ # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L98
32
+ # flashattn 2.7.0 changes the API, we are overriding it here
33
+ def unpad_input(hidden_states, attention_mask):
34
+ """
35
+ Arguments:
36
+ hidden_states: (batch, seqlen, ...)
37
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
38
+ Return:
39
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
40
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
41
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
42
+ max_seqlen_in_batch: int
43
+ """
44
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
45
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
46
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
47
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
48
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
49
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
50
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
51
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
52
+ # so we write custom forward and backward to make it a bit faster.
53
+ return (
54
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
55
+ indices,
56
+ cu_seqlens,
57
+ max_seqlen_in_batch,
58
+ )
59
+
60
+
61
+ def attention(q, k, v, mask_q=None, mask_kv=None, dropout=0, causal=False, window_size=(-1, -1), backend="torch"):
62
+ # q: (B, N, H, D)
63
+ # k: (B, M, H, D)
64
+ # v: (B, M, H, D)
65
+ # mask_q: (B, N)
66
+ # mask_kv: (B, M)
67
+ # return: (B, N, H, D)
68
+
69
+ B, N, H, D = q.shape
70
+ M = k.shape[1]
71
+
72
+ if causal:
73
+ assert N == 1 or N == M, "Causal mask only supports self-attention"
74
+
75
+ # unmasked case (usually inference)
76
+ # will ignore window_size except flash-attn impl. Only provide the effective window!
77
+ if mask_q is None and mask_kv is None:
78
+ if backend == "flash-attn" and FLASH_ATTN_AVAILABLE:
79
+ return flash_attn_func(q, k, v, dropout, causal=causal, window_size=window_size) # [B, N, H, D]
80
+ elif backend == "torch": # torch implementation
81
+ q = q.permute(0, 2, 1, 3)
82
+ k = k.permute(0, 2, 1, 3)
83
+ v = v.permute(0, 2, 1, 3)
84
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=dropout, is_causal=causal)
85
+ out = out.permute(0, 2, 1, 3).contiguous()
86
+ return out
87
+ else: # naive implementation
88
+ q = q.transpose(1, 2).reshape(B * H, N, D)
89
+ k = k.transpose(1, 2).reshape(B * H, M, D)
90
+ v = v.transpose(1, 2).reshape(B * H, M, D)
91
+ w = torch.bmm(q, k.transpose(1, 2)) / (D**0.5) # [B*H, N, M]
92
+ if causal and N > 1:
93
+ causal_mask = torch.full((N, M), float("-inf"), device=w.device, dtype=w.dtype)
94
+ causal_mask = torch.triu(causal_mask, diagonal=1)
95
+ w = w + causal_mask.unsqueeze(0)
96
+ w = F.softmax(w, dim=-1)
97
+ if dropout > 0:
98
+ w = F.dropout(w, p=dropout)
99
+ out = torch.bmm(w, v) # [B*H, N, D]
100
+ out = out.reshape(B, H, N, D).transpose(1, 2).contiguous() # [B, N, H, D]
101
+ return out
102
+
103
+ # at least one of q or kv is masked (training)
104
+ # only support flash-attn for now...
105
+ if mask_q is None:
106
+ mask_q = torch.ones(B, N, dtype=torch.bool, device=q.device)
107
+ elif mask_kv is None:
108
+ mask_kv = torch.ones(B, M, dtype=torch.bool, device=q.device)
109
+
110
+ if FLASH_ATTN_AVAILABLE:
111
+ # unpad (gather) input
112
+ # mask_q: [B, N], first row has N1 1s, second row has N2 1s, ...
113
+ # indices: [Ns,], Ns = N1 + N2 + ...
114
+ # cu_seqlens_q: [B+1,], (0, N1, N1+N2, ...), cu=cumulative
115
+ # max_len_q: scalar, max(N1, N2, ...)
116
+ q, indices_q, cu_seqlens_q, max_len_q = unpad_input(q, mask_q)
117
+ k, indices_kv, cu_seqlens_kv, max_len_kv = unpad_input(k, mask_kv)
118
+ v = index_first_axis(v.reshape(-1, H, D), indices_kv) # same indice as k
119
+
120
+ # call varlen_func
121
+ out = flash_attn_varlen_func(
122
+ q,
123
+ k,
124
+ v,
125
+ cu_seqlens_q=cu_seqlens_q,
126
+ cu_seqlens_k=cu_seqlens_kv,
127
+ max_seqlen_q=max_len_q,
128
+ max_seqlen_k=max_len_kv,
129
+ dropout_p=dropout,
130
+ causal=causal,
131
+ window_size=window_size,
132
+ )
133
+
134
+ # pad (put back) output
135
+ out = pad_input(out, indices_q, B, N)
136
+ return out
137
+ else:
138
+ raise NotImplementedError("masked attention requires flash_attn!")
139
+
140
+
141
+ class RMSNorm(nn.Module):
142
+ def __init__(self, dim, eps=1e-6):
143
+ super().__init__()
144
+ self.weight = nn.Parameter(torch.ones(dim))
145
+ self.eps = eps
146
+
147
+ def forward(self, x):
148
+ rnorm = torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
149
+ return (x * rnorm).to(dtype=self.weight.dtype) * self.weight
150
+
151
+
152
+ class SelfAttention(nn.Module):
153
+ def __init__(
154
+ self,
155
+ hidden_dim,
156
+ num_heads,
157
+ input_dim=None,
158
+ output_dim=None,
159
+ dropout=0,
160
+ causal=False,
161
+ qknorm=False,
162
+ qknorm_type="LayerNorm",
163
+ ):
164
+ super().__init__()
165
+ self.hidden_dim = hidden_dim
166
+ self.input_dim = input_dim if input_dim is not None else hidden_dim
167
+ self.output_dim = output_dim if output_dim is not None else hidden_dim
168
+ self.num_heads = num_heads
169
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
170
+ self.head_dim = hidden_dim // num_heads
171
+ self.causal = causal
172
+ self.dropout = dropout
173
+ self.qknorm = qknorm
174
+
175
+ self.qkv_proj = nn.Linear(self.input_dim, 3 * self.hidden_dim)
176
+ self.out_proj = nn.Linear(self.hidden_dim, self.output_dim)
177
+
178
+ if self.qknorm:
179
+ if qknorm_type == "RMSNorm":
180
+ self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6)
181
+ self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6)
182
+ else:
183
+ self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
184
+ self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
185
+
186
+ def forward(self, x, mask=None):
187
+ # x: [B, N, C]
188
+ # mask: [B, N]
189
+ B, N, C = x.shape
190
+ qkv = self.qkv_proj(x) # [B, N, C] -> [B, N, 3 * D]
191
+ qkv = qkv.reshape(B, N, 3, -1).permute(2, 0, 1, 3) # [3, B, N, D]
192
+ q, k, v = qkv.chunk(3, dim=0) # [3, B, N, D] -> 3 * [1, B, N, D]
193
+ q = q.squeeze(0)
194
+ k = k.squeeze(0)
195
+ v = v.squeeze(0)
196
+ if self.qknorm:
197
+ q = self.q_norm(q)
198
+ k = self.k_norm(k)
199
+ q = q.reshape(B, N, self.num_heads, self.head_dim)
200
+ k = k.reshape(B, N, self.num_heads, self.head_dim)
201
+ v = v.reshape(B, N, self.num_heads, self.head_dim)
202
+ x = attention(q, k, v, mask_q=mask, mask_kv=mask, dropout=self.dropout, causal=self.causal) # [B, N, H, D]
203
+ x = self.out_proj(x.reshape(B, N, -1))
204
+ return x
205
+
206
+
207
+ class CrossAttention(nn.Module):
208
+ def __init__(
209
+ self,
210
+ hidden_dim,
211
+ num_heads,
212
+ input_dim=None,
213
+ context_dim=None,
214
+ output_dim=None,
215
+ dropout=0,
216
+ qknorm=False,
217
+ qknorm_type="LayerNorm",
218
+ ):
219
+ super().__init__()
220
+ self.hidden_dim = hidden_dim
221
+ self.input_dim = input_dim if input_dim is not None else hidden_dim
222
+ self.context_dim = context_dim if context_dim is not None else hidden_dim
223
+ self.output_dim = output_dim if output_dim is not None else hidden_dim
224
+ self.num_heads = num_heads
225
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
226
+ self.head_dim = hidden_dim // num_heads
227
+ self.dropout = dropout
228
+ self.qknorm = qknorm
229
+
230
+ self.q_proj = nn.Linear(self.input_dim, self.hidden_dim)
231
+ self.k_proj = nn.Linear(self.context_dim, self.hidden_dim)
232
+ self.v_proj = nn.Linear(self.context_dim, self.hidden_dim)
233
+ self.out_proj = nn.Linear(self.hidden_dim, self.output_dim)
234
+
235
+ if self.qknorm:
236
+ if qknorm_type == "RMSNorm":
237
+ self.q_norm = RMSNorm(self.hidden_dim, eps=1e-6)
238
+ self.k_norm = RMSNorm(self.hidden_dim, eps=1e-6)
239
+ else:
240
+ self.q_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
241
+ self.k_norm = nn.LayerNorm(self.hidden_dim, eps=1e-6, elementwise_affine=False)
242
+
243
+ def forward(self, x, context, mask_q=None, mask_kv=None):
244
+ # x: [B, N, C]
245
+ # context: [B, M, C']
246
+ # mask_q: [B, N]
247
+ # mask_kv: [B, M]
248
+ B, N, C = x.shape
249
+ M = context.shape[1]
250
+ q = self.q_proj(x)
251
+ k = self.k_proj(context)
252
+ v = self.v_proj(context)
253
+ if self.qknorm:
254
+ q = self.q_norm(q)
255
+ k = self.k_norm(k)
256
+ q = q.reshape(B, N, self.num_heads, self.head_dim)
257
+ k = k.reshape(B, M, self.num_heads, self.head_dim)
258
+ v = v.reshape(B, M, self.num_heads, self.head_dim)
259
+ x = attention(q, k, v, mask_q=mask_q, mask_kv=mask_kv, dropout=self.dropout, causal=False) # [B, N, H, D]
260
+ x = self.out_proj(x.reshape(B, N, -1))
261
+ return x
vae/modules/transformer.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import torch.nn as nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from vae.modules.attention import CrossAttention, SelfAttention
17
+
18
+
19
+ class FeedForward(nn.Module):
20
+ def __init__(self, dim, mult=4):
21
+ super().__init__()
22
+ self.net = nn.Sequential(nn.Linear(dim, dim * mult), nn.GELU(), nn.Linear(dim * mult, dim))
23
+
24
+ def forward(self, x):
25
+ return self.net(x)
26
+
27
+
28
+ class AttentionBlock(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dim,
32
+ num_heads,
33
+ dim_context=None,
34
+ qknorm=False,
35
+ gradient_checkpointing=True,
36
+ qknorm_type="LayerNorm",
37
+ ):
38
+ super().__init__()
39
+ self.dim = dim
40
+ self.num_heads = num_heads
41
+ self.dim_context = dim_context
42
+ self.gradient_checkpointing = gradient_checkpointing
43
+
44
+ self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
45
+ if dim_context is not None:
46
+ self.norm_context = nn.LayerNorm(dim_context, eps=1e-6, elementwise_affine=False)
47
+ self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
48
+ else:
49
+ self.attn = SelfAttention(dim, num_heads, qknorm=qknorm, qknorm_type=qknorm_type)
50
+
51
+ self.norm_ff = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
52
+ self.ff = FeedForward(dim)
53
+
54
+ def forward(self, x, c=None, mask=None, mask_c=None):
55
+ if self.training and self.gradient_checkpointing:
56
+ return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
57
+ else:
58
+ return self._forward(x, c, mask, mask_c)
59
+
60
+ def _forward(self, x, c=None, mask=None, mask_c=None):
61
+ # x: [B, N, C], hidden states
62
+ # c: [B, M, C'], condition (assume normed and projected to C)
63
+ # mask: [B, N], mask for x
64
+ # mask_c: [B, M], mask for c
65
+ # return: [B, N, C], updated hidden states
66
+
67
+ if c is not None:
68
+ x = x + self.attn(self.norm_attn(x), self.norm_context(c), mask_q=mask, mask_kv=mask_c)
69
+ else:
70
+ x = x + self.attn(self.norm_attn(x), mask=mask)
71
+
72
+ x = x + self.ff(self.norm_ff(x))
73
+
74
+ return x
75
+
76
+
77
+ # special attention block for the last cross-attn query layer
78
+ # 1. simple feed-forward (mult=1, no post ln)
79
+ # 2. no residual connection
80
+ # 3. no context ln
81
+ class FlashQueryLayer(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim,
85
+ num_heads,
86
+ dim_context,
87
+ qknorm=False,
88
+ gradient_checkpointing=True,
89
+ qknorm_type="LayerNorm",
90
+ ):
91
+ super().__init__()
92
+ self.dim = dim
93
+ self.num_heads = num_heads
94
+ self.dim_context = dim_context
95
+ self.gradient_checkpointing = gradient_checkpointing
96
+
97
+ self.norm_attn = nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
98
+ self.attn = CrossAttention(dim, num_heads, context_dim=dim_context, qknorm=qknorm, qknorm_type=qknorm_type)
99
+ self.ff = FeedForward(dim, mult=1)
100
+
101
+ def forward(self, x, c=None, mask=None, mask_c=None):
102
+ if self.training and self.gradient_checkpointing:
103
+ return checkpoint(self._forward, x, c, mask, mask_c, use_reentrant=False)
104
+ else:
105
+ return self._forward(x, c, mask, mask_c)
106
+
107
+ def _forward(self, x, c, mask=None, mask_c=None):
108
+ # x: [B, N, C], hidden states
109
+ # c: [B, M, C'], condition (assume normed and projected to C)
110
+ # mask: [B, N], mask for x
111
+ # mask_c: [B, M], mask for c
112
+ # return: [B, N, C], updated hidden states
113
+
114
+ x = self.attn(self.norm_attn(x), c, mask_q=mask, mask_kv=mask_c)
115
+ x = self.ff(x)
116
+
117
+ return x
vae/scripts/infer.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import argparse
14
+ import glob
15
+ import importlib
16
+ import os
17
+ from datetime import datetime
18
+
19
+ import fpsample
20
+ import kiui
21
+ import meshiki
22
+ import numpy as np
23
+ import torch
24
+ import trimesh
25
+
26
+ from vae.model import Model
27
+ from vae.utils import box_normalize, postprocess_mesh, sphere_normalize, sync_timer
28
+
29
+ # PYTHONPATH=. python vae/scripts/infer.py
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--config", type=str, help="config file path", default="vae.configs.part_woenc")
32
+ parser.add_argument(
33
+ "--ckpt_path",
34
+ type=str,
35
+ help="checkpoint path",
36
+ default="pretrained/vae.pt",
37
+ )
38
+ parser.add_argument("--input", type=str, help="input directory", default="assets/meshes/")
39
+ parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
40
+ parser.add_argument("--limit", type=int, help="how many samples to test", default=-1)
41
+ parser.add_argument("--num_fps_point", type=int, help="number of fps points", default=1024)
42
+ parser.add_argument("--num_fps_salient_point", type=int, help="number of fps salient points", default=1024)
43
+ parser.add_argument("--grid_res", type=int, help="grid resolution", default=512)
44
+ parser.add_argument("--seed", type=int, help="seed", default=42)
45
+ args = parser.parse_args()
46
+
47
+
48
+ TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
49
+
50
+ kiui.seed_everything(args.seed)
51
+
52
+
53
+ @sync_timer("prepare_input_from_mesh")
54
+ def prepare_input_from_mesh(mesh_path, use_salient_point=True, num_fps_point=1024, num_fps_salient_point=1024):
55
+ # load mesh, assume it's already processed to be watertight.
56
+
57
+ mesh_name = mesh_path.split("/")[-1].split(".")[0]
58
+ vertices, faces = meshiki.load_mesh(mesh_path)
59
+
60
+ # vertices = sphere_normalize(vertices)
61
+ vertices = box_normalize(vertices)
62
+
63
+ mesh = meshiki.Mesh(vertices, faces)
64
+
65
+ uniform_surface_points = mesh.uniform_point_sample(200000)
66
+ uniform_surface_points = meshiki.fps(uniform_surface_points, 32768) # hardcoded...
67
+ salient_surface_points = mesh.salient_point_sample(16384, thresh_bihedral=15)
68
+
69
+ # save points
70
+ # trimesh.PointCloud(vertices=uniform_surface_points).export(os.path.join(workspace, mesh_name + "_uniform.ply"))
71
+ # trimesh.PointCloud(vertices=salient_surface_points).export(os.path.join(workspace, mesh_name + "_salient.ply"))
72
+
73
+ sample = {}
74
+
75
+ sample["pointcloud"] = torch.from_numpy(uniform_surface_points)
76
+
77
+ # fps subsample
78
+ fps_indices = fpsample.bucket_fps_kdline_sampling(uniform_surface_points, num_fps_point, h=5, start_idx=0)
79
+ sample["fps_indices"] = torch.from_numpy(fps_indices).long() # [num_fps_point,]
80
+
81
+ if use_salient_point:
82
+ sample["pointcloud_dorases"] = torch.from_numpy(salient_surface_points) # [N', 3]
83
+
84
+ # fps subsample
85
+ fps_indices_dorases = fpsample.bucket_fps_kdline_sampling(
86
+ salient_surface_points, num_fps_salient_point, h=5, start_idx=0
87
+ )
88
+ sample["fps_indices_dorases"] = torch.from_numpy(fps_indices_dorases).long() # [num_fps_point,]
89
+
90
+ return sample
91
+
92
+
93
+ print(f"Loading checkpoint from {args.ckpt_path}")
94
+ ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
95
+
96
+ # delete all keys other than model
97
+ if "model" in ckpt_dict:
98
+ ckpt_dict = ckpt_dict["model"]
99
+
100
+ # instantiate model
101
+ print(f"Instantiating model from {args.config}")
102
+ model_config = importlib.import_module(args.config).make_config()
103
+ model = Model(model_config).eval().cuda().bfloat16()
104
+
105
+ # load weight
106
+ print(f"Loading weights from {args.ckpt_path}")
107
+ model.load_state_dict(ckpt_dict, strict=True)
108
+
109
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
110
+ workspace = os.path.join(args.output_dir, "vae_" + args.config.split(".")[-1] + "_" + timestamp)
111
+ if not os.path.exists(workspace):
112
+ os.makedirs(workspace)
113
+ else:
114
+ os.system(f"rm {workspace}/*")
115
+ print(f"Output directory: {workspace}")
116
+
117
+ # load dataset
118
+ mesh_list = glob.glob(os.path.join(args.input, "*"))
119
+ mesh_list = mesh_list[: args.limit] if args.limit > 0 else mesh_list
120
+
121
+ for i, mesh_path in enumerate(mesh_list):
122
+ print(f"Processing {i}/{len(mesh_list)}: {mesh_path}")
123
+
124
+ mesh_name = mesh_path.split("/")[-1].split(".")[0]
125
+
126
+ sample = prepare_input_from_mesh(
127
+ mesh_path, num_fps_point=args.num_fps_point, num_fps_salient_point=args.num_fps_salient_point
128
+ )
129
+ for k in sample:
130
+ sample[k] = sample[k].unsqueeze(0).cuda()
131
+
132
+ # call vae
133
+ with torch.inference_mode():
134
+ output = model(sample, resolution=args.grid_res)
135
+
136
+ latent = output["latent"]
137
+ vertices, faces = output["meshes"][0]
138
+
139
+ mesh = trimesh.Trimesh(vertices, faces)
140
+ mesh = postprocess_mesh(mesh, 5e5)
141
+
142
+ mesh.export(f"{workspace}/{mesh_name}.glb")
vae/utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ -----------------------------------------------------------------------------
3
+ Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+
5
+ NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ and proprietary rights in and to this software, related documentation
7
+ and any modifications thereto. Any use, reproduction, disclosure or
8
+ distribution of this software and related documentation without an express
9
+ license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+ -----------------------------------------------------------------------------
11
+ """
12
+
13
+ import os
14
+ from functools import wraps
15
+ from typing import Literal
16
+
17
+ import numpy as np
18
+ import torch
19
+ import trimesh
20
+ from kiui.mesh_utils import clean_mesh, decimate_mesh
21
+
22
+
23
+ # Adapted from https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/utils.py#L38
24
+ class sync_timer:
25
+ """
26
+ Synchronized timer to count the inference time of `nn.Module.forward` or else.
27
+ set env var TIMER=1 to enable logging!
28
+
29
+ Example as context manager:
30
+ ```python
31
+ with timer('name'):
32
+ run()
33
+ ```
34
+
35
+ Example as decorator:
36
+ ```python
37
+ @timer('name')
38
+ def run():
39
+ pass
40
+ ```
41
+ """
42
+
43
+ def __init__(self, name=None, flag_env="TIMER"):
44
+ self.name = name
45
+ self.flag_env = flag_env
46
+
47
+ def __enter__(self):
48
+ if os.environ.get(self.flag_env, "0") == "1":
49
+ self.start = torch.cuda.Event(enable_timing=True)
50
+ self.end = torch.cuda.Event(enable_timing=True)
51
+ self.start.record()
52
+ return lambda: self.time
53
+
54
+ def __exit__(self, exc_type, exc_value, exc_tb):
55
+ if os.environ.get(self.flag_env, "0") == "1":
56
+ self.end.record()
57
+ torch.cuda.synchronize()
58
+ self.time = self.start.elapsed_time(self.end)
59
+ if self.name is not None:
60
+ print(f"{self.name} takes {self.time} ms")
61
+
62
+ def __call__(self, func):
63
+ @wraps(func)
64
+ def wrapper(*args, **kwargs):
65
+ with self:
66
+ result = func(*args, **kwargs)
67
+ return result
68
+
69
+ return wrapper
70
+
71
+
72
+ @torch.no_grad()
73
+ def calculate_iou(pred: torch.Tensor, gt: torch.Tensor, target_value: int, thresh: float = 0) -> torch.Tensor:
74
+ """Calculate the Intersection over Union (IoU) between two volumes.
75
+
76
+ Args:
77
+ pred (torch.Tensor): [*] continuous value between 0 and 1
78
+ gt (torch.Tensor): [*] discrete value of 0 or 1
79
+ target_value (int): The value to be considered as the target class
80
+
81
+ Returns:
82
+ torch.Tensor: IoU value
83
+ """
84
+ # Ensure volumes have the same shape
85
+ assert pred.shape == gt.shape, "Volumes must have the same shape"
86
+
87
+ # binarize
88
+ pred_binary = pred > thresh
89
+ gt = gt > thresh
90
+
91
+ # Convert the volumes to boolean tensors for logical operations
92
+ intersection = torch.logical_and(pred_binary == target_value, gt == target_value).sum().float()
93
+ union = torch.logical_or(pred_binary == target_value, gt == target_value).sum().float()
94
+
95
+ # Compute IoU
96
+ iou = intersection / union if union != 0 else torch.tensor(0.0)
97
+ return iou
98
+
99
+
100
+ @torch.no_grad()
101
+ def calculate_metrics(
102
+ pred: torch.Tensor, gt: torch.Tensor, target_value: int = 1, thresh: float = 0.5
103
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
104
+ """Calculate Precision, Recall, and F1 between two volumes.
105
+
106
+ Args:
107
+ pred (torch.Tensor): [*] continuous value between 0 and 1
108
+ gt (torch.Tensor): [*] discrete value of 0 or 1
109
+ target_value (int): The value to be considered as the target class
110
+
111
+ Returns:
112
+ tuple: Precision, Recall, F1 values
113
+ """
114
+ assert pred.shape == gt.shape, f"Pred {pred.shape} and gt {gt.shape} must have the same shape"
115
+
116
+ # Binarize prediction
117
+ pred_binary = pred > thresh
118
+ gt = gt > thresh
119
+
120
+ # True Positive (TP): pred == target_value and gt == target_value
121
+ true_positive = torch.logical_and(pred_binary == target_value, gt == target_value).sum().float()
122
+
123
+ # False Positive (FP): pred == target_value and gt != target_value
124
+ false_positive = torch.logical_and(pred_binary == target_value, gt != target_value).sum().float()
125
+
126
+ # False Negative (FN): pred != target_value and gt == target_value
127
+ false_negative = torch.logical_and(pred_binary != target_value, gt == target_value).sum().float()
128
+
129
+ # Precision: TP / (TP + FP), best to detect False Positives
130
+ precision = (
131
+ true_positive / (true_positive + false_positive) if (true_positive + false_positive) != 0 else torch.tensor(0.0)
132
+ )
133
+
134
+ # Recall: TP / (TP + FN), best to detect False Negatives
135
+ recall = (
136
+ true_positive / (true_positive + false_negative) if (true_positive + false_negative) != 0 else torch.tensor(0.0)
137
+ )
138
+
139
+ # f1: 2 / (1 / precision + 1 / recall)
140
+ f1 = 2 / (1 / precision + 1 / recall) if (precision != 0 and recall != 0) else torch.tensor(0.0)
141
+
142
+ return precision, recall, f1
143
+
144
+
145
+ # Adapted from https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/distributions/distributions.py#L24
146
+ class DiagonalGaussianDistribution:
147
+ """VAE latent"""
148
+
149
+ def __init__(self, mean, logvar, deterministic=False):
150
+ # mean, logvar: [B, L, D] x 2
151
+ self.mean, self.logvar = mean, logvar
152
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
153
+ self.deterministic = deterministic
154
+ self.std = torch.exp(0.5 * self.logvar)
155
+ self.var = torch.exp(self.logvar)
156
+ if self.deterministic:
157
+ self.var = self.std = torch.zeros_like(self.mean, device=self.mean.device, dtype=self.mean.dtype)
158
+
159
+ def sample(self, weight: float = 1.0):
160
+ sample = weight * torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype)
161
+ x = self.mean + self.std * sample
162
+ return x
163
+
164
+ def kl(self, other=None, dims=[1, 2]):
165
+ if self.deterministic:
166
+ return torch.Tensor([0.0])
167
+ else:
168
+ if other is None:
169
+ return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
170
+ else:
171
+ return 0.5 * torch.mean(
172
+ torch.pow(self.mean - other.mean, 2) / other.var
173
+ + self.var / other.var
174
+ - 1.0
175
+ - self.logvar
176
+ + other.logvar,
177
+ dim=dims,
178
+ )
179
+
180
+ def nll(self, sample, dims=[1, 2]):
181
+ if self.deterministic:
182
+ return torch.Tensor([0.0])
183
+ logtwopi = np.log(2.0 * np.pi)
184
+ return 0.5 * torch.mean(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
185
+
186
+ def mode(self):
187
+ return self.mean
188
+
189
+
190
+ class DummyLatent:
191
+ def __init__(self, mean):
192
+ self.mean = mean
193
+
194
+ def sample(self, weight=0):
195
+ # simply perturb the mean
196
+ if weight > 0:
197
+ noise = torch.randn_like(self.mean) * weight
198
+ else:
199
+ noise = 0
200
+ return self.mean + noise
201
+
202
+ def mode(self):
203
+ return self.mean
204
+
205
+ def kl(self):
206
+ # just an l2 penalty
207
+ return 0.5 * torch.mean(torch.pow(self.mean, 2))
208
+
209
+
210
+ def construct_grid_points(
211
+ resolution: int,
212
+ indexing: str = "ij",
213
+ ):
214
+ """Generate dense grid points in [-1, 1]^3.
215
+
216
+ Args:
217
+ resolution (int): resolution of the grid
218
+ indexing (str, optional): indexing of the grid. Defaults to "ij".
219
+
220
+ Returns:
221
+ torch.Tensor: grid points (resolution + 1, resolution + 1, resolution + 1, 3), inside bbox.
222
+ """
223
+ x = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
224
+ y = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
225
+ z = np.linspace(-1, 1, resolution + 1, dtype=np.float32)
226
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
227
+ xyzs = np.stack((xs, ys, zs), axis=-1)
228
+ xyzs = torch.from_numpy(xyzs).float()
229
+ return xyzs
230
+
231
+
232
+ _diso_session = None # lazy session for reuse
233
+
234
+
235
+ @sync_timer("extract_mesh")
236
+ def extract_mesh(
237
+ grid_vals: torch.Tensor,
238
+ resolution: int,
239
+ isosurface_level: float = 0,
240
+ backend: Literal["mcubes", "diso"] = "mcubes",
241
+ ):
242
+ """Extract mesh from grid occupancy.
243
+
244
+ Args:
245
+ grid_vals (torch.Tensor): [resolution + 1, resolution + 1, resolution + 1], assume to be TSDF in [-1, 1] (inner is positive)
246
+ resolution (int, optional): Grid resolution.
247
+ isosurface_level (float, optional): Iso-surface level. Defaults to 0.
248
+ backend (Literal["mcubes", "diso"], optional): Backend for mesh extraction. Defaults to "diso", which uses GPU and is faster.
249
+ Returns:
250
+ vertices (np.ndarray): [N, 3], float32, in [-1, 1]
251
+ faces (np.ndarray): [M, 3], int32
252
+ """
253
+
254
+ grid_vals = grid_vals.view(resolution + 1, resolution + 1, resolution + 1)
255
+
256
+ if backend == "mcubes":
257
+ try:
258
+ import mcubes
259
+ except ImportError:
260
+ os.system("pip install pymcubes")
261
+ import mcubes
262
+ grid_vals = grid_vals.float().cpu().numpy()
263
+ verts, faces = mcubes.marching_cubes(grid_vals, isosurface_level)
264
+ verts = 2 * verts / resolution - 1.0 # normalize to [-1, 1]
265
+ elif backend == "diso":
266
+ try:
267
+ import diso
268
+ except ImportError:
269
+ os.system("pip install diso")
270
+ import diso
271
+ global _diso_session
272
+ if _diso_session is None:
273
+ _diso_session = diso.DiffDMC(dtype=torch.float32).cuda()
274
+
275
+ grid_vals = -grid_vals.float().cuda() # diso assumes inner is NEGATIVE!
276
+ verts, faces = _diso_session(grid_vals, deform=None, normalize=True) # verts in [0, 1]
277
+ verts = verts.cpu().numpy() * 2 - 1.0 # normalize to [-1, 1]
278
+ faces = faces.cpu().numpy()
279
+
280
+ return verts, faces
281
+
282
+
283
+ @sync_timer("postprocess_mesh")
284
+ def postprocess_mesh(mesh: trimesh.Trimesh, decimate_target=100000):
285
+ vertices = mesh.vertices
286
+ triangles = mesh.faces
287
+
288
+ if vertices.shape[0] > 0 and triangles.shape[0] > 0:
289
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
290
+ if triangles.shape[0] > decimate_target:
291
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
292
+ if vertices.shape[0] > 0 and triangles.shape[0] > 0:
293
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False, min_f=25, min_d=5)
294
+
295
+ mesh.vertices = vertices
296
+ mesh.faces = triangles
297
+
298
+ return mesh
299
+
300
+
301
+ def sphere_normalize(vertices):
302
+ bmin = vertices.min(axis=0)
303
+ bmax = vertices.max(axis=0)
304
+ bcenter = (bmax + bmin) / 2
305
+ radius = np.linalg.norm(vertices - bcenter, axis=-1).max()
306
+ vertices = (vertices - bcenter) / radius # to [-1, 1]
307
+ return vertices
308
+
309
+
310
+ def box_normalize(vertices, bound=0.95):
311
+ bmin = vertices.min(axis=0)
312
+ bmax = vertices.max(axis=0)
313
+ bcenter = (bmax + bmin) / 2
314
+ vertices = bound * (vertices - bcenter) / (bmax - bmin).max()
315
+ return vertices