from __future__ import absolute_import import os import torch import torch.nn as nn from . import pretrained_networks as pretrained_torch_models def spatial_average(x, keepdim=True): return x.mean([2, 3], keepdim=keepdim) def upsample(x): return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x) def normalize_tensor(in_feat, eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) return in_feat / (norm_factor + eps) # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): """ Initializes a perceptual loss torch.nn.Module Parameters (default listed first) --------------------------------- lpips : bool [True] use linear layers on top of base/trunk network [False] means no linear layers; each layer is averaged together pretrained : bool This flag controls the linear layers, which are only in effect when lpips=True above [True] means linear layers are calibrated with human perceptual judgments [False] means linear layers are randomly initialized pnet_rand : bool [False] means trunk loaded with ImageNet classification weights [True] means randomly initialized trunk net : str ['alex','vgg','squeeze'] are the base/trunk networks available version : str ['v0.1'] is the default and latest ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) model_path : 'str' [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 The following parameters should only be changed if training the network: eval_mode : bool [True] is for test mode (default) [False] is for training mode pnet_tune [False] keep base/trunk frozen [True] tune the base/trunk network use_dropout : bool [True] to use dropout when training linear layers [False] for no dropout when training linear layers """ super(LPIPS, self).__init__() if verbose: print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ['vgg', 'vgg16']: net_type = pretrained_torch_models.vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == 'alex': net_type = pretrained_torch_models.alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == 'squeeze': net_type = pretrained_torch_models.squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == 'squeeze': # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) if pretrained: if model_path is None: model_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), f"weights/v{version}/{net}.pth" ) if verbose: print('Loading model from: %s' % model_path) self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) if eval_mode: self.eval() def forward(self, in0, in1, return_per_layer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # Noting: v0.0 - original release had a bug, where input was not scaled if self.version == '0.1': in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) else: in0_input, in1_input = in0, in1 # model forward outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.lpips: if self.spatial: res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if self.spatial: res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] loss = sum(res) if return_per_layer: return loss, res else: return loss class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv""" def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if (use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)