bubbliiiing
Update V5.1
c2a6cd2
import torch
import torch.nn as nn
from .attention import SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
from .downsamplers import (SpatialDownsampler3D, SpatialTemporalDownsampler3D,
TemporalDownsampler3D)
from .gc_block import GlobalContextBlock
def get_down_block(
down_block_type: str,
in_channels: int,
out_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
) -> nn.Module:
if down_block_type == "DownBlock3D":
return DownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
)
elif down_block_type == "SpatialDownBlock3D":
return SpatialDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialAttnDownBlock3D":
return SpatialAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalDownBlock3D":
return TemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "TemporalAttnDownBlock3D":
return TemporalAttnDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
elif down_block_type == "SpatialTemporalDownBlock3D":
return SpatialTemporalDownBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_downsample=add_downsample,
)
else:
raise ValueError(f"Unknown down block type: {down_block_type}")
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
return x
class SpatialDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialTemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialTemporalDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class SpatialAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
SpatialAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = SpatialDownsampler3D(out_channels, out_channels)
self.spatial_downsample_factor = 2
else:
self.downsampler = None
self.spatial_downsample_factor = 1
self.temporal_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x
class TemporalAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_downsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
TemporalAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_downsample:
self.downsampler = TemporalDownsampler3D(out_channels, out_channels)
self.temporal_downsample_factor = 2
else:
self.downsampler = None
self.temporal_downsample_factor = 1
self.spatial_downsample_factor = 1
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x