hr16's picture
Fork adriansahlman's stylegan2_pytorch
480bfbc
import numpy as np
import torch
from torch.nn import functional as F
from . import utils
def _grad(input, output, retain_graph):
# https://discuss.pytorch.org/t/gradient-penalty-loss-with-modified-weights/64910
# Currently not possible to not
# retain graph for regularization losses.
# Ugly hack is to always set it to True.
retain_graph = True
grads = torch.autograd.grad(
output.sum(),
input,
only_inputs=True,
retain_graph=retain_graph,
create_graph=True
)
return grads[0]
def _grad_pen(input, output, gamma, constraint=1, onesided=False, retain_graph=True):
grad = _grad(input, output, retain_graph=retain_graph)
grad = grad.view(grad.size(0), -1)
grad_norm = grad.norm(2, dim=1)
if onesided:
gp = torch.max(0, grad_norm - constraint)
else:
gp = (grad_norm - constraint) ** 2
return gamma * gp.mean()
def _grad_reg(input, output, gamma, retain_graph=True):
grad = _grad(input, output, retain_graph=retain_graph)
grad = grad.view(grad.size(0), -1)
gr = (grad ** 2).sum(1)
return (0.5 * gamma) * gr.mean()
def _pathreg(dlatents, fakes, pl_avg, pl_decay, gamma, retain_graph=True):
retain_graph = True
pl_noise = torch.empty_like(fakes).normal_().div_(np.sqrt(np.prod(fakes.size()[2:])))
pl_grad = _grad(dlatents, torch.sum(pl_noise * fakes), retain_graph=retain_graph)
pl_length = torch.sqrt(torch.mean(torch.sum(pl_grad ** 2, dim=2), dim=1))
with torch.no_grad():
pl_avg.add_(pl_decay * (torch.mean(pl_length) - pl_avg))
return gamma * torch.mean((pl_length - pl_avg) ** 2)
#----------------------------------------------------------------------------
# Logistic loss from the paper
# "Generative Adversarial Nets", Goodfellow et al. 2014
def G_logistic(G,
D,
latents,
latent_labels=None,
*args,
**kwargs):
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
loss = - F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
reg = None
return loss, reg
def G_logistic_ns(G,
D,
latents,
latent_labels=None,
*args,
**kwargs):
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
reg = None
return loss, reg
def D_logistic(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
*args,
**kwargs):
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
real_scores = D(reals, labels=real_labels).float()
fake_scores = D(fakes, labels=latent_labels).float()
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
loss = real_loss + fake_loss
reg = None
return loss, reg
#----------------------------------------------------------------------------
# R1 and R2 regularizers from the paper
# "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
def D_r1(D,
reals,
real_labels=None,
gamma=10,
*args,
**kwargs):
loss = None
reg = None
if gamma:
reals.requires_grad_(True)
real_scores = D(reals, labels=real_labels)
reg = _grad_reg(
input=reals, output=real_scores, gamma=gamma, retain_graph=False).float()
return loss, reg
def D_r2(D,
G,
latents,
latent_labels=None,
gamma=10,
*args,
**kwargs):
loss = None
reg = None
if gamma:
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
fakes.requires_grad_(True)
fake_scores = D(fakes, labels=latent_labels)
reg = _grad_reg(
input=fakes, output=fake_scores, gamma=gamma, retain_graph=False).float()
return loss, reg
def D_logistic_r1(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
gamma=10,
*args,
**kwargs):
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
if gamma:
reals.requires_grad_(True)
real_scores = D(reals, labels=real_labels).float()
fake_scores = D(fakes, labels=latent_labels).float()
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
loss = real_loss + fake_loss
reg = None
if gamma:
reg = _grad_reg(
input=reals, output=real_scores, gamma=gamma, retain_graph=True).float()
return loss, reg
def D_logistic_r2(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
gamma=10,
*args,
**kwargs):
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
if gamma:
fakes.requires_grad_(True)
real_scores = D(reals, labels=real_labels).float()
fake_scores = D(fakes, labels=latent_labels).float()
real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
loss = real_loss + fake_loss
reg = None
if gamma:
reg = _grad_reg(
input=fakes, output=fake_scores, gamma=gamma, retain_graph=True).float()
return loss, reg
#----------------------------------------------------------------------------
# Non-saturating logistic loss with path length regularizer from the paper
# "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
def G_pathreg(G,
latents,
pl_avg,
latent_labels=None,
pl_decay=0.01,
gamma=2,
*args,
**kwargs):
loss = None
reg = None
if gamma:
fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True, mapping_grad=False)
reg = _pathreg(
dlatents=dlatents,
fakes=fakes,
pl_avg=pl_avg,
pl_decay=pl_decay,
gamma=gamma,
retain_graph=False
).float()
return loss, reg
def G_logistic_ns_pathreg(G,
D,
latents,
pl_avg,
latent_labels=None,
pl_decay=0.01,
gamma=2,
*args,
**kwargs):
fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True)
fake_scores = D(fakes, labels=latent_labels).float()
loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
reg = None
if gamma:
reg = _pathreg(
dlatents=dlatents,
fakes=fakes,
pl_avg=pl_avg,
pl_decay=pl_decay,
gamma=gamma,
retain_graph=True
).float()
return loss, reg
#----------------------------------------------------------------------------
# WGAN loss from the paper
# "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
def G_wgan(G,
D,
latents,
latent_labels=None,
*args,
**kwargs):
fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
loss = -fake_scores.mean()
reg = None
return loss, reg
def D_wgan(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
drift_gamma=0.001,
*args,
**kwargs):
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
real_scores = D(reals, labels=real_labels).float()
fake_scores = D(fakes, labels=latent_labels).float()
loss = fake_scores.mean() - real_scores.mean()
if drift_gamma:
loss += drift_gamma * torch.mean(real_scores ** 2)
reg = None
return loss, reg
#----------------------------------------------------------------------------
# WGAN-GP loss from the paper
# "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
def D_gp(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
gamma=0,
constraint=1,
*args,
**kwargs):
loss = None
reg = None
if gamma:
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
assert reals.size() == fakes.size()
if latent_labels:
assert latent_labels == real_labels
alpha = torch.empty(reals.size(0)).uniform_()
alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
interp_scores = D(interp, labels=latent_labels)
reg = _grad_pen(
input=interp, output=interp_scores, gamma=gamma, retain_graph=False).float()
return loss, reg
def D_wgan_gp(G,
D,
latents,
reals,
latent_labels=None,
real_labels=None,
gamma=0,
drift_gamma=0.001,
constraint=1,
*args,
**kwargs):
assert (latent_labels is None) == (real_labels is None)
with torch.no_grad():
fakes = G(latents, labels=latent_labels)
real_scores = D(reals, labels=real_labels).float()
fake_scores = D(fakes, labels=latent_labels).float()
loss = fake_scores.mean() - real_scores.mean()
if drift_gamma:
loss += drift_gamma * torch.mean(real_scores ** 2)
reg = None
if gamma:
assert reals.size() == fakes.size()
if latent_labels:
assert latent_labels == real_labels
alpha = torch.empty(reals.size(0)).uniform_()
alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
interp_scores = D(interp, labels=latent_labels)
reg = _grad_pen(
input=interp, output=interp_scores, gamma=gamma, retain_graph=True).float()
return loss, reg