Spaces:
Running
on
Zero
Running
on
Zero
""" | |
----------------------------------------------------------------------------- | |
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
NVIDIA CORPORATION and its licensors retain all intellectual property | |
and proprietary rights in and to this software, related documentation | |
and any modifications thereto. Any use, reproduction, disclosure or | |
distribution of this software and related documentation without an express | |
license agreement from NVIDIA CORPORATION is strictly prohibited. | |
----------------------------------------------------------------------------- | |
""" | |
from typing import Literal | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from vae.configs.schema import ModelConfig | |
from vae.modules.transformer import AttentionBlock, FlashQueryLayer | |
from vae.utils import ( | |
DiagonalGaussianDistribution, | |
DummyLatent, | |
calculate_iou, | |
calculate_metrics, | |
construct_grid_points, | |
extract_mesh, | |
sync_timer, | |
) | |
class Model(nn.Module): | |
def __init__(self, config: ModelConfig) -> None: | |
super().__init__() | |
self.config = config | |
self.precision = torch.bfloat16 # manually handle low-precision training, always use bf16 | |
# point encoder | |
self.proj_input = nn.Linear(3 + config.point_fourier_dim, config.hidden_dim) | |
self.perceiver = AttentionBlock( | |
config.hidden_dim, | |
num_heads=config.num_heads, | |
dim_context=config.hidden_dim, | |
qknorm=config.qknorm, | |
qknorm_type=config.qknorm_type, | |
) | |
if self.config.salient_attn_mode == "dual": | |
self.perceiver_dorases = AttentionBlock( | |
config.hidden_dim, | |
num_heads=config.num_heads, | |
dim_context=config.hidden_dim, | |
qknorm=config.qknorm, | |
qknorm_type=config.qknorm_type, | |
) | |
# self-attention encoder | |
self.encoder = nn.ModuleList( | |
[ | |
AttentionBlock( | |
config.hidden_dim, config.num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type | |
) | |
for _ in range(config.num_enc_layers) | |
] | |
) | |
# vae bottleneck | |
self.norm_down = nn.LayerNorm(config.hidden_dim) | |
self.proj_down_mean = nn.Linear(config.hidden_dim, config.latent_dim) | |
if not self.config.use_ae: | |
self.proj_down_std = nn.Linear(config.hidden_dim, config.latent_dim) | |
self.proj_up = nn.Linear(config.latent_dim, config.dec_hidden_dim) | |
# self-attention decoder | |
self.decoder = nn.ModuleList( | |
[ | |
AttentionBlock( | |
config.dec_hidden_dim, config.dec_num_heads, qknorm=config.qknorm, qknorm_type=config.qknorm_type | |
) | |
for _ in range(config.num_dec_layers) | |
] | |
) | |
# cross-attention query | |
self.proj_query = nn.Linear(3 + config.point_fourier_dim, config.query_hidden_dim) | |
if self.config.use_flash_query: | |
self.norm_query_context = nn.LayerNorm(config.hidden_dim, eps=1e-6, elementwise_affine=False) | |
self.attn_query = FlashQueryLayer( | |
config.query_hidden_dim, | |
num_heads=config.query_num_heads, | |
dim_context=config.hidden_dim, | |
qknorm=config.qknorm, | |
qknorm_type=config.qknorm_type, | |
) | |
else: | |
self.attn_query = AttentionBlock( | |
config.query_hidden_dim, | |
num_heads=config.query_num_heads, | |
dim_context=config.hidden_dim, | |
qknorm=config.qknorm, | |
qknorm_type=config.qknorm_type, | |
) | |
self.norm_out = nn.LayerNorm(config.query_hidden_dim) | |
self.proj_out = nn.Linear(config.query_hidden_dim, 1) | |
# preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!) | |
if self.config.pretrain_path is not None: | |
try: | |
ckpt = torch.load(self.config.pretrain_path) # local path | |
self.load_state_dict(ckpt["model"], strict=True) | |
del ckpt | |
print(f"Loaded VAE from {self.config.pretrain_path}") | |
except Exception as e: | |
print( | |
f"Failed to load VAE from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!" | |
) | |
# log | |
n_params = 0 | |
for p in self.parameters(): | |
n_params += p.numel() | |
print(f"Number of parameters in VAE: {n_params / 1e6:.2f}M") | |
# override to support tolerant loading (only load matched shape) | |
def load_state_dict(self, state_dict, strict=True, assign=False): | |
local_state_dict = self.state_dict() | |
seen_keys = {k: False for k in local_state_dict.keys()} | |
for k, v in state_dict.items(): | |
if k in local_state_dict: | |
seen_keys[k] = True | |
if local_state_dict[k].shape == v.shape: | |
local_state_dict[k].copy_(v) | |
else: | |
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}") | |
else: | |
print(f"unexpected key {k} in loaded state dict") | |
for k in seen_keys: | |
if not seen_keys[k]: | |
print(f"missing key {k} in loaded state dict") | |
def fourier_encoding(self, points: torch.Tensor): | |
# points: [B, N, 3], float32 for precision | |
# assert points.dtype == torch.float32, "Query points must be float32" | |
F = self.config.point_fourier_dim // (2 * points.shape[-1]) | |
if self.config.fourier_version == "v1": # default | |
exponent = torch.arange(1, F + 1, device=points.device, dtype=torch.float32) / F # [F], range from 0 to 1 | |
freq_band = 512**exponent # [F], min frequency is 1, max frequency is 1/freq | |
freq_band *= torch.pi | |
elif self.config.fourier_version == "v2": | |
exponent = torch.arange(F, device=points.device, dtype=torch.float32) / (F - 1) # [F], range from 0 to 1 | |
freq_band = 1024**exponent # [F] | |
freq_band *= torch.pi | |
elif self.config.fourier_version == "v3": # hunyuan3d-2 | |
freq_band = 2 ** torch.arange(F, device=points.device, dtype=torch.float32) # [F] | |
spectrum = points.unsqueeze(-1) * freq_band # [B,...,3,F] | |
sin, cos = spectrum.sin(), spectrum.cos() # [B,...,3,F] | |
input_enc = torch.stack([sin, cos], dim=-2) # [B,...,3,2,F] | |
input_enc = input_enc.view(*points.shape[:-1], -1) # [B,...,6F] = [B,...,dim] | |
return torch.cat([input_enc, points], dim=-1).to(dtype=self.precision) # [B,...,dim+input_dim] | |
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: | |
super().on_train_start(memory_format=memory_format) | |
self.to(dtype=self.precision, memory_format=memory_format) # use bfloat16 for training | |
def encode(self, data: dict[str, torch.Tensor]): | |
# uniform points | |
pointcloud = data["pointcloud"] # [B, N, 3] | |
# fourier embed and project | |
pointcloud = self.fourier_encoding(pointcloud) # [B, N, 3+C] | |
pointcloud = self.proj_input(pointcloud) # [B, N, hidden_dim] | |
# salient points | |
if self.config.use_salient_point: | |
pointcloud_dorases = data["pointcloud_dorases"] # [B, M, 3] | |
# fourier embed and project (shared weights) | |
pointcloud_dorases = self.fourier_encoding(pointcloud_dorases) # [B, M, 3+C] | |
pointcloud_dorases = self.proj_input(pointcloud_dorases) # [B, M, hidden_dim] | |
# gather fps point | |
fps_indices = data["fps_indices"] # [B, N'] | |
pointcloud_query = torch.gather(pointcloud, 1, fps_indices.unsqueeze(-1).expand(-1, -1, pointcloud.shape[-1])) | |
if self.config.use_salient_point: | |
fps_indices_dorases = data["fps_indices_dorases"] # [B, M'] | |
if fps_indices_dorases.shape[1] > 0: | |
pointcloud_query_dorases = torch.gather( | |
pointcloud_dorases, | |
1, | |
fps_indices_dorases.unsqueeze(-1).expand(-1, -1, pointcloud_dorases.shape[-1]), | |
) | |
# combine both fps points as the query | |
pointcloud_query = torch.cat( | |
[pointcloud_query, pointcloud_query_dorases], dim=1 | |
) # [B, N'+M', hidden_dim] | |
# dual cross-attention | |
if self.config.salient_attn_mode == "dual_shared": | |
hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver( | |
pointcloud_query, pointcloud_dorases | |
) # [B, N'+M', hidden_dim] | |
elif self.config.salient_attn_mode == "dual": | |
hidden_states = self.perceiver(pointcloud_query, pointcloud) + self.perceiver_dorases( | |
pointcloud_query, pointcloud_dorases | |
) | |
else: # single, hunyuan3d-2 style | |
hidden_states = self.perceiver(pointcloud_query, torch.cat([pointcloud, pointcloud_dorases], dim=1)) | |
else: | |
hidden_states = self.perceiver(pointcloud_query, pointcloud) # [B, N', hidden_dim] | |
# encoder | |
for block in self.encoder: | |
hidden_states = block(hidden_states) | |
# bottleneck | |
hidden_states = self.norm_down(hidden_states) | |
latent_mean = self.proj_down_mean(hidden_states).float() | |
if not self.config.use_ae: | |
latent_std = self.proj_down_std(hidden_states).float() | |
posterior = DiagonalGaussianDistribution(latent_mean, latent_std) | |
else: | |
posterior = DummyLatent(latent_mean) | |
return posterior | |
def decode(self, latent: torch.Tensor): | |
latent = latent.to(dtype=self.precision) | |
hidden_states = self.proj_up(latent) | |
for block in self.decoder: | |
hidden_states = block(hidden_states) | |
return hidden_states | |
def query(self, query_points: torch.Tensor, hidden_states: torch.Tensor): | |
# query_points: [B, N, 3], float32 to keep the precision | |
query_points = self.fourier_encoding(query_points) # [B, N, 3+C] | |
query_points = self.proj_query(query_points) # [B, N, hidden_dim] | |
# cross attention | |
query_output = self.attn_query(query_points, hidden_states) # [B, N, hidden_dim] | |
# output linear | |
query_output = self.norm_out(query_output) | |
pred = self.proj_out(query_output) # [B, N, 1] | |
return pred | |
def training_step( | |
self, | |
data: dict[str, torch.Tensor], | |
iteration: int, | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
output = {} | |
# cut off fps point during training for progressive flow | |
if self.training: | |
# randomly choose from a set of cutoff candidates | |
cutoff_index = np.random.choice(len(self.config.cutoff_fps_prob), p=self.config.cutoff_fps_prob) | |
cutoff_fps_point = self.config.cutoff_fps_point[cutoff_index] | |
cutoff_fps_salient_point = self.config.cutoff_fps_salient_point[cutoff_index] | |
# prefix of FPS points are still FPS points | |
data["fps_indices"] = data["fps_indices"][:, :cutoff_fps_point] | |
if self.config.use_salient_point: | |
data["fps_indices_dorases"] = data["fps_indices_dorases"][:, :cutoff_fps_salient_point] | |
loss = 0 | |
# encode | |
posterior = self.encode(data) | |
latent_geom = posterior.sample() if self.training else posterior.mode() | |
# decode | |
hidden_states = self.decode(latent_geom) | |
# cross-attention query | |
query_points = data["query_points"] # [B, N, 3], float32 | |
# the context norm can be moved out to avoid repeated computation | |
if self.config.use_flash_query: | |
hidden_states = self.norm_query_context(hidden_states) | |
pred = self.query(query_points, hidden_states).squeeze(-1).float() # [B, N] | |
gt = data["query_gt"].float() # [B, N], in [-1, 1] | |
# main loss | |
loss_mse = F.mse_loss(pred, gt, reduction="mean") | |
loss += loss_mse | |
loss_l1 = F.l1_loss(pred, gt, reduction="mean") | |
loss += loss_l1 | |
# kl loss | |
loss_kl = posterior.kl().mean() | |
loss += self.config.kl_weight * loss_kl | |
# metrics | |
with torch.no_grad(): | |
output["scalar"] = {} # for wandb logging | |
output["scalar"]["loss_mse"] = loss_mse.detach() | |
output["scalar"]["loss_l1"] = loss_l1.detach() | |
output["scalar"]["loss_kl"] = loss_kl.detach() | |
output["scalar"]["iou_fg"] = calculate_iou(pred, gt, target_value=1) | |
output["scalar"]["iou_bg"] = calculate_iou(pred, gt, target_value=0) | |
output["scalar"]["precision"], output["scalar"]["recall"], output["scalar"]["f1"] = calculate_metrics( | |
pred, gt, target_value=1 | |
) | |
return output, loss | |
def validation_step( | |
self, | |
data: dict[str, torch.Tensor], | |
iteration: int, | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
return self.training_step(data, iteration) | |
def forward( | |
self, | |
data: dict[str, torch.Tensor], | |
mode: Literal["dense", "hierarchical"] = "hierarchical", | |
max_samples_per_iter: int = 512**2, | |
resolution: int = 512, | |
min_resolution: int = 64, # for hierarchical | |
) -> dict[str, torch.Tensor]: | |
output = {} | |
# encode | |
if "latent" in data: | |
latent = data["latent"] | |
else: | |
posterior = self.encode(data) | |
output["posterior"] = posterior | |
latent = posterior.mode() | |
output["latent"] = latent | |
B = latent.shape[0] | |
# decode | |
hidden_states = self.decode(latent) | |
output["hidden_states"] = hidden_states # [B, N, hidden_dim] for the last cross-attention decoder | |
# the context norm can be moved out to avoid repeated computation | |
if self.config.use_flash_query: | |
hidden_states = self.norm_query_context(hidden_states) | |
# query | |
def chunked_query(grid_points): | |
if grid_points.shape[0] <= max_samples_per_iter: | |
return self.query(grid_points.unsqueeze(0), hidden_states).squeeze(-1) # [B, N] | |
all_pred = [] | |
for i in range(0, grid_points.shape[0], max_samples_per_iter): | |
grid_chunk = grid_points[i : i + max_samples_per_iter] | |
pred_chunk = self.query(grid_chunk.unsqueeze(0), hidden_states) | |
all_pred.append(pred_chunk) | |
return torch.cat(all_pred, dim=1).squeeze(-1) # [B, N] | |
if mode == "dense": | |
grid_points = construct_grid_points(resolution).to(latent.device) | |
grid_points = grid_points.contiguous().view(-1, 3) | |
grid_vals = chunked_query(grid_points).float().view(B, resolution + 1, resolution + 1, resolution + 1) | |
elif mode == "hierarchical": | |
assert resolution >= min_resolution, "Resolution must be greater than or equal to min_resolution" | |
assert B == 1, "Only one batch is supported for hierarchical mode" | |
resolutions = [] | |
res = resolution | |
while res >= min_resolution: | |
resolutions.append(res) | |
res = res // 2 | |
resolutions.reverse() # e.g., [64, 128, 256, 512] | |
# dense-query the coarsest resolution | |
res = resolutions[0] | |
grid_points = construct_grid_points(res).to(latent.device) | |
grid_points = grid_points.contiguous().view(-1, 3) | |
grid_vals = chunked_query(grid_points).float().view(res + 1, res + 1, res + 1) | |
# sparse-query finer resolutions | |
dilate_kernel_3 = torch.ones(1, 1, 3, 3, 3, dtype=torch.float32, device=latent.device) | |
dilate_kernel_5 = torch.ones(1, 1, 5, 5, 5, dtype=torch.float32, device=latent.device) | |
for i in range(1, len(resolutions)): | |
res = resolutions[i] | |
# get the boundary grid mask in the coarser grid (where the grid_vals have different signs with at least one of its neighbors) | |
grid_signs = grid_vals >= 0 | |
mask = torch.zeros_like(grid_signs) | |
mask[1:, :, :] += grid_signs[1:, :, :] != grid_signs[:-1, :, :] | |
mask[:-1, :, :] += grid_signs[:-1, :, :] != grid_signs[1:, :, :] | |
mask[:, 1:, :] += grid_signs[:, 1:, :] != grid_signs[:, :-1, :] | |
mask[:, :-1, :] += grid_signs[:, :-1, :] != grid_signs[:, 1:, :] | |
mask[:, :, 1:] += grid_signs[:, :, 1:] != grid_signs[:, :, :-1] | |
mask[:, :, :-1] += grid_signs[:, :, :-1] != grid_signs[:, :, 1:] | |
# empirical: also add those with abs(grid_vals) < 0.95 | |
mask += grid_vals.abs() < 0.95 | |
mask = (mask > 0).float() | |
# empirical: dilate the coarse mask | |
if res < 512: | |
mask = mask.unsqueeze(0).unsqueeze(0) | |
mask = F.conv3d(mask, weight=dilate_kernel_3, padding=1) | |
mask = mask.squeeze(0).squeeze(0) | |
# get the coarse coordinates | |
cidx_x, cidx_y, cidx_z = torch.nonzero(mask, as_tuple=True) | |
# fill to the fine indices | |
mask_fine = torch.zeros(res + 1, res + 1, res + 1, dtype=torch.float32, device=latent.device) | |
mask_fine[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 | |
# empirical: dilate the fine mask | |
if res < 512: | |
mask_fine = mask_fine.unsqueeze(0).unsqueeze(0) | |
mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_3, padding=1) | |
mask_fine = mask_fine.squeeze(0).squeeze(0) | |
else: | |
mask_fine = mask_fine.unsqueeze(0).unsqueeze(0) | |
mask_fine = F.conv3d(mask_fine, weight=dilate_kernel_5, padding=2) | |
mask_fine = mask_fine.squeeze(0).squeeze(0) | |
# get the fine coordinates | |
fidx_x, fidx_y, fidx_z = torch.nonzero(mask_fine, as_tuple=True) | |
# convert to float query points | |
query_points = torch.stack([fidx_x, fidx_y, fidx_z], dim=-1) # [N, 3] | |
query_points = query_points * 2 / res - 1 # [N, 3], in [-1, 1] | |
# query | |
pred = chunked_query(query_points).float() | |
# fill to the fine indices | |
grid_vals = torch.full((res + 1, res + 1, res + 1), -100.0, dtype=torch.float32, device=latent.device) | |
grid_vals[fidx_x, fidx_y, fidx_z] = pred | |
# print(f"[INFO] hierarchical: resolution: {res}, valid coarse points: {len(cidx_x)}, valid fine points: {len(fidx_x)}") | |
grid_vals = grid_vals.unsqueeze(0) # [1, res+1, res+1, res+1] | |
grid_vals[grid_vals <= -100.0] = float("nan") # use nans to ignore invalid regions | |
# extract mesh | |
meshes = [] | |
for b in range(B): | |
vertices, faces = extract_mesh(grid_vals[b], resolution) | |
meshes.append((vertices, faces)) | |
output["meshes"] = meshes | |
return output | |