|
"""This file contains the definition of the discriminator.""" |
|
|
|
import functools |
|
import math |
|
from typing import Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .autoencoder import Conv2dSame |
|
|
|
|
|
class BlurBlock(torch.nn.Module): |
|
def __init__(self, kernel: Tuple[int] = (1, 3, 3, 1)): |
|
"""Initializes the blur block. |
|
|
|
Args: |
|
kernel -> Tuple[int]: The kernel size. |
|
""" |
|
super().__init__() |
|
|
|
self.kernel_size = len(kernel) |
|
|
|
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) |
|
kernel = kernel[None, :] * kernel[:, None] |
|
kernel /= kernel.sum() |
|
kernel = kernel.unsqueeze(0).unsqueeze(0) |
|
self.register_buffer("kernel", kernel) |
|
|
|
def calc_same_pad(self, i: int, k: int, s: int) -> int: |
|
"""Calculates the same padding for the BlurBlock. |
|
|
|
Args: |
|
i -> int: Input size. |
|
k -> int: Kernel size. |
|
s -> int: Stride. |
|
|
|
Returns: |
|
pad -> int: The padding. |
|
""" |
|
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass. |
|
|
|
Args: |
|
x -> torch.Tensor: The input tensor. |
|
|
|
Returns: |
|
out -> torch.Tensor: The output tensor. |
|
""" |
|
ic, ih, iw = x.size()[-3:] |
|
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size, s=2) |
|
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size, s=2) |
|
if pad_h > 0 or pad_w > 0: |
|
x = F.pad( |
|
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] |
|
) |
|
|
|
weight = self.kernel.expand(ic, -1, -1, -1) |
|
|
|
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) |
|
return out |
|
|
|
|
|
class NLayerDiscriminatorv2(torch.nn.Module): |
|
def __init__( |
|
self, |
|
num_channels: int = 3, |
|
hidden_channels: int = 64, |
|
num_stages: int = 3, |
|
activation_fn: str = "leaky_relu", |
|
blur_resample: bool = False, |
|
blur_kernel_size: int = 4, |
|
): |
|
"""Initializes the NLayerDiscriminatorv2. |
|
|
|
Args: |
|
num_channels -> int: The number of input channels. |
|
hidden_channels -> int: The number of hidden channels. |
|
num_stages -> int: The number of stages. |
|
activation_fn -> str: The activation function. |
|
blur_resample -> bool: Whether to use blur resampling. |
|
blur_kernel_size -> int: The blur kernel size. |
|
""" |
|
super().__init__() |
|
assert num_stages > 0, "Discriminator cannot have 0 stages" |
|
assert (not blur_resample) or ( |
|
blur_kernel_size >= 3 and blur_kernel_size <= 5 |
|
), "Blur kernel size must be in [3,5] when sampling]" |
|
|
|
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) |
|
init_kernel_size = 5 |
|
if activation_fn == "leaky_relu": |
|
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) |
|
else: |
|
activation = torch.nn.SiLU |
|
|
|
self.block_in = torch.nn.Sequential( |
|
Conv2dSame(num_channels, hidden_channels, kernel_size=init_kernel_size), |
|
activation(), |
|
) |
|
|
|
BLUR_KERNEL_MAP = { |
|
3: (1, 2, 1), |
|
4: (1, 3, 3, 1), |
|
5: (1, 4, 6, 4, 1), |
|
} |
|
|
|
discriminator_blocks = [] |
|
for i_level in range(num_stages): |
|
in_channels = hidden_channels * in_channel_mult[i_level] |
|
out_channels = hidden_channels * in_channel_mult[i_level + 1] |
|
block = torch.nn.Sequential( |
|
Conv2dSame( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
), |
|
( |
|
torch.nn.AvgPool2d(kernel_size=2, stride=2) |
|
if not blur_resample |
|
else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]) |
|
), |
|
torch.nn.GroupNorm(32, out_channels), |
|
activation(), |
|
) |
|
discriminator_blocks.append(block) |
|
|
|
self.blocks = torch.nn.ModuleList(discriminator_blocks) |
|
|
|
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) |
|
|
|
self.to_logits = torch.nn.Sequential( |
|
Conv2dSame(out_channels, out_channels, 1), |
|
activation(), |
|
Conv2dSame(out_channels, 1, kernel_size=5), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass. |
|
|
|
Args: |
|
x -> torch.Tensor: The input tensor. |
|
|
|
Returns: |
|
output -> torch.Tensor: The output tensor. |
|
""" |
|
hidden_states = self.block_in(x) |
|
for block in self.blocks: |
|
hidden_states = block(hidden_states) |
|
|
|
hidden_states = self.pool(hidden_states) |
|
|
|
return self.to_logits(hidden_states) |
|
|
|
|
|
class OriginalNLayerDiscriminator(torch.nn.Module): |
|
"""Defines a PatchGAN discriminator like in Pix2Pix as used by Taming VQGAN |
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_channels: int = 3, |
|
hidden_channels: int = 64, |
|
num_stages: int = 3, |
|
): |
|
"""Initializes a PatchGAN discriminator. |
|
|
|
Args: |
|
num_channels -> int: The number of input channels. |
|
hidden_channels -> int: The number of hidden channels. |
|
num_stages -> int: The number of stages. |
|
""" |
|
super(OriginalNLayerDiscriminator, self).__init__() |
|
norm_layer = torch.nn.BatchNorm2d |
|
|
|
sequence = [ |
|
torch.nn.Conv2d( |
|
num_channels, hidden_channels, kernel_size=4, stride=2, padding=1 |
|
), |
|
torch.nn.LeakyReLU(0.2, True), |
|
] |
|
nf_mult = 1 |
|
nf_mult_prev = 1 |
|
for n in range(1, num_stages): |
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**n, 8) |
|
sequence += [ |
|
torch.nn.Conv2d( |
|
hidden_channels * nf_mult_prev, |
|
hidden_channels * nf_mult, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
norm_layer(hidden_channels * nf_mult), |
|
torch.nn.LeakyReLU(0.2, True), |
|
] |
|
|
|
nf_mult_prev = nf_mult |
|
nf_mult = min(2**num_stages, 8) |
|
sequence += [ |
|
torch.nn.Conv2d( |
|
hidden_channels * nf_mult_prev, |
|
hidden_channels * nf_mult, |
|
kernel_size=4, |
|
stride=1, |
|
padding=1, |
|
bias=False, |
|
), |
|
norm_layer(hidden_channels * nf_mult), |
|
torch.nn.LeakyReLU(0.2, True), |
|
] |
|
|
|
sequence += [ |
|
torch.nn.Conv2d( |
|
hidden_channels * nf_mult, 1, kernel_size=4, stride=1, padding=1 |
|
) |
|
] |
|
self.main = torch.nn.Sequential(*sequence) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass. |
|
|
|
Args: |
|
x -> torch.Tensor: The input tensor. |
|
|
|
Returns: |
|
output -> torch.Tensor: The output tensor. |
|
""" |
|
return self.main(x) |
|
|
|
|
|
if __name__ == "__main__": |
|
patch_discriminator_v2 = NLayerDiscriminatorv2( |
|
num_channels=3, hidden_channels=128, num_stages=3 |
|
) |
|
patch_discriminator_v2_blur = NLayerDiscriminatorv2( |
|
num_channels=3, hidden_channels=128, num_stages=3, blur_resample=True |
|
) |
|
original_discriminiator = OriginalNLayerDiscriminator( |
|
num_channels=3, hidden_channels=128, num_stages=3 |
|
) |
|
|
|
from torchinfo import summary |
|
|
|
print("Original Discriminator") |
|
summary( |
|
original_discriminiator, |
|
input_size=(1, 3, 256, 256), |
|
depth=3, |
|
col_names=( |
|
"input_size", |
|
"output_size", |
|
"num_params", |
|
"params_percent", |
|
"kernel_size", |
|
"mult_adds", |
|
), |
|
) |
|
print("Patch Discriminator v2") |
|
summary( |
|
patch_discriminator_v2, |
|
input_size=(1, 3, 256, 256), |
|
depth=3, |
|
col_names=( |
|
"input_size", |
|
"output_size", |
|
"num_params", |
|
"params_percent", |
|
"kernel_size", |
|
"mult_adds", |
|
), |
|
) |
|
print("Patch Discriminator v2 (blur)") |
|
summary( |
|
patch_discriminator_v2_blur, |
|
input_size=(1, 3, 256, 256), |
|
depth=3, |
|
col_names=( |
|
"input_size", |
|
"output_size", |
|
"num_params", |
|
"params_percent", |
|
"kernel_size", |
|
"mult_adds", |
|
), |
|
) |
|
|
|
x = torch.randn((1, 3, 256, 256)).to(next(original_discriminiator.parameters())) |
|
|
|
out_original = original_discriminiator(x) |
|
out_patch_v2 = patch_discriminator_v2(x) |
|
out_patch_v2_blur = patch_discriminator_v2_blur(x) |
|
|
|
print(f"Input shape: {x.shape}") |
|
print(f"Patch Discriminator v2 output shape: {out_patch_v2.shape}") |
|
print(f"Patch Discriminator v2 (blur) output shape: {out_patch_v2_blur.shape}") |
|
print(f"Original Discriminator output shape: {out_original.shape}") |
|
|