Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
from .downsamplers import BlurPooling2D, BlurPooling3D | |
class DiscriminatorBlock2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
dropout: float = 0.0, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
self.output_scale_factor = output_scale_factor | |
self.norm1 = nn.BatchNorm2d(in_channels) | |
self.nonlinearity = nn.LeakyReLU(0.2) | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
if add_downsample: | |
self.downsampler = BlurPooling2D(out_channels, out_channels) | |
else: | |
self.downsampler = nn.Identity() | |
self.norm2 = nn.BatchNorm2d(out_channels) | |
self.dropout = nn.Dropout(dropout) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) | |
if add_downsample: | |
self.shortcut = nn.Sequential( | |
BlurPooling2D(in_channels, in_channels), | |
nn.Conv2d(in_channels, out_channels, kernel_size=1), | |
) | |
else: | |
self.shortcut = nn.Identity() | |
self.spatial_downsample_factor = 2 | |
self.temporal_downsample_factor = 1 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
shortcut = self.shortcut(x) | |
x = self.norm1(x) | |
x = self.nonlinearity(x) | |
x = self.conv1(x) | |
x = self.norm2(x) | |
x = self.nonlinearity(x) | |
x = self.dropout(x) | |
x = self.downsampler(x) | |
x = self.conv2(x) | |
return (x + shortcut) / self.output_scale_factor | |
class Discriminator2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
block_out_channels = (64,), | |
): | |
super().__init__() | |
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) | |
self.blocks = nn.ModuleList([]) | |
output_channels = block_out_channels[0] | |
for i, out_channels in enumerate(block_out_channels): | |
input_channels = output_channels | |
output_channels = out_channels | |
is_final_block = i == len(block_out_channels) - 1 | |
self.blocks.append( | |
DiscriminatorBlock2D( | |
in_channels=input_channels, | |
out_channels=output_channels, | |
output_scale_factor=math.sqrt(2), | |
add_downsample=not is_final_block, | |
) | |
) | |
self.conv_norm_out = nn.BatchNorm2d(block_out_channels[-1]) | |
self.conv_act = nn.LeakyReLU(0.2) | |
self.conv_out = nn.Conv2d(block_out_channels[-1], 1, kernel_size=3, padding=1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: (B, C, H, W) | |
x = self.conv_in(x) | |
for block in self.blocks: | |
x = block(x) | |
x = self.conv_out(x) | |
return x | |
class DiscriminatorBlock3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
dropout: float = 0.0, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
): | |
super().__init__() | |
self.output_scale_factor = output_scale_factor | |
self.norm1 = nn.GroupNorm(32, in_channels) | |
self.nonlinearity = nn.LeakyReLU(0.2) | |
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) | |
if add_downsample: | |
self.downsampler = BlurPooling3D(out_channels, out_channels) | |
else: | |
self.downsampler = nn.Identity() | |
self.norm2 = nn.GroupNorm(32, out_channels) | |
self.dropout = nn.Dropout(dropout) | |
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) | |
if add_downsample: | |
self.shortcut = nn.Sequential( | |
BlurPooling3D(in_channels, in_channels), | |
nn.Conv3d(in_channels, out_channels, kernel_size=1), | |
) | |
else: | |
self.shortcut = nn.Sequential( | |
nn.Conv3d(in_channels, out_channels, kernel_size=1), | |
) | |
self.spatial_downsample_factor = 2 | |
self.temporal_downsample_factor = 2 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
shortcut = self.shortcut(x) | |
x = self.norm1(x) | |
x = self.nonlinearity(x) | |
x = self.conv1(x) | |
x = self.norm2(x) | |
x = self.nonlinearity(x) | |
x = self.dropout(x) | |
x = self.downsampler(x) | |
x = self.conv2(x) | |
return (x + shortcut) / self.output_scale_factor | |
class Discriminator3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
block_out_channels = (64,), | |
): | |
super().__init__() | |
self.conv_in = nn.Conv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1, stride=2) | |
self.blocks = nn.ModuleList([]) | |
output_channels = block_out_channels[0] | |
for i, out_channels in enumerate(block_out_channels): | |
input_channels = output_channels | |
output_channels = out_channels | |
is_final_block = i == len(block_out_channels) - 1 | |
self.blocks.append( | |
DiscriminatorBlock3D( | |
in_channels=input_channels, | |
out_channels=output_channels, | |
output_scale_factor=math.sqrt(2), | |
add_downsample=not is_final_block, | |
) | |
) | |
self.conv_norm_out = nn.GroupNorm(32, block_out_channels[-1]) | |
self.conv_act = nn.LeakyReLU(0.2) | |
self.conv_out = nn.Conv3d(block_out_channels[-1], 1, kernel_size=3, padding=1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: (B, C, T, H, W) | |
x = self.conv_in(x) | |
for block in self.blocks: | |
x = block(x) | |
x = self.conv_out(x) | |
return x | |