VTBench / src /vqvaes /maskbit /modules /discriminator.py
huaweilin's picture
update
14ce5a9
"""This file contains the definition of the discriminator."""
import functools
import math
from typing import Tuple
import torch
import torch.nn.functional as F
from .autoencoder import Conv2dSame
class BlurBlock(torch.nn.Module):
def __init__(self, kernel: Tuple[int] = (1, 3, 3, 1)):
"""Initializes the blur block.
Args:
kernel -> Tuple[int]: The kernel size.
"""
super().__init__()
self.kernel_size = len(kernel)
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
kernel = kernel[None, :] * kernel[:, None]
kernel /= kernel.sum()
kernel = kernel.unsqueeze(0).unsqueeze(0)
self.register_buffer("kernel", kernel)
def calc_same_pad(self, i: int, k: int, s: int) -> int:
"""Calculates the same padding for the BlurBlock.
Args:
i -> int: Input size.
k -> int: Kernel size.
s -> int: Stride.
Returns:
pad -> int: The padding.
"""
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x -> torch.Tensor: The input tensor.
Returns:
out -> torch.Tensor: The output tensor.
"""
ic, ih, iw = x.size()[-3:]
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size, s=2)
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size, s=2)
if pad_h > 0 or pad_w > 0:
x = F.pad(
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
weight = self.kernel.expand(ic, -1, -1, -1)
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
return out
class NLayerDiscriminatorv2(torch.nn.Module):
def __init__(
self,
num_channels: int = 3,
hidden_channels: int = 64,
num_stages: int = 3,
activation_fn: str = "leaky_relu",
blur_resample: bool = False,
blur_kernel_size: int = 4,
):
"""Initializes the NLayerDiscriminatorv2.
Args:
num_channels -> int: The number of input channels.
hidden_channels -> int: The number of hidden channels.
num_stages -> int: The number of stages.
activation_fn -> str: The activation function.
blur_resample -> bool: Whether to use blur resampling.
blur_kernel_size -> int: The blur kernel size.
"""
super().__init__()
assert num_stages > 0, "Discriminator cannot have 0 stages"
assert (not blur_resample) or (
blur_kernel_size >= 3 and blur_kernel_size <= 5
), "Blur kernel size must be in [3,5] when sampling]"
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
init_kernel_size = 5
if activation_fn == "leaky_relu":
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
else:
activation = torch.nn.SiLU
self.block_in = torch.nn.Sequential(
Conv2dSame(num_channels, hidden_channels, kernel_size=init_kernel_size),
activation(),
)
BLUR_KERNEL_MAP = {
3: (1, 2, 1),
4: (1, 3, 3, 1),
5: (1, 4, 6, 4, 1),
}
discriminator_blocks = []
for i_level in range(num_stages):
in_channels = hidden_channels * in_channel_mult[i_level]
out_channels = hidden_channels * in_channel_mult[i_level + 1]
block = torch.nn.Sequential(
Conv2dSame(
in_channels,
out_channels,
kernel_size=3,
),
(
torch.nn.AvgPool2d(kernel_size=2, stride=2)
if not blur_resample
else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size])
),
torch.nn.GroupNorm(32, out_channels),
activation(),
)
discriminator_blocks.append(block)
self.blocks = torch.nn.ModuleList(discriminator_blocks)
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
self.to_logits = torch.nn.Sequential(
Conv2dSame(out_channels, out_channels, 1),
activation(),
Conv2dSame(out_channels, 1, kernel_size=5),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x -> torch.Tensor: The input tensor.
Returns:
output -> torch.Tensor: The output tensor.
"""
hidden_states = self.block_in(x)
for block in self.blocks:
hidden_states = block(hidden_states)
hidden_states = self.pool(hidden_states)
return self.to_logits(hidden_states)
class OriginalNLayerDiscriminator(torch.nn.Module):
"""Defines a PatchGAN discriminator like in Pix2Pix as used by Taming VQGAN
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(
self,
num_channels: int = 3,
hidden_channels: int = 64,
num_stages: int = 3,
):
"""Initializes a PatchGAN discriminator.
Args:
num_channels -> int: The number of input channels.
hidden_channels -> int: The number of hidden channels.
num_stages -> int: The number of stages.
"""
super(OriginalNLayerDiscriminator, self).__init__()
norm_layer = torch.nn.BatchNorm2d
sequence = [
torch.nn.Conv2d(
num_channels, hidden_channels, kernel_size=4, stride=2, padding=1
),
torch.nn.LeakyReLU(0.2, True),
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, num_stages): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
torch.nn.Conv2d(
hidden_channels * nf_mult_prev,
hidden_channels * nf_mult,
kernel_size=4,
stride=2,
padding=1,
bias=False,
),
norm_layer(hidden_channels * nf_mult),
torch.nn.LeakyReLU(0.2, True),
]
nf_mult_prev = nf_mult
nf_mult = min(2**num_stages, 8)
sequence += [
torch.nn.Conv2d(
hidden_channels * nf_mult_prev,
hidden_channels * nf_mult,
kernel_size=4,
stride=1,
padding=1,
bias=False,
),
norm_layer(hidden_channels * nf_mult),
torch.nn.LeakyReLU(0.2, True),
]
sequence += [
torch.nn.Conv2d(
hidden_channels * nf_mult, 1, kernel_size=4, stride=1, padding=1
)
] # output 1 channel prediction map
self.main = torch.nn.Sequential(*sequence)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x -> torch.Tensor: The input tensor.
Returns:
output -> torch.Tensor: The output tensor.
"""
return self.main(x)
if __name__ == "__main__":
patch_discriminator_v2 = NLayerDiscriminatorv2(
num_channels=3, hidden_channels=128, num_stages=3
)
patch_discriminator_v2_blur = NLayerDiscriminatorv2(
num_channels=3, hidden_channels=128, num_stages=3, blur_resample=True
)
original_discriminiator = OriginalNLayerDiscriminator(
num_channels=3, hidden_channels=128, num_stages=3
)
from torchinfo import summary
print("Original Discriminator")
summary(
original_discriminiator,
input_size=(1, 3, 256, 256),
depth=3,
col_names=(
"input_size",
"output_size",
"num_params",
"params_percent",
"kernel_size",
"mult_adds",
),
)
print("Patch Discriminator v2")
summary(
patch_discriminator_v2,
input_size=(1, 3, 256, 256),
depth=3,
col_names=(
"input_size",
"output_size",
"num_params",
"params_percent",
"kernel_size",
"mult_adds",
),
)
print("Patch Discriminator v2 (blur)")
summary(
patch_discriminator_v2_blur,
input_size=(1, 3, 256, 256),
depth=3,
col_names=(
"input_size",
"output_size",
"num_params",
"params_percent",
"kernel_size",
"mult_adds",
),
)
x = torch.randn((1, 3, 256, 256)).to(next(original_discriminiator.parameters()))
out_original = original_discriminiator(x)
out_patch_v2 = patch_discriminator_v2(x)
out_patch_v2_blur = patch_discriminator_v2_blur(x)
print(f"Input shape: {x.shape}")
print(f"Patch Discriminator v2 output shape: {out_patch_v2.shape}")
print(f"Patch Discriminator v2 (blur) output shape: {out_patch_v2_blur.shape}")
print(f"Original Discriminator output shape: {out_original.shape}")