Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
#################################################################################################### | |
# adversarial loss for different gan mode | |
#################################################################################################### | |
class GANLoss(nn.Module): | |
"""Define different GAN objectives. | |
The GANLoss class abstracts away the need to create the target label tensor | |
that has the same size as the input. | |
""" | |
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): | |
""" Initialize the GANLoss class. | |
Parameters: | |
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. | |
target_real_label (bool) - - label for a real image | |
target_fake_label (bool) - - label of a fake image | |
Note: Do not use sigmoid as the last layer of Discriminator. | |
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. | |
""" | |
super(GANLoss, self).__init__() | |
self.register_buffer('real_label', torch.tensor(target_real_label)) | |
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | |
self.gan_mode = gan_mode | |
if gan_mode == 'lsgan': | |
self.loss = nn.MSELoss() | |
elif gan_mode == 'vanilla': | |
self.loss = nn.BCEWithLogitsLoss() | |
elif gan_mode == 'hinge': | |
self.loss = nn.ReLU() | |
elif gan_mode in ['wgangp', 'nonsaturating']: | |
self.loss = None | |
else: | |
raise NotImplementedError('gan mode %s not implemented' % gan_mode) | |
def get_target_tensor(self, prediction, target_is_real): | |
"""Create label tensors with the same size as the input. | |
Parameters: | |
prediction (tensor) - - tpyically the prediction from a discriminator | |
target_is_real (bool) - - if the ground truth label is for real examples or fake examples | |
Returns: | |
A label tensor filled with ground truth label, and with the size of the input | |
""" | |
if target_is_real: | |
target_tensor = self.real_label | |
else: | |
target_tensor = self.fake_label | |
return target_tensor.expand_as(prediction) | |
def calculate_loss(self, prediction, target_is_real, is_dis=False): | |
"""Calculate loss given Discriminator's output and grount truth labels. | |
Parameters: | |
prediction (tensor) - - tpyically the prediction output from a discriminator | |
target_is_real (bool) - - if the ground truth label is for real examples or fake examples | |
Returns: | |
the calculated loss. | |
""" | |
if self.gan_mode in ['lsgan', 'vanilla']: | |
target_tensor = self.get_target_tensor(prediction, target_is_real) | |
loss = self.loss(prediction, target_tensor) | |
if self.gan_mode == 'lsgan': | |
loss = loss * 0.5 | |
else: | |
if is_dis: | |
if target_is_real: | |
prediction = -prediction | |
if self.gan_mode == 'wgangp': | |
loss = prediction.mean() | |
elif self.gan_mode == 'nonsaturating': | |
loss = F.softplus(prediction).mean() | |
elif self.gan_mode == 'hinge': | |
loss = self.loss(1+prediction).mean() | |
else: | |
if self.gan_mode == 'nonsaturating': | |
loss = F.softplus(-prediction).mean() | |
else: | |
loss = -prediction.mean() | |
return loss | |
def __call__(self, predictions, target_is_real, is_dis=False): | |
"""Calculate loss for multi-scales gan""" | |
if isinstance(predictions, list): | |
losses = [] | |
for prediction in predictions: | |
losses.append(self.calculate_loss(prediction, target_is_real, is_dis)) | |
loss = sum(losses) | |
else: | |
loss = self.calculate_loss(predictions, target_is_real, is_dis) | |
return loss | |
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): | |
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 | |
Arguments: | |
netD (network) -- discriminator network | |
real_data (tensor array) -- real examples | |
fake_data (tensor array) -- generated examples from the generator | |
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') | |
type (str) -- if we mix real and fake data or not [real | fake | mixed]. | |
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 | |
lambda_gp (float) -- weight for this loss | |
Returns the gradient penalty loss | |
""" | |
if lambda_gp > 0.0: | |
if type == 'real': # either use real examples, fake examples, or a linear interpolation of two. | |
interpolatesv = real_data | |
elif type == 'fake': | |
interpolatesv = fake_data | |
elif type == 'mixed': | |
alpha = torch.rand(real_data.shape[0], 1, device=device) | |
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) | |
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) | |
else: | |
raise NotImplementedError('{} not implemented'.format(type)) | |
interpolatesv.requires_grad_(True) | |
disc_interpolates = netD(interpolatesv) | |
if isinstance(disc_interpolates, list): | |
gradients = 0 | |
for disc_interpolate in disc_interpolates: | |
gradients += torch.autograd.grad(outputs=disc_interpolate, inputs=interpolatesv, | |
grad_outputs=torch.ones(disc_interpolate.size()).to(device), | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
else: | |
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, | |
grad_outputs=torch.ones(disc_interpolates.size()).to(device), | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
gradients = gradients.view(real_data.size(0), -1) # flat the data | |
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps | |
return gradient_penalty, gradients | |
else: | |
return 0.0, None | |
#################################################################################################### | |
# trained LPIPS loss | |
#################################################################################################### | |
def normalize_tensor(x, eps=1e-10): | |
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) | |
return x/(norm_factor+eps) | |
def spatial_average(x, keepdim=True): | |
return x.mean([2, 3], keepdim=keepdim) | |
class NetLinLayer(nn.Module): | |
""" A single linear layer which does a 1x1 conv """ | |
def __init__(self, chn_in, chn_out=1, use_dropout=False): | |
super(NetLinLayer, self).__init__() | |
layers = [nn.Dropout(), ] if (use_dropout) else [] | |
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] | |
self.model = nn.Sequential(*layers) | |
class LPIPSLoss(nn.Module): | |
""" | |
Learned perceptual metric | |
https://github.com/richzhang/PerceptualSimilarity | |
""" | |
def __init__(self, use_dropout=True, ckpt_path=None): | |
super(LPIPSLoss, self).__init__() | |
self.path = ckpt_path | |
self.net = VGG16() | |
self.chns = [64, 128, 256, 512, 512] # vg16 features | |
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) | |
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) | |
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) | |
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) | |
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) | |
self.load_from_pretrained() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def load_from_pretrained(self): | |
self.load_state_dict(torch.load(self.path, map_location=torch.device("cpu")), strict=False) | |
print("loaded pretrained LPIPS loss from {}".format(self.path)) | |
def _get_features(self, vgg_f): | |
names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
feats = [] | |
for i in range(len(names)): | |
name = names[i] | |
feat = vgg_f[name] | |
feats.append(feat) | |
return feats | |
def forward(self, x, y): | |
x_vgg, y_vgg = self._get_features(self.net(x)), self._get_features(self.net(y)) | |
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
reses = [] | |
loss = 0 | |
for i in range(len(self.chns)): | |
x_feats, y_feats = normalize_tensor(x_vgg[i]), normalize_tensor(y_vgg[i]) | |
diffs = (x_feats - y_feats) ** 2 | |
res = spatial_average(lins[i].model(diffs)) | |
loss += res | |
reses.append(res) | |
return loss | |
class PerceptualLoss(nn.Module): | |
r""" | |
Perceptual loss, VGG-based | |
https://arxiv.org/abs/1603.08155 | |
https://github.com/dxyang/StyleTransfer/blob/master/utils.py | |
""" | |
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 0.0]): | |
super(PerceptualLoss, self).__init__() | |
self.add_module('vgg', VGG16()) | |
self.criterion = nn.L1Loss() | |
self.weights = weights | |
def __call__(self, x, y): | |
# Compute features | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
content_loss = 0.0 | |
content_loss += self.weights[0] * self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) if self.weights[0] > 0 else 0 | |
content_loss += self.weights[1] * self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) if self.weights[1] > 0 else 0 | |
content_loss += self.weights[2] * self.criterion(x_vgg['relu3_3'], y_vgg['relu3_3']) if self.weights[2] > 0 else 0 | |
content_loss += self.weights[3] * self.criterion(x_vgg['relu4_3'], y_vgg['relu4_3']) if self.weights[3] > 0 else 0 | |
content_loss += self.weights[4] * self.criterion(x_vgg['relu5_3'], y_vgg['relu5_3']) if self.weights[4] > 0 else 0 | |
return content_loss | |
class Normalization(nn.Module): | |
def __init__(self, device): | |
super(Normalization, self).__init__() | |
# .view the mean and std to make them [C x 1 x 1] so that they can | |
# directly work with image Tensor of shape [B x C x H x W]. | |
# B is batch size. C is number of channels. H is height and W is width. | |
mean = torch.tensor([0.485, 0.456, 0.406]).to(device) | |
std = torch.tensor([0.229, 0.224, 0.225]).to(device) | |
self.mean = mean.view(-1, 1, 1) | |
self.std = std.view(-1, 1, 1) | |
def forward(self, img): | |
# normalize img | |
return (img - self.mean) / self.std | |
class VGG16(nn.Module): | |
def __init__(self): | |
super(VGG16, self).__init__() | |
features = models.vgg16(pretrained=True).features | |
self.relu1_1 = torch.nn.Sequential() | |
self.relu1_2 = torch.nn.Sequential() | |
self.relu2_1 = torch.nn.Sequential() | |
self.relu2_2 = torch.nn.Sequential() | |
self.relu3_1 = torch.nn.Sequential() | |
self.relu3_2 = torch.nn.Sequential() | |
self.relu3_3 = torch.nn.Sequential() | |
self.relu4_1 = torch.nn.Sequential() | |
self.relu4_2 = torch.nn.Sequential() | |
self.relu4_3 = torch.nn.Sequential() | |
self.relu5_1 = torch.nn.Sequential() | |
self.relu5_2 = torch.nn.Sequential() | |
self.relu5_3 = torch.nn.Sequential() | |
for x in range(2): | |
self.relu1_1.add_module(str(x), features[x]) | |
for x in range(2, 4): | |
self.relu1_2.add_module(str(x), features[x]) | |
for x in range(4, 7): | |
self.relu2_1.add_module(str(x), features[x]) | |
for x in range(7, 9): | |
self.relu2_2.add_module(str(x), features[x]) | |
for x in range(9, 12): | |
self.relu3_1.add_module(str(x), features[x]) | |
for x in range(12, 14): | |
self.relu3_2.add_module(str(x), features[x]) | |
for x in range(14, 16): | |
self.relu3_3.add_module(str(x), features[x]) | |
for x in range(16, 18): | |
self.relu4_1.add_module(str(x), features[x]) | |
for x in range(18, 21): | |
self.relu4_2.add_module(str(x), features[x]) | |
for x in range(21, 23): | |
self.relu4_3.add_module(str(x), features[x]) | |
for x in range(23, 26): | |
self.relu5_1.add_module(str(x), features[x]) | |
for x in range(26, 28): | |
self.relu5_2.add_module(str(x), features[x]) | |
for x in range(28, 30): | |
self.relu5_3.add_module(str(x), features[x]) | |
# don't need the gradients, just want the features | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, x,): | |
relu1_1 = self.relu1_1(x) | |
relu1_2 = self.relu1_2(relu1_1) | |
relu2_1 = self.relu2_1(relu1_2) | |
relu2_2 = self.relu2_2(relu2_1) | |
relu3_1 = self.relu3_1(relu2_2) | |
relu3_2 = self.relu3_2(relu3_1) | |
relu3_3 = self.relu3_3(relu3_2) | |
relu4_1 = self.relu4_1(relu3_3) | |
relu4_2 = self.relu4_2(relu4_1) | |
relu4_3 = self.relu4_3(relu4_2) | |
relu5_1 = self.relu5_1(relu4_3) | |
relu5_2 = self.relu5_2(relu5_1) | |
relu5_3 = self.relu5_3(relu5_2) | |
out = { | |
'relu1_1': relu1_1, | |
'relu1_2': relu1_2, | |
'relu2_1': relu2_1, | |
'relu2_2': relu2_2, | |
'relu3_1': relu3_1, | |
'relu3_2': relu3_2, | |
'relu3_3': relu3_3, | |
'relu4_1': relu4_1, | |
'relu4_2': relu4_2, | |
'relu4_3': relu4_3, | |
'relu5_1': relu5_1, | |
'relu5_2': relu5_2, | |
'relu5_3': relu5_3, | |
} | |
return out |