# Copyright 2025 The EasyAnimate team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .vae import DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name class EasyAnimateCausalConv3d(nn.Conv3d): def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, ...]] = 3, stride: Union[int, Tuple[int, ...]] = 1, padding: Union[int, Tuple[int, ...]] = 1, dilation: Union[int, Tuple[int, ...]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", ): # Ensure kernel_size, stride, and dilation are tuples of length 3 kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead." stride = stride if isinstance(stride, tuple) else (stride,) * 3 assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead." dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3 assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead." # Unpack kernel size, stride, and dilation for temporal, height, and width dimensions t_ks, h_ks, w_ks = kernel_size self.t_stride, h_stride, w_stride = stride t_dilation, h_dilation, w_dilation = dilation # Calculate padding for temporal dimension to maintain causality t_pad = (t_ks - 1) * t_dilation # Calculate padding for height and width dimensions based on the padding parameter if padding is None: h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2) w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2) elif isinstance(padding, int): h_pad = w_pad = padding else: assert NotImplementedError # Store temporal padding and initialize flags and previous features cache self.temporal_padding = t_pad self.temporal_padding_origin = math.ceil(((t_ks - 1) * w_dilation + (1 - w_stride)) / 2) self.prev_features = None # Initialize the parent class with modified padding super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=(0, h_pad, w_pad), groups=groups, bias=bias, padding_mode=padding_mode, ) def _clear_conv_cache(self): del self.prev_features self.prev_features = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Ensure input tensor is of the correct type dtype = hidden_states.dtype if self.prev_features is None: # Pad the input tensor in the temporal dimension to maintain causality hidden_states = F.pad( hidden_states, pad=(0, 0, 0, 0, self.temporal_padding, 0), mode="replicate", # TODO: check if this is necessary ) hidden_states = hidden_states.to(dtype=dtype) # Clear cache before processing and store previous features for causality self._clear_conv_cache() self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone() # Process the input tensor in chunks along the temporal dimension num_frames = hidden_states.size(2) outputs = [] i = 0 while i + self.temporal_padding + 1 <= num_frames: out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1]) i += self.t_stride outputs.append(out) return torch.concat(outputs, 2) else: # Concatenate previous features with the input tensor for continuous temporal processing if self.t_stride == 2: hidden_states = torch.concat( [self.prev_features[:, :, -(self.temporal_padding - 1) :], hidden_states], dim=2 ) else: hidden_states = torch.concat([self.prev_features, hidden_states], dim=2) hidden_states = hidden_states.to(dtype=dtype) # Clear cache and update previous features self._clear_conv_cache() self.prev_features = hidden_states[:, :, -self.temporal_padding :].clone() # Process the concatenated tensor in chunks along the temporal dimension num_frames = hidden_states.size(2) outputs = [] i = 0 while i + self.temporal_padding + 1 <= num_frames: out = super().forward(hidden_states[:, :, i : i + self.temporal_padding + 1]) i += self.t_stride outputs.append(out) return torch.concat(outputs, 2) class EasyAnimateResidualBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, non_linearity: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-6, spatial_group_norm: bool = True, dropout: float = 0.0, output_scale_factor: float = 1.0, ): super().__init__() self.output_scale_factor = output_scale_factor # Group normalization for input tensor self.norm1 = nn.GroupNorm( num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True, ) self.nonlinearity = get_activation(non_linearity) self.conv1 = EasyAnimateCausalConv3d(in_channels, out_channels, kernel_size=3) self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) self.dropout = nn.Dropout(dropout) self.conv2 = EasyAnimateCausalConv3d(out_channels, out_channels, kernel_size=3) if in_channels != out_channels: self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1) else: self.shortcut = nn.Identity() self.spatial_group_norm = spatial_group_norm def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shortcut = self.shortcut(hidden_states) if self.spatial_group_norm: batch_size = hidden_states.size(0) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W] hidden_states = self.norm1(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute( 0, 2, 1, 3, 4 ) # [B * T, C, H, W] -> [B, C, T, H, W] else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if self.spatial_group_norm: batch_size = hidden_states.size(0) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W] hidden_states = self.norm2(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute( 0, 2, 1, 3, 4 ) # [B * T, C, H, W] -> [B, C, T, H, W] else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) return (hidden_states + shortcut) / self.output_scale_factor class EasyAnimateDownsampler3D(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: tuple = (2, 2, 2)): super().__init__() self.conv = EasyAnimateCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0 ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.pad(hidden_states, (0, 1, 0, 1)) hidden_states = self.conv(hidden_states) return hidden_states class EasyAnimateUpsampler3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, temporal_upsample: bool = False, spatial_group_norm: bool = True, ): super().__init__() out_channels = out_channels or in_channels self.temporal_upsample = temporal_upsample self.spatial_group_norm = spatial_group_norm self.conv = EasyAnimateCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size ) self.prev_features = None def _clear_conv_cache(self): del self.prev_features self.prev_features = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.interpolate(hidden_states, scale_factor=(1, 2, 2), mode="nearest") hidden_states = self.conv(hidden_states) if self.temporal_upsample: if self.prev_features is None: self.prev_features = hidden_states else: hidden_states = F.interpolate( hidden_states, scale_factor=(2, 1, 1), mode="trilinear" if not self.spatial_group_norm else "nearest", ) return hidden_states class EasyAnimateDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, num_layers: int = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-6, spatial_group_norm: bool = True, dropout: float = 0.0, output_scale_factor: float = 1.0, add_downsample: bool = True, add_temporal_downsample: bool = True, ): super().__init__() self.convs = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.convs.append( EasyAnimateResidualBlock3D( in_channels=in_channels, out_channels=out_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, spatial_group_norm=spatial_group_norm, dropout=dropout, output_scale_factor=output_scale_factor, ) ) if add_downsample and add_temporal_downsample: self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(2, 2, 2)) self.spatial_downsample_factor = 2 self.temporal_downsample_factor = 2 elif add_downsample and not add_temporal_downsample: self.downsampler = EasyAnimateDownsampler3D(out_channels, out_channels, kernel_size=3, stride=(1, 2, 2)) self.spatial_downsample_factor = 2 self.temporal_downsample_factor = 1 else: self.downsampler = None self.spatial_downsample_factor = 1 self.temporal_downsample_factor = 1 def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for conv in self.convs: hidden_states = conv(hidden_states) if self.downsampler is not None: hidden_states = self.downsampler(hidden_states) return hidden_states class EasyAnimateUpBlock3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, num_layers: int = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-6, spatial_group_norm: bool = False, dropout: float = 0.0, output_scale_factor: float = 1.0, add_upsample: bool = True, add_temporal_upsample: bool = True, ): super().__init__() self.convs = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.convs.append( EasyAnimateResidualBlock3D( in_channels=in_channels, out_channels=out_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, spatial_group_norm=spatial_group_norm, dropout=dropout, output_scale_factor=output_scale_factor, ) ) if add_upsample: self.upsampler = EasyAnimateUpsampler3D( in_channels, in_channels, temporal_upsample=add_temporal_upsample, spatial_group_norm=spatial_group_norm, ) else: self.upsampler = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for conv in self.convs: hidden_states = conv(hidden_states) if self.upsampler is not None: hidden_states = self.upsampler(hidden_states) return hidden_states class EasyAnimateMidBlock3d(nn.Module): def __init__( self, in_channels: int, num_layers: int = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-6, spatial_group_norm: bool = True, dropout: float = 0.0, output_scale_factor: float = 1.0, ): super().__init__() norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32) self.convs = nn.ModuleList( [ EasyAnimateResidualBlock3D( in_channels=in_channels, out_channels=in_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, spatial_group_norm=spatial_group_norm, dropout=dropout, output_scale_factor=output_scale_factor, ) ] ) for _ in range(num_layers - 1): self.convs.append( EasyAnimateResidualBlock3D( in_channels=in_channels, out_channels=in_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, spatial_group_norm=spatial_group_norm, dropout=dropout, output_scale_factor=output_scale_factor, ) ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.convs[0](hidden_states) for resnet in self.convs[1:]: hidden_states = resnet(hidden_states) return hidden_states class EasyAnimateEncoder(nn.Module): r""" Causal encoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991). """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 3, out_channels: int = 8, down_block_types: Tuple[str, ...] = ( "SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", ), block_out_channels: Tuple[int, ...] = [128, 256, 512, 512], layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, spatial_group_norm: bool = False, ): super().__init__() # 1. Input convolution self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[0], kernel_size=3) # 2. Down blocks self.down_blocks = nn.ModuleList([]) output_channels = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channels = output_channels output_channels = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "SpatialDownBlock3D": down_block = EasyAnimateDownBlock3D( in_channels=input_channels, out_channels=output_channels, num_layers=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=1e-6, spatial_group_norm=spatial_group_norm, add_downsample=not is_final_block, add_temporal_downsample=False, ) elif down_block_type == "SpatialTemporalDownBlock3D": down_block = EasyAnimateDownBlock3D( in_channels=input_channels, out_channels=output_channels, num_layers=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=1e-6, spatial_group_norm=spatial_group_norm, add_downsample=not is_final_block, add_temporal_downsample=True, ) else: raise ValueError(f"Unknown up block type: {down_block_type}") self.down_blocks.append(down_block) # 3. Middle block self.mid_block = EasyAnimateMidBlock3d( in_channels=block_out_channels[-1], num_layers=layers_per_block, act_fn=act_fn, spatial_group_norm=spatial_group_norm, norm_num_groups=norm_num_groups, norm_eps=1e-6, dropout=0, output_scale_factor=1, ) # 4. Output normalization & convolution self.spatial_group_norm = spatial_group_norm self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6, ) self.conv_act = get_activation(act_fn) # Initialize the output convolution layer conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = EasyAnimateCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states: (B, C, T, H, W) hidden_states = self.conv_in(hidden_states) for down_block in self.down_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) else: hidden_states = down_block(hidden_states) hidden_states = self.mid_block(hidden_states) if self.spatial_group_norm: batch_size = hidden_states.size(0) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.conv_norm_out(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) else: hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class EasyAnimateDecoder(nn.Module): r""" Causal decoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991). """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 8, out_channels: int = 3, up_block_types: Tuple[str, ...] = ( "SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", ), block_out_channels: Tuple[int, ...] = [128, 256, 512, 512], layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", spatial_group_norm: bool = False, ): super().__init__() # 1. Input convolution self.conv_in = EasyAnimateCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3) # 2. Middle block self.mid_block = EasyAnimateMidBlock3d( in_channels=block_out_channels[-1], num_layers=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=1e-6, dropout=0, output_scale_factor=1, ) # 3. Up blocks self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(block_out_channels)) output_channels = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): input_channels = output_channels output_channels = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 # Create and append up block to up_blocks if up_block_type == "SpatialUpBlock3D": up_block = EasyAnimateUpBlock3d( in_channels=input_channels, out_channels=output_channels, num_layers=layers_per_block + 1, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=1e-6, spatial_group_norm=spatial_group_norm, add_upsample=not is_final_block, add_temporal_upsample=False, ) elif up_block_type == "SpatialTemporalUpBlock3D": up_block = EasyAnimateUpBlock3d( in_channels=input_channels, out_channels=output_channels, num_layers=layers_per_block + 1, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=1e-6, spatial_group_norm=spatial_group_norm, add_upsample=not is_final_block, add_temporal_upsample=True, ) else: raise ValueError(f"Unknown up block type: {up_block_type}") self.up_blocks.append(up_block) # Output normalization and activation self.spatial_group_norm = spatial_group_norm self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6, ) self.conv_act = get_activation(act_fn) # Output convolution layer self.conv_out = EasyAnimateCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states: (B, C, T, H, W) hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) else: hidden_states = self.mid_block(hidden_states) for up_block in self.up_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) else: hidden_states = up_block(hidden_states) if self.spatial_group_norm: batch_size = hidden_states.size(0) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W] hidden_states = self.conv_norm_out(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute( 0, 2, 1, 3, 4 ) # [B * T, C, H, W] -> [B, C, T, H, W] else: hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class AutoencoderKLMagvit(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991). This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 3, latent_channels: int = 16, out_channels: int = 3, block_out_channels: Tuple[int, ...] = [128, 256, 512, 512], down_block_types: Tuple[str, ...] = [ "SpatialDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", "SpatialTemporalDownBlock3D", ], up_block_types: Tuple[str, ...] = [ "SpatialUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", "SpatialTemporalUpBlock3D", ], layers_per_block: int = 2, act_fn: str = "silu", norm_num_groups: int = 32, scaling_factor: float = 0.7125, spatial_group_norm: bool = True, ): super().__init__() # Initialize the encoder self.encoder = EasyAnimateEncoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, double_z=True, spatial_group_norm=spatial_group_norm, ) # Initialize the decoder self.decoder = EasyAnimateDecoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, spatial_group_norm=spatial_group_norm, ) # Initialize convolution layers for quantization and post-quantization self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) self.temporal_compression_ratio = 2 ** (len(block_out_channels) - 2) # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. self.use_slicing = False # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the # intermediate tiles together, the memory requirement can be lowered. self.use_tiling = False # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames # at a fixed frame batch size (based on `self.num_latent_frames_batch_size`), the memory requirement can be lowered. self.use_framewise_encoding = False self.use_framewise_decoding = False # Assign mini-batch sizes for encoder and decoder self.num_sample_frames_batch_size = 4 self.num_latent_frames_batch_size = 1 # The minimal tile height and width for spatial tiling to be used self.tile_sample_min_height = 512 self.tile_sample_min_width = 512 self.tile_sample_min_num_frames = 4 # The minimal distance between two spatial tiles self.tile_sample_stride_height = 448 self.tile_sample_stride_width = 448 self.tile_sample_stride_num_frames = 8 def _clear_conv_cache(self): # Clear cache for convolutional layers if needed for name, module in self.named_modules(): if isinstance(module, EasyAnimateCausalConv3d): module._clear_conv_cache() if isinstance(module, EasyAnimateUpsampler3D): module._clear_conv_cache() def enable_tiling( self, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None, tile_sample_min_num_frames: Optional[int] = None, tile_sample_stride_height: Optional[float] = None, tile_sample_stride_width: Optional[float] = None, tile_sample_stride_num_frames: Optional[float] = None, ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. Args: tile_sample_min_height (`int`, *optional*): The minimum height required for a sample to be separated into tiles across the height dimension. tile_sample_min_width (`int`, *optional*): The minimum width required for a sample to be separated into tiles across the width dimension. tile_sample_stride_height (`int`, *optional*): The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are no tiling artifacts produced across the height dimension. tile_sample_stride_width (`int`, *optional*): The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling artifacts produced across the width dimension. """ self.use_tiling = True self.use_framewise_decoding = True self.use_framewise_encoding = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames def disable_tiling(self) -> None: r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.use_tiling = False def enable_slicing(self) -> None: r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self) -> None: r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @apply_forward_hook def _encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_height or x.shape[-2] > self.tile_sample_min_width): return self.tiled_encode(x, return_dict=return_dict) first_frames = self.encoder(x[:, :, :1, :, :]) h = [first_frames] for i in range(1, x.shape[2], self.num_sample_frames_batch_size): next_frames = self.encoder(x[:, :, i : i + self.num_sample_frames_batch_size, :, :]) h.append(next_frames) h = torch.cat(h, dim=2) moments = self.quant_conv(h) self._clear_conv_cache() return moments @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self._encode(x) posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width): return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) # Process the first frame and save the result first_frames = self.decoder(z[:, :, :1, :, :]) # Initialize the list to store the processed frames, starting with the first frame dec = [first_frames] # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder for i in range(1, z.shape[2], self.num_latent_frames_batch_size): next_frames = self.decoder(z[:, :, i : i + self.num_latent_frames_batch_size, :, :]) dec.append(next_frames) # Concatenate all processed frames along the channel dimension dec = torch.cat(dec, dim=2) if not return_dict: return (dec,) return DecoderOutput(sample=dec) @apply_forward_hook def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample self._clear_conv_cache() if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for y in range(blend_extent): b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( y / blend_extent ) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[4], b.shape[4], blend_extent) for x in range(blend_extent): b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( x / blend_extent ) return b def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: batch_size, num_channels, num_frames, height, width = x.shape latent_height = height // self.spatial_compression_ratio latent_width = width // self.spatial_compression_ratio tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio blend_height = tile_latent_min_height - tile_latent_stride_height blend_width = tile_latent_min_width - tile_latent_stride_width # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, height, self.tile_sample_stride_height): row = [] for j in range(0, width, self.tile_sample_stride_width): tile = x[ :, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width, ] first_frames = self.encoder(tile[:, :, 0:1, :, :]) tile_h = [first_frames] for k in range(1, num_frames, self.num_sample_frames_batch_size): next_frames = self.encoder(tile[:, :, k : k + self.num_sample_frames_batch_size, :, :]) tile_h.append(next_frames) tile = torch.cat(tile_h, dim=2) tile = self.quant_conv(tile) self._clear_conv_cache() row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, :latent_height, :latent_width]) result_rows.append(torch.cat(result_row, dim=4)) moments = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return moments def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape sample_height = height * self.spatial_compression_ratio sample_width = width * self.spatial_compression_ratio tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio blend_height = self.tile_sample_min_height - self.tile_sample_stride_height blend_width = self.tile_sample_min_width - self.tile_sample_stride_width # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, height, tile_latent_stride_height): row = [] for j in range(0, width, tile_latent_stride_width): tile = z[ :, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width, ] tile = self.post_quant_conv(tile) # Process the first frame and save the result first_frames = self.decoder(tile[:, :, :1, :, :]) # Initialize the list to store the processed frames, starting with the first frame tile_dec = [first_frames] # Process the remaining frames, with the number of frames processed at a time determined by mini_batch_decoder for k in range(1, num_frames, self.num_latent_frames_batch_size): next_frames = self.decoder(tile[:, :, k : k + self.num_latent_frames_batch_size, :, :]) tile_dec.append(next_frames) # Concatenate all processed frames along the channel dimension decoded = torch.cat(tile_dec, dim=2) self._clear_conv_cache() row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=4)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec)