File size: 15,283 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 |
from typing import Mapping, Text, Tuple
import torch
import torch.nn.functional as F
from .lpips import LPIPS
from .perceptual_loss import PerceptualLoss
from . import gan_utils
def create_perception_loss(
perception_loss: str, compute_on_logits: bool = True
) -> torch.nn.Module:
"""Creates the perception loss.
Args:
perception_loss -> str: The name of the perception loss.
compute_on_logits -> bool: Whether to compute the loss on logits or on multiple features.
Returns:
perception_loss -> torch.nn.Module: The perception loss.
"""
if perception_loss == "lpips":
return LPIPS().eval()
elif perception_loss in ("resnet50", "convnext_s"):
return PerceptualLoss(
model_name=perception_loss,
compute_perceptual_loss_on_logits=compute_on_logits,
).eval()
else:
raise ValueError(f"Perception loss {perception_loss} is not supported.")
class VQGANLoss(torch.nn.Module):
def __init__(
self,
discriminator_config,
loss_config,
):
"""Initializes the VQGAN loss.
Args:
discriminator_config: The configuration of the discriminator.
loss_config: The configuration of the loss.
"""
super().__init__()
assert loss_config.discriminator_loss in ("hinge", "vanilla", "non-saturating")
assert loss_config.reconstruction_loss in ("l2", "l1")
assert loss_config.discriminator_gradient_penalty in ("none", "adopt_weight")
self.discriminator = gan_utils.create_discriminator(discriminator_config)
self.reconstruction_loss = loss_config.reconstruction_loss
self.reconstruction_weight = loss_config.get("reconstruction_weight", 1.0)
self.quantizer_weight = loss_config.quantizer_weight
self.perceptual_loss = create_perception_loss(
loss_config.perceptual_loss,
loss_config.get("perceptual_loss_on_logits", True),
)
self.perceptual_weight = loss_config.perceptual_weight
self.lecam_regularization_weight = loss_config.lecam_regularization_weight
self.ema_decay = loss_config.get("ema_decay", 0.999)
self.entropy_annealing_steps = loss_config.get("entropy_annealing_steps", 2000)
self.entropy_annealing_factor = loss_config.get("entropy_annealing_factor", 0.0)
self.discriminator_iter_start = loss_config.discriminator_start
if loss_config.discriminator_loss == "hinge":
self.discriminator_loss = gan_utils.hinge_d_loss
elif loss_config.discriminator_loss == "vanilla":
self.discriminator_loss = gan_utils.vanilla_d_loss
elif loss_config.discriminator_loss == "non-saturating":
self.discriminator_loss = gan_utils.non_saturating_d_loss
else:
raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.")
if loss_config.discriminator_loss == "hinge":
self.generator_loss = gan_utils.hinge_g_loss
elif loss_config.discriminator_loss == "vanilla":
self.generator_loss = gan_utils.hinge_g_loss
elif loss_config.discriminator_loss == "non-saturating":
self.generator_loss = gan_utils.non_saturating_g_loss
else:
raise ValueError(f"Unknown GAN loss '{loss_config.discriminator_loss}'.")
self.discriminator_factor = loss_config.discriminator_factor
self.discriminator_weight = loss_config.discriminator_weight
self.discriminator_gradient_penalty = (
""
if loss_config.discriminator_gradient_penalty == "none"
else loss_config.discriminator_gradient_penalty
)
self.discriminator_penalty_cost = loss_config.discriminator_penalty_cost
if self.lecam_regularization_weight > 0.0:
self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
def calculate_adaptive_weight(
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer
) -> torch.Tensor:
"""Calculates the adaptive weight for the discriminator loss.
Args:
nll_loss -> torch.Tensor: The NLL loss.
g_loss -> torch.Tensor: The generator loss.
last_layer: The last layer of the model.
Returns:
d_weight -> torch.Tensor: The adaptive weight for the discriminator loss.
"""
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
return d_weight
def forward(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
extra_result_dict: Mapping[Text, torch.Tensor],
global_step: int,
last_layer,
mode: str = "gen",
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the VQGAN loss for the generator or discriminator.
Args:
inputs -> torch.Tensor: The input images.
reconstructions -> torch.Tensor: The reconstructed images.
extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary.
global_step -> int: The global step.
last_layer: The last layer of the model.
mode -> str: The mode. Must be either "gen" or "disc".
Returns:
loss -> torch.Tensor: The loss.
loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses.
"""
assert mode in ("gen", "disc")
if mode == "gen":
return self._forward_generator(
inputs, reconstructions, extra_result_dict, global_step, last_layer
)
elif mode == "disc":
return self._forward_discriminator(
inputs, reconstructions, extra_result_dict, global_step
)
def should_discriminator_be_trained(self, global_step: int):
"""Returns if the discriminator should be trained at given step."""
return global_step >= self.discriminator_iter_start
def _forward_generator(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
extra_result_dict: Mapping[Text, torch.Tensor],
global_step: int,
last_layer,
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the VQGAN loss for the generator.
Args:
inputs -> torch.Tensor: The input images.
reconstructions -> torch.Tensor: The reconstructed images.
extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary.
global_step -> int: The global step.
last_layer: The last layer of the model.
Returns:
loss -> torch.Tensor: The loss.
loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses.
"""
inputs = inputs.contiguous()
reconstructions = reconstructions.contiguous()
if self.reconstruction_loss == "l1":
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
else:
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
reconstruction_loss *= self.reconstruction_weight
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
generator_loss = torch.zeros((), device=inputs.device)
extra_generator_loss = torch.zeros((), device=inputs.device)
discriminator_factor = gan_utils.adopt_weight(
self.discriminator_factor,
global_step,
threshold=self.discriminator_iter_start,
)
d_weight = 1.0
if discriminator_factor > 0.0:
# Disable discriminator gradients
gan_utils.toggle_off_gradients(self.discriminator)
logits_fake = self.discriminator(reconstructions)
generator_loss = self.generator_loss(logits_fake)
if self.discriminator_gradient_penalty == "adopt_weight":
d_weight *= self.calculate_adaptive_weight(
reconstruction_loss + self.perceptual_weight * perceptual_loss,
generator_loss,
last_layer=last_layer,
)
d_weight *= self.discriminator_weight
quantizer_loss = extra_result_dict["quantizer_loss"]
if self.entropy_annealing_factor > 0.0:
quantizer_loss += (
max(0.0, 1 - global_step / self.entropy_annealing_steps)
* self.entropy_annealing_factor
* extra_result_dict["entropy_loss"]
)
total_loss = (
reconstruction_loss
+ self.perceptual_weight * perceptual_loss
+ self.quantizer_weight * quantizer_loss
+ d_weight * discriminator_factor * (generator_loss + extra_generator_loss)
)
loss_dict = dict(
total_loss=total_loss.clone().detach(),
reconstruction_loss=reconstruction_loss.detach(),
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
weighted_gan_loss=(
d_weight
* discriminator_factor
* (generator_loss + extra_generator_loss)
).detach(),
discriminator_factor=torch.tensor(discriminator_factor),
commitment_loss=extra_result_dict["commitment_loss"].detach(),
entropy_loss=extra_result_dict["entropy_loss"].detach(),
per_sample_entropy=extra_result_dict["per_sample_entropy"],
avg_entropy=extra_result_dict["avg_entropy"],
d_weight=d_weight,
gan_loss=generator_loss.detach(),
)
if "codebook_loss" in extra_result_dict:
loss_dict["codebook_loss"] = extra_result_dict["codebook_loss"].detach()
return total_loss, loss_dict
def _forward_discriminator(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
extra_result_dict: Mapping[Text, torch.Tensor],
global_step: int,
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the VQGAN loss for the discriminator.
Args:
inputs -> torch.Tensor: The input images.
reconstructions -> torch.Tensor: The reconstructed images.
extra_result_dict -> Mapping[Text, torch.Tensor]: The extra result dictionary.
global_step -> int: The global step.
Returns:
loss -> torch.Tensor: The loss.
loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses.
"""
discriminator_factor = gan_utils.adopt_weight(
self.discriminator_factor,
global_step,
threshold=self.discriminator_iter_start,
)
loss_dict = {}
# Turn on gradients on
gan_utils.toggle_on_gradients(self.discriminator)
real_images = inputs.detach().requires_grad_(True)
logits_real = self.discriminator(real_images)
logits_fake = self.discriminator(reconstructions.detach())
discriminator_loss = discriminator_factor * self.discriminator_loss(
logits_real=logits_real, logits_fake=logits_fake
)
lecam_loss = torch.zeros((), device=inputs.device)
if self.lecam_regularization_weight > 0.0:
lecam_loss = (
gan_utils.compute_lecam_loss(
torch.mean(logits_real),
torch.mean(logits_fake),
self.ema_real_logits_mean,
self.ema_fake_logits_mean,
)
* self.lecam_regularization_weight
)
self.ema_real_logits_mean = (
self.ema_real_logits_mean * self.ema_decay
+ torch.mean(logits_real).detach() * (1 - self.ema_decay)
)
self.ema_fake_logits_mean = (
self.ema_fake_logits_mean * self.ema_decay
+ torch.mean(logits_fake).detach() * (1 - self.ema_decay)
)
discriminator_loss += lecam_loss
loss_dict = dict(
discriminator_loss=discriminator_loss.detach(),
logits_real=logits_real.detach().mean(),
logits_fake=logits_fake.detach().mean(),
lecam_loss=lecam_loss.detach(),
)
return discriminator_loss, loss_dict
class MLMLoss(torch.nn.Module):
def __init__(self, label_smoothing: float = 0.1, sum_splits: bool = False):
"""Initializes the MLM loss, which is essentially a CrossEntropy loss with label smoothing.
Args:
label_smoothing -> float: The label smoothing factor.
sum_splits -> bool: Whether to sum the loss over the splits.
"""
super().__init__()
self.label_smoothing = label_smoothing
self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
self.sum_splits = sum_splits
def forward(
self, inputs: torch.Tensor, targets: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
"""Computes the MLM loss.
Args:
inputs -> torch.Tensor: The input logits.
targets -> torch.Tensor: The target tokens.
masks -> torch.Tensor: The mask for the tokens.
Returns:
loss -> torch.Tensor: The loss.
loss_dict -> Mapping[Text, torch.Tensor]: The loss dictionary for logging individual losses.
"""
b, n, m, codebook_size = inputs.shape
loss = self.criterion(inputs.reshape(-1, codebook_size), targets.view(-1))
correct_tokens = (
torch.argmax(inputs.detach(), dim=-1) == targets
).float().mean() ** m
masked_input = inputs[masks, :].detach()
masked_loss = self.criterion(masked_input, targets[masks])
masked_correct_tokens = (
torch.argmax(masked_input, dim=-1) == targets[masks]
).float().mean() ** m
if self.sum_splits:
loss *= m
masked_loss *= m
loss_dict = {
"mlm_loss": loss,
"correct_tokens": correct_tokens,
"masked_token_loss": masked_loss,
"masked_correct_tokens": masked_correct_tokens,
}
return loss, loss_dict
if __name__ == "__main__":
loss_module = MLMLoss()
batchsize = 2
codebook_dim = 4
num_codebooks = 1
logits = torch.rand((batchsize, 3, num_codebooks, codebook_dim))
targets = torch.randint(0, codebook_dim, (batchsize, 3, num_codebooks))
masks = torch.randint(0, 2, (batchsize, 3, num_codebooks), dtype=bool)
loss, loss_dict = loss_module(logits, targets, masks)
print(logits)
print(targets)
print(masks)
print(loss, loss_dict)
|