reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility module to handle adversarial losses without requiring to mess up the main training loop.
"""
import typing as tp
import flashy
import torch
import torch.nn as nn
import torch.nn.functional as F
ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
class AdversarialLoss(nn.Module):
"""Adversary training wrapper.
Args:
adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
where the first item is a list of logits and the second item is a list of feature maps.
optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
loss (AdvLossType): Loss function for generator training.
loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
loss_feat (FeatLossType): Feature matching loss function for generator training.
normalize (bool): Whether to normalize by number of sub-discriminators.
Example of usage:
adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
for real in loader:
noise = torch.randn(...)
fake = model(noise)
adv_loss.train_adv(fake, real)
loss, _ = adv_loss(fake, real)
loss.backward()
"""
def __init__(self,
adversary: nn.Module,
optimizer: torch.optim.Optimizer,
loss: AdvLossType,
loss_real: AdvLossType,
loss_fake: AdvLossType,
loss_feat: tp.Optional[FeatLossType] = None,
normalize: bool = True):
super().__init__()
self.adversary: nn.Module = adversary
flashy.distrib.broadcast_model(self.adversary)
self.optimizer = optimizer
self.loss = loss
self.loss_real = loss_real
self.loss_fake = loss_fake
self.loss_feat = loss_feat
self.normalize = normalize
def _save_to_state_dict(self, destination, prefix, keep_vars):
# Add the optimizer state dict inside our own.
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'optimizer'] = self.optimizer.state_dict()
return destination
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# Load optimizer state.
self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def get_adversary_pred(self, x):
"""Run adversary model, validating expected output format."""
logits, fmaps = self.adversary(x)
assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
f'Expecting a list of tensors as logits but {type(logits)} found.'
assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
for fmap in fmaps:
assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
return logits, fmaps
def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
"""Train the adversary with the given fake and real example.
We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
The first item being the logits and second item being a list of feature maps for each sub-discriminator.
This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
and call the optimizer.
"""
loss = torch.tensor(0., device=fake.device)
all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
if self.normalize:
loss /= n_sub_adversaries
self.optimizer.zero_grad()
with flashy.distrib.eager_sync_model(self.adversary):
loss.backward()
self.optimizer.step()
return loss
def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Return the loss for the generator, i.e. trying to fool the adversary,
and feature matching loss if provided.
"""
adv = torch.tensor(0., device=fake.device)
feat = torch.tensor(0., device=fake.device)
with flashy.utils.readonly(self.adversary):
all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
n_sub_adversaries = len(all_logits_fake_is_fake)
for logit_fake_is_fake in all_logits_fake_is_fake:
adv += self.loss(logit_fake_is_fake)
if self.loss_feat:
for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
feat += self.loss_feat(fmap_fake, fmap_real)
if self.normalize:
adv /= n_sub_adversaries
feat /= n_sub_adversaries
return adv, feat
def get_adv_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_loss
elif loss_type == 'hinge':
return hinge_loss
elif loss_type == 'hinge2':
return hinge2_loss
raise ValueError('Unsupported loss')
def get_fake_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_fake_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_fake_loss
raise ValueError('Unsupported loss')
def get_real_criterion(loss_type: str) -> tp.Callable:
assert loss_type in ADVERSARIAL_LOSSES
if loss_type == 'mse':
return mse_real_loss
elif loss_type in ['hinge', 'hinge2']:
return hinge_real_loss
raise ValueError('Unsupported loss')
def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
def mse_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
def hinge_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0], device=x.device)
return -x.mean()
def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0:
return torch.tensor([0.0])
return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
class FeatureMatchingLoss(nn.Module):
"""Feature matching loss for adversarial training.
Args:
loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
normalize (bool): Whether to normalize the loss.
by number of feature maps.
"""
def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
super().__init__()
self.loss = loss
self.normalize = normalize
def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
feat_loss = torch.tensor(0., device=fmap_fake[0].device)
feat_scale = torch.tensor(0., device=fmap_fake[0].device)
n_fmaps = 0
for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
assert feat_fake.shape == feat_real.shape
n_fmaps += 1
feat_loss += self.loss(feat_fake, feat_real)
feat_scale += torch.mean(torch.abs(feat_real))
if self.normalize:
feat_loss /= n_fmaps
return feat_loss