# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn from diffusers.models.autoencoders.autoencoder_kl_cogvideox import ( CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D, CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D) from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CogVideoXEncoder3D(nn.Module): r""" The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 3, out_channels: int = 16, down_block_types: Tuple[str, ...] = ( "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", "CogVideoXDownBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() # log2 of temporal_compress_times temporal_compress_level = int(np.log2(temporal_compression_ratio)) self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) self.down_blocks = nn.ModuleList([]) # down blocks output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if down_block_type == "CogVideoXDownBlock3D": down_block = CogVideoXDownBlock3D( in_channels=input_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_downsample=not is_final_block, compress_time=compress_time, ) else: raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") self.down_blocks.append(down_block) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=block_out_channels[-1], temb_channels=0, dropout=dropout, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, pad_mode=pad_mode, ) self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def _clear_fake_context_parallel_cache(self): for name, module in self.named_modules(): if isinstance(module, CogVideoXCausalConv3d): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXEncoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Down for down_block in self.down_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), hidden_states, temb, None ) # 2. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, None ) else: # 1. Down for down_block in self.down_blocks: hidden_states = down_block(hidden_states, temb, None) # 2. Mid hidden_states = self.mid_block(hidden_states, temb, None) # 3. Post-process hidden_states = self.norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class CogVideoXDecoder3D(nn.Module): r""" The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. """ _supports_gradient_checkpointing = True def __init__( self, in_channels: int = 16, out_channels: int = 3, up_block_types: Tuple[str, ...] = ( "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", "CogVideoXUpBlock3D", ), block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), layers_per_block: int = 3, act_fn: str = "silu", norm_eps: float = 1e-6, norm_num_groups: int = 32, dropout: float = 0.0, pad_mode: str = "first", temporal_compression_ratio: float = 4, ): super().__init__() reversed_block_out_channels = list(reversed(block_out_channels)) self.conv_in = CogVideoXCausalConv3d( in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode ) # mid block self.mid_block = CogVideoXMidBlock3D( in_channels=reversed_block_out_channels[0], temb_channels=0, num_layers=2, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, pad_mode=pad_mode, ) # up blocks self.up_blocks = nn.ModuleList([]) output_channel = reversed_block_out_channels[0] temporal_compress_level = int(np.log2(temporal_compression_ratio)) for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 compress_time = i < temporal_compress_level if up_block_type == "CogVideoXUpBlock3D": up_block = CogVideoXUpBlock3D( in_channels=prev_output_channel, out_channels=output_channel, temb_channels=0, dropout=dropout, num_layers=layers_per_block + 1, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, spatial_norm_dim=in_channels, add_upsample=not is_final_block, compress_time=compress_time, pad_mode=pad_mode, ) prev_output_channel = output_channel else: raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") self.up_blocks.append(up_block) self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) self.conv_act = nn.SiLU() self.conv_out = CogVideoXCausalConv3d( reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode ) self.gradient_checkpointing = False def _clear_fake_context_parallel_cache(self): for name, module in self.named_modules(): if isinstance(module, CogVideoXCausalConv3d): logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") module._clear_fake_context_parallel_cache() def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: r"""The forward method of the `CogVideoXDecoder3D` class.""" hidden_states = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # 1. Mid hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), hidden_states, temb, sample ) # 2. Up for up_block in self.up_blocks: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), hidden_states, temb, sample ) else: # 1. Mid hidden_states = self.mid_block(hidden_states, temb, sample) # 2. Up for up_block in self.up_blocks: hidden_states = up_block(hidden_states, temb, sample) # 3. Post-process hidden_states = self.norm_out(hidden_states, sample) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states