bubbliiiing
Update V5.1
c2a6cd2
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .activations import get_activation
try:
current_version = torch.__version__
version_numbers = [int(x) for x in current_version.split('.')[:2]]
if version_numbers[0] < 2 or (version_numbers[0] == 2 and version_numbers[1] < 2):
need_to_float = True
else:
need_to_float = False
except Exception as e:
print("Encountered an error with Torch version. Set the data type to float in the VAE. ")
need_to_float = False
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
class CausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3, # : int | tuple[int, int, int],
stride=1, # : int | tuple[int, int, int] = 1,
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
dilation=1, # : int | tuple[int, int, int] = 1,
**kwargs,
):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
t_ks, h_ks, w_ks = kernel_size
self.t_stride, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
t_pad = (t_ks - 1) * t_dilation
# TODO: align with SD
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
self.temporal_padding = t_pad
self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2)
self.padding_flag = 0
self.prev_features = None
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
**kwargs,
)
def _clear_conv_cache(self):
del self.prev_features
self.prev_features = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
dtype = x.dtype
if need_to_float:
x = x.float()
if self.padding_flag == 0:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
x = x.to(dtype=dtype)
return super().forward(x)
elif self.padding_flag == 3:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
x = x.to(dtype=dtype)
# Clear cache before
self._clear_conv_cache()
# We could move these to the cpu for a lower VRAM
self.prev_features = x[:, :, -self.temporal_padding:].clone()
b, c, f, h, w = x.size()
outputs = []
i = 0
while i + self.temporal_padding + 1 <= f:
out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
i += self.t_stride
outputs.append(out)
return torch.concat(outputs, 2)
elif self.padding_flag == 4:
if self.t_stride == 2:
x = torch.concat(
[self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
)
else:
x = torch.concat(
[self.prev_features, x], dim = 2
)
x = x.to(dtype=dtype)
# Clear cache before
self._clear_conv_cache()
# We could move these to the cpu for a lower VRAM
self.prev_features = x[:, :, -self.temporal_padding:].clone()
b, c, f, h, w = x.size()
outputs = []
i = 0
while i + self.temporal_padding + 1 <= f:
out = super().forward(x[:, :, i:i + self.temporal_padding + 1])
i += self.t_stride
outputs.append(out)
return torch.concat(outputs, 2)
elif self.padding_flag == 5:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
x = x.to(dtype=dtype)
# Clear cache before
self._clear_conv_cache()
# We could move these to the cpu for a lower VRAM
self.prev_features = x[:, :, -self.temporal_padding:].clone()
return super().forward(x)
elif self.padding_flag == 6:
if self.t_stride == 2:
x = torch.concat(
[self.prev_features[:, :, -(self.temporal_padding - 1):], x], dim = 2
)
else:
x = torch.concat(
[self.prev_features, x], dim = 2
)
# Clear cache before
self._clear_conv_cache()
# We could move these to the cpu for a lower VRAM
self.prev_features = x[:, :, -self.temporal_padding:].clone()
x = x.to(dtype=dtype)
return super().forward(x)
else:
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding_origin, self.temporal_padding_origin),
)
x = x.to(dtype=dtype)
return super().forward(x)
class ResidualBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
non_linearity: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=norm_eps,
affine=True,
)
self.nonlinearity = get_activation(non_linearity)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=out_channels,
eps=norm_eps,
affine=True,
)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
self.set_3dgroupnorm = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm1(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm2(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class ResidualBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
non_linearity: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=in_channels,
eps=norm_eps,
affine=True,
)
self.nonlinearity = get_activation(non_linearity)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = nn.GroupNorm(
num_groups=norm_num_groups,
num_channels=out_channels,
eps=norm_eps,
affine=True,
)
self.dropout = nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3)
if in_channels != out_channels:
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = nn.Identity()
self.set_3dgroupnorm = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm1(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
if self.set_3dgroupnorm:
batch_size = x.shape[0]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm2(x)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class SpatialNorm2D(nn.Module):
"""
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
Args:
f_channels (`int`):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
):
super().__init__()
self.norm = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.set_3dgroupnorm = False
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
if self.set_3dgroupnorm:
batch_size = f.shape[0]
f = rearrange(f, "b c t h w -> (b t) c h w")
norm_f = self.norm(f)
norm_f = rearrange(norm_f, "(b t) c h w -> b c t h w", b=batch_size)
else:
norm_f = self.norm(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class SpatialNorm3D(SpatialNorm2D):
def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
batch_size = f.shape[0]
f = rearrange(f, "b c t h w -> (b t) c h w")
zq = rearrange(zq, "b c t h w -> (b t) c h w")
x = super().forward(f, zq)
x = rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
return x