Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from .attention import SpatialAttention, TemporalAttention | |
from .common import ResidualBlock3D | |
from .gc_block import GlobalContextBlock | |
from .upsamplers import (SpatialTemporalUpsampler3D, SpatialUpsampler3D, | |
TemporalUpsampler3D) | |
def get_up_block( | |
up_block_type: str, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int, | |
act_fn: str, | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
num_attention_heads: int = 1, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
) -> nn.Module: | |
if up_block_type == "SpatialUpBlock3D": | |
return SpatialUpBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_layers=num_layers, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
add_gc_block=add_gc_block, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "SpatialAttnUpBlock3D": | |
return SpatialAttnUpBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_layers=num_layers, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
attention_head_dim=out_channels // num_attention_heads, | |
output_scale_factor=output_scale_factor, | |
add_gc_block=add_gc_block, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "TemporalUpBlock3D": | |
return TemporalUpBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_layers=num_layers, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
add_gc_block=add_gc_block, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "TemporalAttnUpBlock3D": | |
return TemporalAttnUpBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_layers=num_layers, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
attention_head_dim=out_channels // num_attention_heads, | |
output_scale_factor=output_scale_factor, | |
add_gc_block=add_gc_block, | |
add_upsample=add_upsample, | |
) | |
elif up_block_type == "SpatialTemporalUpBlock3D": | |
return SpatialTemporalUpBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_layers=num_layers, | |
act_fn=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
add_gc_block=add_gc_block, | |
add_upsample=add_upsample, | |
) | |
else: | |
raise ValueError(f"Unknown up block type: {up_block_type}") | |
class SpatialUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
if add_upsample: | |
self.upsampler = SpatialUpsampler3D(in_channels, in_channels) | |
else: | |
self.upsampler = None | |
if add_gc_block: | |
self.gc_block = GlobalContextBlock(in_channels, in_channels, fusion_type="mul") | |
else: | |
self.gc_block = None | |
self.convs = nn.ModuleList([]) | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
self.convs.append( | |
ResidualBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
non_linearity=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
) | |
) | |
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
for conv in self.convs: | |
x = conv(x) | |
if self.gc_block is not None: | |
x = self.gc_block(x) | |
if self.upsampler is not None: | |
x = self.upsampler(x) | |
return x | |
class SpatialAttnUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
attention_head_dim: int = 1, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.convs = nn.ModuleList([]) | |
self.attentions = nn.ModuleList([]) | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
self.convs.append( | |
ResidualBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
non_linearity=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
) | |
) | |
self.attentions.append( | |
SpatialAttention( | |
out_channels, | |
nheads=out_channels // attention_head_dim, | |
head_dim=attention_head_dim, | |
bias=True, | |
upcast_softmax=True, | |
norm_num_groups=norm_num_groups, | |
eps=norm_eps, | |
rescale_output_factor=output_scale_factor, | |
residual_connection=True, | |
) | |
) | |
if add_gc_block: | |
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
else: | |
self.gc_block = None | |
if add_upsample: | |
self.upsampler = SpatialUpsampler3D(out_channels, out_channels) | |
else: | |
self.upsampler = None | |
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
for conv, attn in zip(self.convs, self.attentions): | |
x = conv(x) | |
x = attn(x) | |
if self.gc_block is not None: | |
x = self.gc_block(x) | |
if self.upsampler is not None: | |
x = self.upsampler(x) | |
return x | |
class TemporalUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.convs = nn.ModuleList([]) | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
self.convs.append( | |
ResidualBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
non_linearity=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
) | |
) | |
if add_gc_block: | |
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
else: | |
self.gc_block = None | |
if add_upsample: | |
self.upsampler = TemporalUpsampler3D(out_channels, out_channels) | |
else: | |
self.upsampler = None | |
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
for conv in self.convs: | |
x = conv(x) | |
if self.gc_block is not None: | |
x = self.gc_block(x) | |
if self.upsampler is not None: | |
x = self.upsampler(x) | |
return x | |
class TemporalAttnUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
attention_head_dim: int = 1, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.convs = nn.ModuleList([]) | |
self.attentions = nn.ModuleList([]) | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
self.convs.append( | |
ResidualBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
non_linearity=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
) | |
) | |
self.attentions.append( | |
TemporalAttention( | |
out_channels, | |
nheads=out_channels // attention_head_dim, | |
head_dim=attention_head_dim, | |
bias=True, | |
upcast_softmax=True, | |
norm_num_groups=norm_num_groups, | |
eps=norm_eps, | |
rescale_output_factor=output_scale_factor, | |
residual_connection=True, | |
) | |
) | |
if add_gc_block: | |
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
else: | |
self.gc_block = None | |
if add_upsample: | |
self.upsampler = TemporalUpsampler3D(out_channels, out_channels) | |
else: | |
self.upsampler = None | |
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
for conv, attn in zip(self.convs, self.attentions): | |
x = conv(x) | |
x = attn(x) | |
if self.gc_block is not None: | |
x = self.gc_block(x) | |
if self.upsampler is not None: | |
x = self.upsampler(x) | |
return x | |
class SpatialTemporalUpBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
num_layers: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: int = 32, | |
norm_eps: float = 1e-6, | |
dropout: float = 0.0, | |
output_scale_factor: float = 1.0, | |
add_gc_block: bool = False, | |
add_upsample: bool = True, | |
): | |
super().__init__() | |
self.convs = nn.ModuleList([]) | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
self.convs.append( | |
ResidualBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
non_linearity=act_fn, | |
norm_num_groups=norm_num_groups, | |
norm_eps=norm_eps, | |
dropout=dropout, | |
output_scale_factor=output_scale_factor, | |
) | |
) | |
if add_gc_block: | |
self.gc_block = GlobalContextBlock(out_channels, out_channels, fusion_type="mul") | |
else: | |
self.gc_block = None | |
if add_upsample: | |
self.upsampler = SpatialTemporalUpsampler3D(out_channels, out_channels) | |
else: | |
self.upsampler = None | |
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | |
for conv in self.convs: | |
x = conv(x) | |
if self.gc_block is not None: | |
x = self.gc_block(x) | |
if self.upsampler is not None: | |
x = self.upsampler(x) | |
return x | |