huaweilin's picture
update
14ce5a9
"""This file contains the definition of utility functions for GANs."""
import torch
import torch.nn.functional as F
from . import OriginalNLayerDiscriminator, NLayerDiscriminatorv2
def toggle_off_gradients(model: torch.nn.Module):
"""Toggles off gradients for all parameters in a model."""
for param in model.parameters():
param.requires_grad = False
def toggle_on_gradients(model: torch.nn.Module):
"""Toggles on gradients for all parameters in a model."""
for param in model.parameters():
param.requires_grad = True
def discriminator_weights_init(m):
"""Initialize weights for convolutions in the discriminator."""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
def adopt_weight(
weight: float, global_step: int, threshold: int = 0, value: float = 0.0
) -> float:
"""If global_step is less than threshold, return value, else return weight."""
if global_step < threshold:
weight = value
return weight
def compute_lecam_loss(
logits_real_mean: torch.Tensor,
logits_fake_mean: torch.Tensor,
ema_logits_real_mean: torch.Tensor,
ema_logits_fake_mean: torch.Tensor,
) -> torch.Tensor:
"""Computes the LeCam loss for the given average real and fake logits.
Args:
logits_real_mean -> torch.Tensor: The average real logits.
logits_fake_mean -> torch.Tensor: The average fake logits.
ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
Returns:
lecam_loss -> torch.Tensor: The LeCam loss.
"""
lecam_loss = torch.mean(
torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2)
)
lecam_loss += torch.mean(
torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2)
)
return lecam_loss
def hinge_g_loss(logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the hinge loss for the generator given the fake logits.
Args:
logits_fake -> torch.Tensor: The fake logits.
Returns:
g_loss -> torch.Tensor: The hinge loss.
"""
g_loss = -torch.mean(logits_fake)
return g_loss
def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the hinge loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
d_loss -> torch.Tensor: The hinge loss.
"""
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def sigmoid_cross_entropy_with_logits(
logits: torch.Tensor, label: torch.Tensor
) -> torch.Tensor:
"""Credits to Magvit.
We use a stable formulation that is equivalent to the one used in TensorFlow.
The following derivation shows how we arrive at the formulation:
.. math::
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
For x < 0, the following formula is more stable:
.. math::
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
We combine the two cases (x<0, x>=0) into one formula as follows:
.. math::
max(x, 0) - x * z + log(1 + exp(-abs(x)))
"""
zeros = torch.zeros_like(logits)
cond = logits >= zeros
relu_logits = torch.where(cond, logits, zeros)
neg_abs_logits = torch.where(cond, -logits, logits)
loss = relu_logits - logits * label + torch.log1p(neg_abs_logits.exp())
return loss
def non_saturating_d_loss(
logits_real: torch.Tensor, logits_fake: torch.Tensor
) -> torch.Tensor:
"""Computes the non-saturating loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The non-saturating loss.
"""
real_loss = torch.mean(
sigmoid_cross_entropy_with_logits(
logits_real, label=torch.ones_like(logits_real)
)
)
fake_loss = torch.mean(
sigmoid_cross_entropy_with_logits(
logits_fake, label=torch.zeros_like(logits_fake)
)
)
return torch.mean(real_loss) + torch.mean(fake_loss)
def non_saturating_g_loss(logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the non-saturating loss for the generator given the fake logits.
Args:
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The non-saturating loss.
"""
return torch.mean(
sigmoid_cross_entropy_with_logits(
logits_fake, label=torch.ones_like(logits_fake)
)
)
def vanilla_d_loss(
logits_real: torch.Tensor, logits_fake: torch.Tensor
) -> torch.Tensor:
"""Computes the vanilla loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The vanilla loss.
"""
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
def create_discriminator(discriminator_config) -> torch.nn.Module:
"""Creates a discriminator based on the given config.
Args:
discriminator_config: The config for the discriminator.
Returns:
discriminator -> torch.nn.Module: The discriminator.
"""
if discriminator_config.name == "Original":
return OriginalNLayerDiscriminator(
num_channels=discriminator_config.num_channels,
num_stages=discriminator_config.num_stages,
hidden_channels=discriminator_config.hidden_channels,
).apply(discriminator_weights_init)
elif discriminator_config.name == "VQGAN+Discriminator":
return NLayerDiscriminatorv2(
num_channels=discriminator_config.num_channels,
num_stages=discriminator_config.num_stages,
hidden_channels=discriminator_config.hidden_channels,
blur_resample=discriminator_config.blur_resample,
blur_kernel_size=discriminator_config.get("blur_kernel_size", 4),
)
else:
raise ValueError(
f"Discriminator {discriminator_config.name} is not implemented."
)