Spaces:
Runtime error
Runtime error
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. | |
# All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from diffusers.models.autoencoders.autoencoder_kl_cogvideox import ( | |
CogVideoXCausalConv3d, CogVideoXDownBlock3D, CogVideoXMidBlock3D, | |
CogVideoXSafeConv3d, CogVideoXSpatialNorm3D, CogVideoXUpBlock3D) | |
from diffusers.utils import logging | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class CogVideoXEncoder3D(nn.Module): | |
r""" | |
The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. | |
Args: | |
in_channels (`int`, *optional*, defaults to 3): | |
The number of input channels. | |
out_channels (`int`, *optional*, defaults to 3): | |
The number of output channels. | |
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): | |
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available | |
options. | |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
The number of output channels for each block. | |
act_fn (`str`, *optional*, defaults to `"silu"`): | |
The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
layers_per_block (`int`, *optional*, defaults to 2): | |
The number of layers per block. | |
norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups for normalization. | |
""" | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 16, | |
down_block_types: Tuple[str, ...] = ( | |
"CogVideoXDownBlock3D", | |
"CogVideoXDownBlock3D", | |
"CogVideoXDownBlock3D", | |
"CogVideoXDownBlock3D", | |
), | |
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), | |
layers_per_block: int = 3, | |
act_fn: str = "silu", | |
norm_eps: float = 1e-6, | |
norm_num_groups: int = 32, | |
dropout: float = 0.0, | |
pad_mode: str = "first", | |
temporal_compression_ratio: float = 4, | |
): | |
super().__init__() | |
# log2 of temporal_compress_times | |
temporal_compress_level = int(np.log2(temporal_compression_ratio)) | |
self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) | |
self.down_blocks = nn.ModuleList([]) | |
# down blocks | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
compress_time = i < temporal_compress_level | |
if down_block_type == "CogVideoXDownBlock3D": | |
down_block = CogVideoXDownBlock3D( | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=0, | |
dropout=dropout, | |
num_layers=layers_per_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
add_downsample=not is_final_block, | |
compress_time=compress_time, | |
) | |
else: | |
raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") | |
self.down_blocks.append(down_block) | |
# mid block | |
self.mid_block = CogVideoXMidBlock3D( | |
in_channels=block_out_channels[-1], | |
temb_channels=0, | |
dropout=dropout, | |
num_layers=2, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
pad_mode=pad_mode, | |
) | |
self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) | |
self.conv_act = nn.SiLU() | |
self.conv_out = CogVideoXCausalConv3d( | |
block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode | |
) | |
self.gradient_checkpointing = False | |
def _clear_fake_context_parallel_cache(self): | |
for name, module in self.named_modules(): | |
if isinstance(module, CogVideoXCausalConv3d): | |
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") | |
module._clear_fake_context_parallel_cache() | |
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
r"""The forward method of the `CogVideoXEncoder3D` class.""" | |
hidden_states = self.conv_in(sample) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
# 1. Down | |
for down_block in self.down_blocks: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(down_block), hidden_states, temb, None | |
) | |
# 2. Mid | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.mid_block), hidden_states, temb, None | |
) | |
else: | |
# 1. Down | |
for down_block in self.down_blocks: | |
hidden_states = down_block(hidden_states, temb, None) | |
# 2. Mid | |
hidden_states = self.mid_block(hidden_states, temb, None) | |
# 3. Post-process | |
hidden_states = self.norm_out(hidden_states) | |
hidden_states = self.conv_act(hidden_states) | |
hidden_states = self.conv_out(hidden_states) | |
return hidden_states | |
class CogVideoXDecoder3D(nn.Module): | |
r""" | |
The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output | |
sample. | |
Args: | |
in_channels (`int`, *optional*, defaults to 3): | |
The number of input channels. | |
out_channels (`int`, *optional*, defaults to 3): | |
The number of output channels. | |
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): | |
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. | |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
The number of output channels for each block. | |
act_fn (`str`, *optional*, defaults to `"silu"`): | |
The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
layers_per_block (`int`, *optional*, defaults to 2): | |
The number of layers per block. | |
norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups for normalization. | |
""" | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
in_channels: int = 16, | |
out_channels: int = 3, | |
up_block_types: Tuple[str, ...] = ( | |
"CogVideoXUpBlock3D", | |
"CogVideoXUpBlock3D", | |
"CogVideoXUpBlock3D", | |
"CogVideoXUpBlock3D", | |
), | |
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512), | |
layers_per_block: int = 3, | |
act_fn: str = "silu", | |
norm_eps: float = 1e-6, | |
norm_num_groups: int = 32, | |
dropout: float = 0.0, | |
pad_mode: str = "first", | |
temporal_compression_ratio: float = 4, | |
): | |
super().__init__() | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
self.conv_in = CogVideoXCausalConv3d( | |
in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode | |
) | |
# mid block | |
self.mid_block = CogVideoXMidBlock3D( | |
in_channels=reversed_block_out_channels[0], | |
temb_channels=0, | |
num_layers=2, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
spatial_norm_dim=in_channels, | |
pad_mode=pad_mode, | |
) | |
# up blocks | |
self.up_blocks = nn.ModuleList([]) | |
output_channel = reversed_block_out_channels[0] | |
temporal_compress_level = int(np.log2(temporal_compression_ratio)) | |
for i, up_block_type in enumerate(up_block_types): | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
compress_time = i < temporal_compress_level | |
if up_block_type == "CogVideoXUpBlock3D": | |
up_block = CogVideoXUpBlock3D( | |
in_channels=prev_output_channel, | |
out_channels=output_channel, | |
temb_channels=0, | |
dropout=dropout, | |
num_layers=layers_per_block + 1, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
spatial_norm_dim=in_channels, | |
add_upsample=not is_final_block, | |
compress_time=compress_time, | |
pad_mode=pad_mode, | |
) | |
prev_output_channel = output_channel | |
else: | |
raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") | |
self.up_blocks.append(up_block) | |
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) | |
self.conv_act = nn.SiLU() | |
self.conv_out = CogVideoXCausalConv3d( | |
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode | |
) | |
self.gradient_checkpointing = False | |
def _clear_fake_context_parallel_cache(self): | |
for name, module in self.named_modules(): | |
if isinstance(module, CogVideoXCausalConv3d): | |
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}") | |
module._clear_fake_context_parallel_cache() | |
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
r"""The forward method of the `CogVideoXDecoder3D` class.""" | |
hidden_states = self.conv_in(sample) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
# 1. Mid | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(self.mid_block), hidden_states, temb, sample | |
) | |
# 2. Up | |
for up_block in self.up_blocks: | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(up_block), hidden_states, temb, sample | |
) | |
else: | |
# 1. Mid | |
hidden_states = self.mid_block(hidden_states, temb, sample) | |
# 2. Up | |
for up_block in self.up_blocks: | |
hidden_states = up_block(hidden_states, temb, sample) | |
# 3. Post-process | |
hidden_states = self.norm_out(hidden_states, sample) | |
hidden_states = self.conv_act(hidden_states) | |
hidden_states = self.conv_out(hidden_states) | |
return hidden_states |