Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import numpy as np | |
from ..modules.vaemodules.activations import get_activation | |
from ..modules.vaemodules.common import CausalConv3d | |
from ..modules.vaemodules.down_blocks import get_down_block | |
from ..modules.vaemodules.mid_blocks import get_mid_block | |
from ..modules.vaemodules.up_blocks import get_up_block | |
class Encoder(nn.Module): | |
r""" | |
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. | |
Args: | |
in_channels (`int`, *optional*, defaults to 3): | |
The number of input channels. | |
out_channels (`int`, *optional*, defaults to 8): | |
The number of output channels. | |
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialDownBlock3D",)`): | |
The types of down blocks to use. | |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
The number of output channels for each block. | |
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`): | |
Whether to use global context blocks for each down block. | |
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`): | |
The type of mid block to use. | |
layers_per_block (`int`, *optional*, defaults to 2): | |
The number of layers per block. | |
norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups for normalization. | |
act_fn (`str`, *optional*, defaults to `"silu"`): | |
The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
num_attention_heads (`int`, *optional*, defaults to 1): | |
The number of attention heads to use. | |
double_z (`bool`, *optional*, defaults to `True`): | |
Whether to double the number of output channels for the last block. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 3, | |
out_channels: int = 8, | |
down_block_types = ("SpatialDownBlock3D",), | |
ch = 128, | |
ch_mult = [1,2,4,4,], | |
use_gc_blocks = None, | |
mid_block_type: str = "MidBlock3D", | |
mid_block_use_attention: bool = True, | |
mid_block_attention_type: str = "3d", | |
mid_block_num_attention_heads: int = 1, | |
layers_per_block: int = 2, | |
norm_num_groups: int = 32, | |
act_fn: str = "silu", | |
num_attention_heads: int = 1, | |
double_z: bool = True, | |
slice_compression_vae: bool = False, | |
mini_batch_encoder: int = 9, | |
verbose = False, | |
): | |
super().__init__() | |
block_out_channels = [ch * i for i in ch_mult] | |
assert len(down_block_types) == len(block_out_channels), ( | |
"Number of down block types must match number of block output channels." | |
) | |
if use_gc_blocks is not None: | |
assert len(use_gc_blocks) == len(down_block_types), ( | |
"Number of GC blocks must match number of down block types." | |
) | |
else: | |
use_gc_blocks = [False] * len(down_block_types) | |
self.conv_in = CausalConv3d( | |
in_channels, | |
block_out_channels[0], | |
kernel_size=3, | |
) | |
self.down_blocks = nn.ModuleList([]) | |
output_channels = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channels = output_channels | |
output_channels = block_out_channels[i] | |
is_final_block = (i == len(block_out_channels) - 1) | |
down_block = get_down_block( | |
down_block_type, | |
in_channels=input_channels, | |
out_channels=output_channels, | |
num_layers=layers_per_block, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=1e-6, | |
num_attention_heads=num_attention_heads, | |
add_gc_block=use_gc_blocks[i], | |
add_downsample=not is_final_block, | |
) | |
self.down_blocks.append(down_block) | |
self.mid_block = get_mid_block( | |
mid_block_type, | |
in_channels=block_out_channels[-1], | |
num_layers=layers_per_block, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=1e-6, | |
add_attention=mid_block_use_attention, | |
attention_type=mid_block_attention_type, | |
num_attention_heads=mid_block_num_attention_heads, | |
) | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[-1], | |
num_groups=norm_num_groups, | |
eps=1e-6, | |
) | |
self.conv_act = get_activation(act_fn) | |
conv_out_channels = 2 * out_channels if double_z else out_channels | |
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) | |
self.slice_compression_vae = slice_compression_vae | |
self.mini_batch_encoder = mini_batch_encoder | |
self.features_share = False | |
self.verbose = verbose | |
def set_padding_one_frame(self): | |
def _set_padding_one_frame(name, module): | |
if hasattr(module, 'padding_flag'): | |
if self.verbose: | |
print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
module.padding_flag = 1 | |
for sub_name, sub_mod in module.named_children(): | |
_set_padding_one_frame(sub_name, sub_mod) | |
for name, module in self.named_children(): | |
_set_padding_one_frame(name, module) | |
def set_padding_more_frame(self): | |
def _set_padding_more_frame(name, module): | |
if hasattr(module, 'padding_flag'): | |
if self.verbose: | |
print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
module.padding_flag = 2 | |
for sub_name, sub_mod in module.named_children(): | |
_set_padding_more_frame(sub_name, sub_mod) | |
for name, module in self.named_children(): | |
_set_padding_more_frame(name, module) | |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: | |
# x: (B, C, T, H, W) | |
if self.features_share and previous_features is not None and after_features is None: | |
x = torch.concat([previous_features, x], 2) | |
elif self.features_share and previous_features is None and after_features is not None: | |
x = torch.concat([x, after_features], 2) | |
elif self.features_share and previous_features is not None and after_features is not None: | |
x = torch.concat([previous_features, x, after_features], 2) | |
x = self.conv_in(x) | |
for down_block in self.down_blocks: | |
x = down_block(x) | |
x = self.mid_block(x) | |
x = self.conv_norm_out(x) | |
x = self.conv_act(x) | |
x = self.conv_out(x) | |
if self.features_share and previous_features is not None and after_features is None: | |
x = x[:, :, 1:] | |
elif self.features_share and previous_features is None and after_features is not None: | |
x = x[:, :, :2] | |
elif self.features_share and previous_features is not None and after_features is not None: | |
x = x[:, :, 1:3] | |
return x | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.slice_compression_vae: | |
_, _, f, _, _ = x.size() | |
if f % 2 != 0: | |
self.set_padding_one_frame() | |
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
self.set_padding_more_frame() | |
new_pixel_values = [first_frames] | |
start_index = 1 | |
else: | |
self.set_padding_more_frame() | |
new_pixel_values = [] | |
start_index = 0 | |
previous_features = None | |
for i in range(start_index, x.shape[2], self.mini_batch_encoder): | |
after_features = x[:, :, i + self.mini_batch_encoder: i + self.mini_batch_encoder + 4, :, :] if i + self.mini_batch_encoder < x.shape[2] else None | |
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_encoder, :, :], previous_features, after_features) | |
previous_features = x[:, :, i + self.mini_batch_encoder - 4: i + self.mini_batch_encoder, :, :] | |
new_pixel_values.append(next_frames) | |
new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
else: | |
new_pixel_values = self.single_forward(x, None, None) | |
return new_pixel_values | |
class Decoder(nn.Module): | |
r""" | |
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. | |
Args: | |
in_channels (`int`, *optional*, defaults to 8): | |
The number of input channels. | |
out_channels (`int`, *optional*, defaults to 3): | |
The number of output channels. | |
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("SpatialUpBlock3D",)`): | |
The types of up blocks to use. | |
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): | |
The number of output channels for each block. | |
use_gc_blocks (`Tuple[bool, ...]`, *optional*, defaults to `None`): | |
Whether to use global context blocks for each down block. | |
mid_block_type (`str`, *optional*, defaults to `"MidBlock3D"`): | |
The type of mid block to use. | |
layers_per_block (`int`, *optional*, defaults to 2): | |
The number of layers per block. | |
norm_num_groups (`int`, *optional*, defaults to 32): | |
The number of groups for normalization. | |
act_fn (`str`, *optional*, defaults to `"silu"`): | |
The activation function to use. See `~diffusers.models.activations.get_activation` for available options. | |
num_attention_heads (`int`, *optional*, defaults to 1): | |
The number of attention heads to use. | |
""" | |
def __init__( | |
self, | |
in_channels: int = 8, | |
out_channels: int = 3, | |
up_block_types = ("SpatialUpBlock3D",), | |
ch = 128, | |
ch_mult = [1,2,4,4,], | |
use_gc_blocks = None, | |
mid_block_type: str = "MidBlock3D", | |
mid_block_use_attention: bool = True, | |
mid_block_attention_type: str = "3d", | |
mid_block_num_attention_heads: int = 1, | |
layers_per_block: int = 2, | |
norm_num_groups: int = 32, | |
act_fn: str = "silu", | |
num_attention_heads: int = 1, | |
slice_compression_vae: bool = False, | |
mini_batch_decoder: int = 3, | |
verbose = False, | |
): | |
super().__init__() | |
block_out_channels = [ch * i for i in ch_mult] | |
assert len(up_block_types) == len(block_out_channels), ( | |
"Number of up block types must match number of block output channels." | |
) | |
if use_gc_blocks is not None: | |
assert len(use_gc_blocks) == len(up_block_types), ( | |
"Number of GC blocks must match number of up block types." | |
) | |
else: | |
use_gc_blocks = [False] * len(up_block_types) | |
self.conv_in = CausalConv3d( | |
in_channels, | |
block_out_channels[-1], | |
kernel_size=3, | |
) | |
self.mid_block = get_mid_block( | |
mid_block_type, | |
in_channels=block_out_channels[-1], | |
num_layers=layers_per_block, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=1e-6, | |
add_attention=mid_block_use_attention, | |
attention_type=mid_block_attention_type, | |
num_attention_heads=mid_block_num_attention_heads, | |
) | |
self.up_blocks = nn.ModuleList([]) | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
output_channels = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
input_channels = output_channels | |
output_channels = reversed_block_out_channels[i] | |
# is_first_block = i == 0 | |
is_final_block = i == len(block_out_channels) - 1 | |
up_block = get_up_block( | |
up_block_type, | |
in_channels=input_channels, | |
out_channels=output_channels, | |
num_layers=layers_per_block + 1, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=1e-6, | |
num_attention_heads=num_attention_heads, | |
add_gc_block=use_gc_blocks[i], | |
add_upsample=not is_final_block, | |
) | |
self.up_blocks.append(up_block) | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], | |
num_groups=norm_num_groups, | |
eps=1e-6, | |
) | |
self.conv_act = get_activation(act_fn) | |
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) | |
self.slice_compression_vae = slice_compression_vae | |
self.mini_batch_decoder = mini_batch_decoder | |
self.features_share = True | |
self.verbose = verbose | |
def set_padding_one_frame(self): | |
def _set_padding_one_frame(name, module): | |
if hasattr(module, 'padding_flag'): | |
if self.verbose: | |
print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
module.padding_flag = 1 | |
for sub_name, sub_mod in module.named_children(): | |
_set_padding_one_frame(sub_name, sub_mod) | |
for name, module in self.named_children(): | |
_set_padding_one_frame(name, module) | |
def set_padding_more_frame(self): | |
def _set_padding_more_frame(name, module): | |
if hasattr(module, 'padding_flag'): | |
if self.verbose: | |
print('Set pad mode for module[%s] type=%s' % (name, str(type(module)))) | |
module.padding_flag = 2 | |
for sub_name, sub_mod in module.named_children(): | |
_set_padding_more_frame(sub_name, sub_mod) | |
for name, module in self.named_children(): | |
_set_padding_more_frame(name, module) | |
def single_forward(self, x: torch.Tensor, previous_features: torch.Tensor, after_features: torch.Tensor) -> torch.Tensor: | |
# x: (B, C, T, H, W) | |
if self.features_share and previous_features is not None and after_features is None: | |
b, c, t, h, w = x.size() | |
x = torch.concat([previous_features, x], 2) | |
x = self.conv_in(x) | |
x = self.mid_block(x) | |
x = x[:, :, -t:] | |
elif self.features_share and previous_features is None and after_features is not None: | |
b, c, t, h, w = x.size() | |
x = torch.concat([x, after_features], 2) | |
x = self.conv_in(x) | |
x = self.mid_block(x) | |
x = x[:, :, :t] | |
elif self.features_share and previous_features is not None and after_features is not None: | |
_, _, t_1, _, _ = previous_features.size() | |
_, _, t_2, _, _ = x.size() | |
x = torch.concat([previous_features, x, after_features], 2) | |
x = self.conv_in(x) | |
x = self.mid_block(x) | |
x = x[:, :, t_1:(t_1 + t_2)] | |
else: | |
x = self.conv_in(x) | |
x = self.mid_block(x) | |
for up_block in self.up_blocks: | |
x = up_block(x) | |
x = self.conv_norm_out(x) | |
x = self.conv_act(x) | |
x = self.conv_out(x) | |
return x | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.slice_compression_vae: | |
_, _, f, _, _ = x.size() | |
if f % 2 != 0: | |
self.set_padding_one_frame() | |
first_frames = self.single_forward(x[:, :, 0:1, :, :], None, None) | |
self.set_padding_more_frame() | |
new_pixel_values = [first_frames] | |
start_index = 1 | |
else: | |
self.set_padding_more_frame() | |
new_pixel_values = [] | |
start_index = 0 | |
previous_features = None | |
for i in range(start_index, x.shape[2], self.mini_batch_decoder): | |
after_features = x[:, :, i + self.mini_batch_decoder: i + 2 * self.mini_batch_decoder, :, :] if i + self.mini_batch_decoder < x.shape[2] else None | |
next_frames = self.single_forward(x[:, :, i: i + self.mini_batch_decoder, :, :], previous_features, after_features) | |
previous_features = x[:, :, i: i + self.mini_batch_decoder, :, :] | |
new_pixel_values.append(next_frames) | |
new_pixel_values = torch.cat(new_pixel_values, dim=2) | |
else: | |
new_pixel_values = self.single_forward(x, None, None) | |
return new_pixel_values | |