import torch import torch.nn as nn from .attention import Attention3D, SpatialAttention, TemporalAttention from .common import ResidualBlock3D def get_mid_block( mid_block_type: str, in_channels: int, num_layers: int, act_fn: str, norm_num_groups: int = 32, norm_eps: float = 1e-6, dropout: float = 0.0, add_attention: bool = True, attention_type: str = "3d", num_attention_heads: int = 1, output_scale_factor: float = 1.0, ) -> nn.Module: if mid_block_type == "MidBlock3D": return MidBlock3D( in_channels=in_channels, num_layers=num_layers, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, dropout=dropout, add_attention=add_attention, attention_type=attention_type, attention_head_dim=in_channels // num_attention_heads, output_scale_factor=output_scale_factor, ) else: raise ValueError(f"Unknown mid block type: {mid_block_type}") class MidBlock3D(nn.Module): """ A 3D UNet mid-block [`MidBlock3D`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. norm_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. dropout (`float`, *optional*, defaults to 0.0): The dropout rate. add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. attention_type: (`str`, *optional*, defaults to `3d`): The type of attention to use. Defaults to `3d`. attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, temporal_length, height, width)`. """ def __init__( self, in_channels: int, num_layers: int = 1, act_fn: str = "silu", norm_num_groups: int = 32, norm_eps: float = 1e-6, dropout: float = 0.0, add_attention: bool = True, attention_type: str = "3d", attention_head_dim: int = 1, output_scale_factor: float = 1.0, ): super().__init__() self.attention_type = attention_type norm_num_groups = norm_num_groups if norm_num_groups is not None else min(in_channels // 4, 32) self.convs = nn.ModuleList([ ResidualBlock3D( in_channels=in_channels, out_channels=in_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, dropout=dropout, output_scale_factor=output_scale_factor, ) ]) self.attentions = nn.ModuleList([]) for _ in range(num_layers - 1): if add_attention: if attention_type == "3d": self.attentions.append( Attention3D( in_channels, nheads=in_channels // attention_head_dim, head_dim=attention_head_dim, bias=True, upcast_softmax=True, norm_num_groups=norm_num_groups, eps=norm_eps, rescale_output_factor=output_scale_factor, residual_connection=True, ) ) elif attention_type == "spatial_temporal": self.attentions.append( nn.ModuleList([ SpatialAttention( in_channels, nheads=in_channels // attention_head_dim, head_dim=attention_head_dim, bias=True, upcast_softmax=True, norm_num_groups=norm_num_groups, eps=norm_eps, rescale_output_factor=output_scale_factor, residual_connection=True, ), TemporalAttention( in_channels, nheads=in_channels // attention_head_dim, head_dim=attention_head_dim, bias=True, upcast_softmax=True, norm_num_groups=norm_num_groups, eps=norm_eps, rescale_output_factor=output_scale_factor, residual_connection=True, ), ]) ) elif attention_type == "spatial": self.attentions.append( SpatialAttention( in_channels, nheads=in_channels // attention_head_dim, head_dim=attention_head_dim, bias=True, upcast_softmax=True, norm_num_groups=norm_num_groups, eps=norm_eps, rescale_output_factor=output_scale_factor, residual_connection=True, ) ) elif attention_type == "temporal": self.attentions.append( TemporalAttention( in_channels, nheads=in_channels // attention_head_dim, head_dim=attention_head_dim, bias=True, upcast_softmax=True, norm_num_groups=norm_num_groups, eps=norm_eps, rescale_output_factor=output_scale_factor, residual_connection=True, ) ) else: raise ValueError(f"Unknown attention type: {attention_type}") else: self.attentions.append(None) self.convs.append( ResidualBlock3D( in_channels=in_channels, out_channels=in_channels, non_linearity=act_fn, norm_num_groups=norm_num_groups, norm_eps=norm_eps, dropout=dropout, output_scale_factor=output_scale_factor, ) ) def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.convs[0](hidden_states) for attn, resnet in zip(self.attentions, self.convs[1:]): if attn is not None: if self.attention_type == "spatial_temporal": spatial_attn, temporal_attn = attn hidden_states = spatial_attn(hidden_states) hidden_states = temporal_attn(hidden_states) else: hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states) return hidden_states