File size: 6,956 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""This file contains the definition of utility functions for GANs."""
import torch
import torch.nn.functional as F
from . import OriginalNLayerDiscriminator, NLayerDiscriminatorv2
def toggle_off_gradients(model: torch.nn.Module):
"""Toggles off gradients for all parameters in a model."""
for param in model.parameters():
param.requires_grad = False
def toggle_on_gradients(model: torch.nn.Module):
"""Toggles on gradients for all parameters in a model."""
for param in model.parameters():
param.requires_grad = True
def discriminator_weights_init(m):
"""Initialize weights for convolutions in the discriminator."""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
def adopt_weight(
weight: float, global_step: int, threshold: int = 0, value: float = 0.0
) -> float:
"""If global_step is less than threshold, return value, else return weight."""
if global_step < threshold:
weight = value
return weight
def compute_lecam_loss(
logits_real_mean: torch.Tensor,
logits_fake_mean: torch.Tensor,
ema_logits_real_mean: torch.Tensor,
ema_logits_fake_mean: torch.Tensor,
) -> torch.Tensor:
"""Computes the LeCam loss for the given average real and fake logits.
Args:
logits_real_mean -> torch.Tensor: The average real logits.
logits_fake_mean -> torch.Tensor: The average fake logits.
ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
Returns:
lecam_loss -> torch.Tensor: The LeCam loss.
"""
lecam_loss = torch.mean(
torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2)
)
lecam_loss += torch.mean(
torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2)
)
return lecam_loss
def hinge_g_loss(logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the hinge loss for the generator given the fake logits.
Args:
logits_fake -> torch.Tensor: The fake logits.
Returns:
g_loss -> torch.Tensor: The hinge loss.
"""
g_loss = -torch.mean(logits_fake)
return g_loss
def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the hinge loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
d_loss -> torch.Tensor: The hinge loss.
"""
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def sigmoid_cross_entropy_with_logits(
logits: torch.Tensor, label: torch.Tensor
) -> torch.Tensor:
"""Credits to Magvit.
We use a stable formulation that is equivalent to the one used in TensorFlow.
The following derivation shows how we arrive at the formulation:
.. math::
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
For x < 0, the following formula is more stable:
.. math::
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
We combine the two cases (x<0, x>=0) into one formula as follows:
.. math::
max(x, 0) - x * z + log(1 + exp(-abs(x)))
"""
zeros = torch.zeros_like(logits)
cond = logits >= zeros
relu_logits = torch.where(cond, logits, zeros)
neg_abs_logits = torch.where(cond, -logits, logits)
loss = relu_logits - logits * label + torch.log1p(neg_abs_logits.exp())
return loss
def non_saturating_d_loss(
logits_real: torch.Tensor, logits_fake: torch.Tensor
) -> torch.Tensor:
"""Computes the non-saturating loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The non-saturating loss.
"""
real_loss = torch.mean(
sigmoid_cross_entropy_with_logits(
logits_real, label=torch.ones_like(logits_real)
)
)
fake_loss = torch.mean(
sigmoid_cross_entropy_with_logits(
logits_fake, label=torch.zeros_like(logits_fake)
)
)
return torch.mean(real_loss) + torch.mean(fake_loss)
def non_saturating_g_loss(logits_fake: torch.Tensor) -> torch.Tensor:
"""Computes the non-saturating loss for the generator given the fake logits.
Args:
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The non-saturating loss.
"""
return torch.mean(
sigmoid_cross_entropy_with_logits(
logits_fake, label=torch.ones_like(logits_fake)
)
)
def vanilla_d_loss(
logits_real: torch.Tensor, logits_fake: torch.Tensor
) -> torch.Tensor:
"""Computes the vanilla loss for the discriminator given the real and fake logits.
Args:
logits_real -> torch.Tensor: The real logits.
logits_fake -> torch.Tensor: The fake logits.
Returns:
loss -> torch.Tensor: The vanilla loss.
"""
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
def create_discriminator(discriminator_config) -> torch.nn.Module:
"""Creates a discriminator based on the given config.
Args:
discriminator_config: The config for the discriminator.
Returns:
discriminator -> torch.nn.Module: The discriminator.
"""
if discriminator_config.name == "Original":
return OriginalNLayerDiscriminator(
num_channels=discriminator_config.num_channels,
num_stages=discriminator_config.num_stages,
hidden_channels=discriminator_config.hidden_channels,
).apply(discriminator_weights_init)
elif discriminator_config.name == "VQGAN+Discriminator":
return NLayerDiscriminatorv2(
num_channels=discriminator_config.num_channels,
num_stages=discriminator_config.num_stages,
hidden_channels=discriminator_config.hidden_channels,
blur_resample=discriminator_config.blur_resample,
blur_kernel_size=discriminator_config.get("blur_kernel_size", 4),
)
else:
raise ValueError(
f"Discriminator {discriminator_config.name} is not implemented."
)
|