svjack's picture
Upload folder using huggingface_hub
e13f5a4 verified
from dataclasses import dataclass
import json
from typing import Optional, Tuple, Union
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm
from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
SCALING_FACTOR = 0.476986
VAE_VER = "884-16c-hy"
def load_vae(
vae_type: str = "884-16c-hy",
vae_dtype: Optional[Union[str, torch.dtype]] = None,
sample_size: tuple = None,
vae_path: str = None,
device=None,
):
"""the fucntion to load the 3D VAE model
Args:
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
vae_precision (str, optional): the precision to load vae. Defaults to None.
sample_size (tuple, optional): the tiling size. Defaults to None.
vae_path (str, optional): the path to vae. Defaults to None.
logger (_type_, optional): logger. Defaults to None.
device (_type_, optional): device to load vae. Defaults to None.
"""
if vae_path is None:
vae_path = VAE_PATH[vae_type]
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
# use fixed config for Hunyuan's VAE
CONFIG_JSON = """{
"_class_name": "AutoencoderKLCausal3D",
"_diffusers_version": "0.4.2",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D",
"DownEncoderBlockCausal3D"
],
"in_channels": 3,
"latent_channels": 16,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 256,
"sample_tsize": 64,
"up_block_types": [
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D",
"UpDecoderBlockCausal3D"
],
"scaling_factor": 0.476986,
"time_compression_ratio": 4,
"mid_block_add_attention": true
}"""
# config = AutoencoderKLCausal3D.load_config(vae_path)
config = json.loads(CONFIG_JSON)
# import here to avoid circular import
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
if sample_size:
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
else:
vae = AutoencoderKLCausal3D.from_config(config)
# vae_ckpt = Path(vae_path) / "pytorch_model.pt"
# assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
if any(k.startswith("vae.") for k in ckpt.keys()):
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
vae.load_state_dict(ckpt)
spatial_compression_ratio = vae.config.spatial_compression_ratio
time_compression_ratio = vae.config.time_compression_ratio
if vae_dtype is not None:
vae = vae.to(vae_dtype)
vae.requires_grad_(False)
logger.info(f"VAE to dtype: {vae.dtype}")
if device is not None:
vae = vae.to(device)
vae.eval()
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class EncoderCausal3D(nn.Module):
r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = get_down_block3d(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
sample = self.conv_in(sample)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = get_up_block3d(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
if parameters.ndim == 3:
dim = 2 # (B, L, C)
elif parameters.ndim == 5 or parameters.ndim == 4:
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
else:
raise NotImplementedError
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
reduce_dim = list(range(1, self.mean.ndim))
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=reduce_dim,
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=reduce_dim,
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean