Spaces:
Running
on
Zero
Running
on
Zero
""" | |
----------------------------------------------------------------------------- | |
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
NVIDIA CORPORATION and its licensors retain all intellectual property | |
and proprietary rights in and to this software, related documentation | |
and any modifications thereto. Any use, reproduction, disclosure or | |
distribution of this software and related documentation without an express | |
license agreement from NVIDIA CORPORATION is strictly prohibited. | |
----------------------------------------------------------------------------- | |
""" | |
import importlib | |
from transformers import Dinov2Model | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import tqdm | |
from torchvision import transforms | |
from flow.configs.schema import ModelConfig | |
from flow.flow_matching import FlowMatchingScheduler | |
from flow.modules.dit import DiT | |
from vae.model import Model as VAE | |
from vae.utils import sync_timer | |
class Model(nn.Module): | |
def __init__(self, config: ModelConfig) -> None: | |
super().__init__() | |
self.config = config | |
self.precision = torch.bfloat16 | |
# image condition model (dinov2) | |
if self.config.dino_model == "dinov2_vitg14": | |
self.dino = Dinov2Model.from_pretrained("facebook/dinov2-giant") | |
elif self.config.dino_model == "dinov2_vitl14_reg": | |
self.dino = Dinov2Model.from_pretrained("facebook/dinov2-with-registers-large") | |
else: | |
raise ValueError(f"DINOv2 model {self.config.dino_model} not supported") | |
# hack to match our implementation | |
self.dino.layernorm = torch.nn.Identity() | |
self.dino.eval().to(dtype=self.precision) | |
self.dino.requires_grad_(False) | |
cond_dim = 1024 if self.config.dino_model == "dinov2_vitl14_reg" else 1536 | |
assert cond_dim == config.hidden_dim, "DINOv2 dim must match backbone dim" | |
self.preprocess_cond_image = transforms.Compose( | |
[ | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
# vae encoder | |
vae_config = importlib.import_module(config.vae_conf).make_config() | |
self.vae = VAE(vae_config).eval().to(dtype=self.precision) | |
self.vae.requires_grad_(False) | |
# load vae | |
if self.config.preload_vae: | |
try: | |
vae_ckpt = torch.load(self.config.vae_ckpt_path, weights_only=True) # local path | |
if "model" in vae_ckpt: | |
vae_ckpt = vae_ckpt["model"] | |
self.vae.load_state_dict(vae_ckpt, strict=True) | |
del vae_ckpt | |
print(f"Loaded VAE from {self.config.vae_ckpt_path}") | |
except Exception as e: | |
print( | |
f"Failed to load VAE from {self.config.vae_ckpt_path}: {e}, make sure you resumed from a valid checkpoint!" | |
) | |
# load info from vae config | |
if config.latent_size is None: | |
config.latent_size = self.vae.config.latent_size | |
if config.latent_dim is None: | |
config.latent_dim = self.vae.config.latent_dim | |
# dit | |
self.dit = DiT( | |
hidden_dim=config.hidden_dim, | |
num_heads=config.num_heads, | |
num_layers=config.num_layers, | |
latent_size=config.latent_size, | |
latent_dim=config.latent_dim, | |
qknorm=config.qknorm, | |
qknorm_type=config.qknorm_type, | |
use_pos_embed=config.use_pos_embed, | |
use_parts=config.use_parts, | |
part_embed_mode=config.part_embed_mode, | |
) | |
# num_part condition | |
if self.config.use_num_parts_cond: | |
assert self.config.use_parts, "use_num_parts_cond requires use_parts" | |
self.num_part_embed = nn.Embedding(5, config.hidden_dim) | |
# preload from a checkpoint (NOTE: this happens BEFORE checkpointer loading latest checkpoint!) | |
if self.config.pretrain_path is not None: | |
try: | |
ckpt = torch.load(self.config.pretrain_path) # local path | |
self.load_state_dict(ckpt["model"], strict=True) | |
del ckpt | |
print(f"Loaded DiT from {self.config.pretrain_path}") | |
except Exception as e: | |
print( | |
f"Failed to load DiT from {self.config.pretrain_path}: {e}, make sure you resumed from a valid checkpoint!" | |
) | |
# sampler | |
self.scheduler = FlowMatchingScheduler(shift=config.flow_shift) | |
n_params = 0 | |
for p in self.dit.parameters(): | |
n_params += p.numel() | |
print(f"Number of parameters in DiT: {n_params/1e6:.2f}M") | |
# override state_dict to exclude vae and dino, so we only save the trainable params. | |
def state_dict(self, *args, **kwargs): | |
state_dict = super().state_dict(*args, **kwargs) | |
keys_to_del = [] | |
for k in state_dict.keys(): | |
if "vae" in k or "dino" in k: | |
keys_to_del.append(k) | |
for k in keys_to_del: | |
del state_dict[k] | |
return state_dict | |
# override to support tolerant loading (only load matched shape) | |
def load_state_dict(self, state_dict, strict=True, assign=False): | |
local_state_dict = self.state_dict() | |
seen_keys = {k: False for k in local_state_dict.keys()} | |
for k, v in state_dict.items(): | |
if k in local_state_dict: | |
seen_keys[k] = True | |
if local_state_dict[k].shape == v.shape: | |
local_state_dict[k].copy_(v) | |
else: | |
print(f"mismatching shape for key {k}: loaded {local_state_dict[k].shape} but model has {v.shape}") | |
else: | |
print(f"unexpected key {k} in loaded state dict") | |
for k in seen_keys: | |
if not seen_keys[k]: | |
print(f"missing key {k} in loaded state dict") | |
# this happens before checkpointer loading old models !!! | |
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: | |
super().on_train_start(memory_format=memory_format) | |
device = next(self.dit.parameters()).device | |
self.dit.to(dtype=self.precision) | |
if self.config.use_num_parts_cond: | |
self.num_part_embed.to(dtype=self.precision) | |
# cast scheduler to device | |
self.scheduler.to(device) | |
def get_cond(self, cond_image, num_part=None): | |
# image condition | |
cond_image = cond_image.to(dtype=self.precision) | |
with torch.no_grad(): | |
cond = self.dino(cond_image).last_hidden_state | |
cond = F.layer_norm(cond.float(), cond.shape[-1:]).to(dtype=self.precision) # [B, L, C] | |
# num_part condition | |
if self.config.use_num_parts_cond: | |
if num_part is None: | |
# use a default value (2-10 parts) | |
num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device) * 2 | |
else: | |
# coarse range | |
num_part_coarse = torch.ones(cond.shape[0], dtype=torch.int64, device=cond.device) | |
num_part_coarse[num_part == 2] = 1 | |
num_part_coarse[(num_part > 2) & (num_part <= 10)] = 2 | |
num_part_coarse[(num_part > 10) & (num_part <= 100)] = 3 | |
num_part_coarse[num_part > 100] = 4 | |
num_part_embed = self.num_part_embed(num_part_coarse).unsqueeze(1) # [B, 1, C] | |
cond = torch.cat([cond, num_part_embed], dim=1) # [B, L+1, C] | |
return cond | |
def training_step( | |
self, | |
data: dict[str, torch.Tensor], | |
iteration: int, | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
output = {} | |
loss = 0 | |
cond_images = self.preprocess_cond_image( | |
data["cond_images"] | |
) # [B, N, 3, 518, 518], we may load multiple (N) cond images for the same shape | |
B, N, C, H, W = cond_images.shape | |
if self.config.use_num_parts_cond: | |
cond_num_part = data["num_part"].repeat_interleave(N, dim=0) | |
else: | |
cond_num_part = None | |
cond = self.get_cond(cond_images.view(-1, C, H, W), cond_num_part) # [B*N, L, C] | |
# random CFG dropout | |
if self.training: | |
mask = torch.rand((B * N, 1, 1), device=cond.device, dtype=cond.dtype) >= 0.1 | |
cond = cond * mask | |
with torch.no_grad(): | |
# encode latent | |
if self.config.use_parts: | |
# encode two parts and concat latent | |
part0_data = {k.replace("_part0", ""): v for k, v in data.items() if "_part0" in k} | |
part1_data = {k.replace("_part1", ""): v for k, v in data.items() if "_part1" in k} | |
posterior0 = self.vae.encode(part0_data) | |
posterior1 = self.vae.encode(part1_data) | |
if self.training and self.config.shuffle_parts: | |
if np.random.rand() < 0.5: | |
posterior0, posterior1 = posterior1, posterior0 | |
latent = torch.cat( | |
[ | |
posterior0.mode().float().nan_to_num_(0), | |
posterior1.mode().float().nan_to_num_(0), | |
], | |
dim=1, | |
) # [B, 2L, C] | |
else: | |
posterior = self.vae.encode(data) | |
latent = posterior.mode().float().nan_to_num_(0) # use mean as the latent, [B, L, C] | |
# repeat latent for each cond image | |
if N != 1: | |
latent = latent.repeat_interleave(N, dim=0) | |
# random sample timesteps and add noise | |
noisy_latent, noise, timesteps = self.scheduler.add_noise( | |
latent, self.config.logitnorm_mean, self.config.logitnorm_std | |
) | |
noisy_latent = noisy_latent.to(dtype=self.precision) | |
model_pred = self.dit(noisy_latent, cond, timesteps) | |
# flow-matching loss | |
target = noise - latent | |
loss = F.mse_loss(model_pred.float(), target.float()) | |
# metrics | |
with torch.no_grad(): | |
output["scalar"] = {} # for wandb logging | |
output["scalar"]["loss_mse"] = loss.detach() | |
return output, loss | |
def validation_step( | |
self, | |
data: dict[str, torch.Tensor], | |
iteration: int, | |
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
return self.training_step(data, iteration) | |
def forward( | |
self, | |
data: dict[str, torch.Tensor], | |
num_steps: int = 30, | |
cfg_scale: float = 7.0, | |
verbose: bool = True, | |
generator: torch.Generator | None = None, | |
) -> dict[str, torch.Tensor]: | |
# the inference sampling | |
cond_images = self.preprocess_cond_image(data["cond_images"]) # [B, 3, 518, 518] | |
B = cond_images.shape[0] | |
assert B == 1, "Only support batch size 1 for now." | |
# num_part condition | |
if self.config.use_num_parts_cond and "num_part" in data: | |
cond_num_part = data["num_part"] # [B,], int | |
else: | |
cond_num_part = None | |
cond = self.get_cond(cond_images, cond_num_part) | |
if self.config.use_parts: | |
x = torch.randn( | |
B, | |
self.config.latent_size * 2, | |
self.config.latent_dim, | |
device=cond.device, | |
dtype=torch.float32, | |
generator=generator, | |
) | |
else: | |
x = torch.randn( | |
B, | |
self.config.latent_size, | |
self.config.latent_dim, | |
device=cond.device, | |
dtype=torch.float32, | |
generator=generator, | |
) | |
cond_input = torch.cat([cond, torch.zeros_like(cond)], dim=0) | |
# flow-matching | |
sigmas = np.linspace(1, 0, num_steps + 1) | |
sigmas = self.scheduler.shift * sigmas / (1 + (self.scheduler.shift - 1) * sigmas) | |
sigmas_pair = list((sigmas[i], sigmas[i + 1]) for i in range(num_steps)) | |
for sigma, sigma_prev in tqdm.tqdm(sigmas_pair, desc="Flow Sampling", disable=not verbose): | |
# classifier-free guidance | |
timesteps = torch.tensor([1000 * sigma] * B * 2, device=x.device, dtype=x.dtype) | |
x_input = torch.cat([x, x], dim=0) | |
# predict v | |
x_input = x_input.to(dtype=self.precision) | |
pred = self.dit(x_input, cond_input, timesteps).float() | |
cond_v, uncond_v = pred.chunk(2, dim=0) | |
pred_v = uncond_v + (cond_v - uncond_v) * cfg_scale | |
# scheduler step | |
x = x - (sigma - sigma_prev) * pred_v | |
output = {} | |
output["latent"] = x | |
# leave mesh extraction to vae | |
return output | |