bubbliiiing
Create Code
19fe404
import torch
import torch.nn as nn
from .attention import SpatialAttention, TemporalAttention
from .common import ResidualBlock3D
from .gc_block import GlobalContextBlock
from .upsamplers import (SpatialTemporalUpsampler3D, SpatialUpsampler3D,
TemporalUpsampler3D)
def get_up_block(
up_block_type: str,
in_channels: int,
out_channels: int,
num_layers: int,
act_fn: str,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
) -> nn.Module:
if up_block_type == "SpatialUpBlock3D":
return SpatialUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "SpatialAttnUpBlock3D":
return SpatialAttnUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "TemporalUpBlock3D":
return TemporalUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "TemporalAttnUpBlock3D":
return TemporalAttnUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
attention_head_dim=out_channels // num_attention_heads,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
elif up_block_type == "SpatialTemporalUpBlock3D":
return SpatialTemporalUpBlock3D(
in_channels=in_channels,
out_channels=out_channels,
num_layers=num_layers,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
add_gc_block=add_gc_block,
add_upsample=add_upsample,
)
else:
raise ValueError(f"Unknown up block type: {up_block_type}")
class SpatialUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
if add_upsample:
self.upsampler = SpatialUpsampler3D(in_channels, in_channels)
else:
self.upsampler = None
if add_gc_block:
self.gc_block = GlobalContextBlock(in_channels, in_channels, fusion_type="mul")
else:
self.gc_block = None
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class SpatialAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
SpatialAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = SpatialUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class TemporalUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class TemporalAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
self.attentions.append(
TemporalAttention(
out_channels,
nheads=out_channels // attention_head_dim,
head_dim=attention_head_dim,
bias=True,
upcast_softmax=True,
norm_num_groups=norm_num_groups,
eps=norm_eps,
rescale_output_factor=output_scale_factor,
residual_connection=True,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = TemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv, attn in zip(self.convs, self.attentions):
x = conv(x)
x = attn(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x
class SpatialTemporalUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_gc_block: bool = False,
add_upsample: bool = True,
):
super().__init__()
self.convs = nn.ModuleList([])
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.convs.append(
ResidualBlock3D(
in_channels=in_channels,
out_channels=out_channels,
non_linearity=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
dropout=dropout,
output_scale_factor=output_scale_factor,
)
)
if add_gc_block:
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul")
else:
self.gc_block = None
if add_upsample:
self.upsampler = SpatialTemporalUpsampler3D(out_channels, out_channels)
else:
self.upsampler = None
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
for conv in self.convs:
x = conv(x)
if self.gc_block is not None:
x = self.gc_block(x)
if self.upsampler is not None:
x = self.upsampler(x)
return x