bubbliiiing
Create Code
19fe404
import math
import torch
import torch.nn as nn
from .downsamplers import BlurPooling2D, BlurPooling3D
class DiscriminatorBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.BatchNorm2d(in_channels)
self.nonlinearity = nn.LeakyReLU(0.2)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.downsampler = BlurPooling2D(out_channels, out_channels)
else:
self.downsampler = nn.Identity()
self.norm2 = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.shortcut = nn.Sequential(
BlurPooling2D(in_channels, in_channels),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
)
else:
self.shortcut = nn.Identity()
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.downsampler(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class Discriminator2D(nn.Module):
def __init__(
self,
in_channels: int = 3,
block_out_channels = (64,),
):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, out_channels in enumerate(block_out_channels):
input_channels = output_channels
output_channels = out_channels
is_final_block = i == len(block_out_channels) - 1
self.blocks.append(
DiscriminatorBlock2D(
in_channels=input_channels,
out_channels=output_channels,
output_scale_factor=math.sqrt(2),
add_downsample=not is_final_block,
)
)
self.conv_norm_out = nn.BatchNorm2d(block_out_channels[-1])
self.conv_act = nn.LeakyReLU(0.2)
self.conv_out = nn.Conv2d(block_out_channels[-1], 1, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, H, W)
x = self.conv_in(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)
return x
class DiscriminatorBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.output_scale_factor = output_scale_factor
self.norm1 = nn.GroupNorm(32, in_channels)
self.nonlinearity = nn.LeakyReLU(0.2)
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.downsampler = BlurPooling3D(out_channels, out_channels)
else:
self.downsampler = nn.Identity()
self.norm2 = nn.GroupNorm(32, out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
if add_downsample:
self.shortcut = nn.Sequential(
BlurPooling3D(in_channels, in_channels),
nn.Conv3d(in_channels, out_channels, kernel_size=1),
)
else:
self.shortcut = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=1),
)
self.spatial_downsample_factor = 2
self.temporal_downsample_factor = 2
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = self.shortcut(x)
x = self.norm1(x)
x = self.nonlinearity(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.nonlinearity(x)
x = self.dropout(x)
x = self.downsampler(x)
x = self.conv2(x)
return (x + shortcut) / self.output_scale_factor
class Discriminator3D(nn.Module):
def __init__(
self,
in_channels: int = 3,
block_out_channels = (64,),
):
super().__init__()
self.conv_in = nn.Conv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1, stride=2)
self.blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, out_channels in enumerate(block_out_channels):
input_channels = output_channels
output_channels = out_channels
is_final_block = i == len(block_out_channels) - 1
self.blocks.append(
DiscriminatorBlock3D(
in_channels=input_channels,
out_channels=output_channels,
output_scale_factor=math.sqrt(2),
add_downsample=not is_final_block,
)
)
self.conv_norm_out = nn.GroupNorm(32, block_out_channels[-1])
self.conv_act = nn.LeakyReLU(0.2)
self.conv_out = nn.Conv3d(block_out_channels[-1], 1, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
x = self.conv_in(x)
for block in self.blocks:
x = block(x)
x = self.conv_out(x)
return x