"""This file contains the definition of utility functions for GANs.""" import torch import torch.nn.functional as F from . import OriginalNLayerDiscriminator, NLayerDiscriminatorv2 def toggle_off_gradients(model: torch.nn.Module): """Toggles off gradients for all parameters in a model.""" for param in model.parameters(): param.requires_grad = False def toggle_on_gradients(model: torch.nn.Module): """Toggles on gradients for all parameters in a model.""" for param in model.parameters(): param.requires_grad = True def discriminator_weights_init(m): """Initialize weights for convolutions in the discriminator.""" classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) def adopt_weight( weight: float, global_step: int, threshold: int = 0, value: float = 0.0 ) -> float: """If global_step is less than threshold, return value, else return weight.""" if global_step < threshold: weight = value return weight def compute_lecam_loss( logits_real_mean: torch.Tensor, logits_fake_mean: torch.Tensor, ema_logits_real_mean: torch.Tensor, ema_logits_fake_mean: torch.Tensor, ) -> torch.Tensor: """Computes the LeCam loss for the given average real and fake logits. Args: logits_real_mean -> torch.Tensor: The average real logits. logits_fake_mean -> torch.Tensor: The average fake logits. ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits. ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits. Returns: lecam_loss -> torch.Tensor: The LeCam loss. """ lecam_loss = torch.mean( torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2) ) lecam_loss += torch.mean( torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2) ) return lecam_loss def hinge_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: """Computes the hinge loss for the generator given the fake logits. Args: logits_fake -> torch.Tensor: The fake logits. Returns: g_loss -> torch.Tensor: The hinge loss. """ g_loss = -torch.mean(logits_fake) return g_loss def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor: """Computes the hinge loss for the discriminator given the real and fake logits. Args: logits_real -> torch.Tensor: The real logits. logits_fake -> torch.Tensor: The fake logits. Returns: d_loss -> torch.Tensor: The hinge loss. """ loss_real = torch.mean(F.relu(1.0 - logits_real)) loss_fake = torch.mean(F.relu(1.0 + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def sigmoid_cross_entropy_with_logits( logits: torch.Tensor, label: torch.Tensor ) -> torch.Tensor: """Credits to Magvit. We use a stable formulation that is equivalent to the one used in TensorFlow. The following derivation shows how we arrive at the formulation: .. math:: z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) = (1 - z) * x + log(1 + exp(-x)) = x - x * z + log(1 + exp(-x)) For x < 0, the following formula is more stable: .. math:: x - x * z + log(1 + exp(-x)) = log(exp(x)) - x * z + log(1 + exp(-x)) = - x * z + log(1 + exp(x)) We combine the two cases (x<0, x>=0) into one formula as follows: .. math:: max(x, 0) - x * z + log(1 + exp(-abs(x))) """ zeros = torch.zeros_like(logits) cond = logits >= zeros relu_logits = torch.where(cond, logits, zeros) neg_abs_logits = torch.where(cond, -logits, logits) loss = relu_logits - logits * label + torch.log1p(neg_abs_logits.exp()) return loss def non_saturating_d_loss( logits_real: torch.Tensor, logits_fake: torch.Tensor ) -> torch.Tensor: """Computes the non-saturating loss for the discriminator given the real and fake logits. Args: logits_real -> torch.Tensor: The real logits. logits_fake -> torch.Tensor: The fake logits. Returns: loss -> torch.Tensor: The non-saturating loss. """ real_loss = torch.mean( sigmoid_cross_entropy_with_logits( logits_real, label=torch.ones_like(logits_real) ) ) fake_loss = torch.mean( sigmoid_cross_entropy_with_logits( logits_fake, label=torch.zeros_like(logits_fake) ) ) return torch.mean(real_loss) + torch.mean(fake_loss) def non_saturating_g_loss(logits_fake: torch.Tensor) -> torch.Tensor: """Computes the non-saturating loss for the generator given the fake logits. Args: logits_fake -> torch.Tensor: The fake logits. Returns: loss -> torch.Tensor: The non-saturating loss. """ return torch.mean( sigmoid_cross_entropy_with_logits( logits_fake, label=torch.ones_like(logits_fake) ) ) def vanilla_d_loss( logits_real: torch.Tensor, logits_fake: torch.Tensor ) -> torch.Tensor: """Computes the vanilla loss for the discriminator given the real and fake logits. Args: logits_real -> torch.Tensor: The real logits. logits_fake -> torch.Tensor: The fake logits. Returns: loss -> torch.Tensor: The vanilla loss. """ d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake)) ) return d_loss def create_discriminator(discriminator_config) -> torch.nn.Module: """Creates a discriminator based on the given config. Args: discriminator_config: The config for the discriminator. Returns: discriminator -> torch.nn.Module: The discriminator. """ if discriminator_config.name == "Original": return OriginalNLayerDiscriminator( num_channels=discriminator_config.num_channels, num_stages=discriminator_config.num_stages, hidden_channels=discriminator_config.hidden_channels, ).apply(discriminator_weights_init) elif discriminator_config.name == "VQGAN+Discriminator": return NLayerDiscriminatorv2( num_channels=discriminator_config.num_channels, num_stages=discriminator_config.num_stages, hidden_channels=discriminator_config.hidden_channels, blur_resample=discriminator_config.blur_resample, blur_kernel_size=discriminator_config.get("blur_kernel_size", 4), ) else: raise ValueError( f"Discriminator {discriminator_config.name} is not implemented." )