from __future__ import absolute_import import torch import torch.nn as nn from torch.autograd import Variable import warnings from . import pretrained_networks as pn from .utils import normalize_tensor, l2, dssim, tensor2np, tensor2tensorlab, tensor2im def spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): def __init__( self, pretrained=True, net="alex", version="0.1", lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True, ): """Initializes a perceptual loss torch.nn.Module Parameters (default listed first) --------------------------------- lpips : bool [True] use linear layers on top of base/trunk network [False] means no linear layers; each layer is averaged together pretrained : bool This flag controls the linear layers, which are only in effect when lpips=True above [True] means linear layers are calibrated with human perceptual judgments [False] means linear layers are randomly initialized pnet_rand : bool [False] means trunk loaded with ImageNet classification weights [True] means randomly initialized trunk net : str ['alex','vgg','squeeze'] are the base/trunk networks available version : str ['v0.1'] is the default and latest ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) model_path : 'str' [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 The following parameters should only be changed if training the network eval_mode : bool [True] is for test mode (default) [False] is for training mode pnet_tune [False] keep base/trunk frozen [True] tune the base/trunk network use_dropout : bool [True] to use dropout when training linear layers [False] for no dropout when training linear layers """ super(LPIPS, self).__init__() warnings.filterwarnings("ignore") if verbose: pass # print( # "Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]" # % ( # "LPIPS" if lpips else "baseline", # net, # version, # "on" if spatial else "off", # ) # ) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: net_type = pn.vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == "alex": net_type = pn.alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": net_type = pn.squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == "squeeze": # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) if pretrained: if model_path is None: import inspect import os model_path = os.path.abspath( os.path.join( inspect.getfile(self.__init__), "..", "weights/v%s/%s.pth" % (version, net), ) ) if verbose: pass # print("Loading model from: %s" % model_path) self.load_state_dict( torch.load(model_path, map_location="cpu"), strict=False ) if eval_mode: self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): if ( normalize ): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = ( (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) ) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( outs1[kk] ) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.lpips: if self.spatial: res = [ upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L) ] else: res = [ spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L) ] else: if self.spatial: res = [ upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L) ] else: res = [ spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L) ] val = 0 for l in range(self.L): val += res[l] if retPerLayer: return (val, res) else: return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer( "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] ) self.register_buffer( "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] ) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv""" def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = ( [ nn.Dropout(), ] if (use_dropout) else [] ) layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): """takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)""" def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [ nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [ nn.LeakyReLU(0.2, True), ] layers += [ nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] if use_sigmoid: layers += [ nn.Sigmoid(), ] self.model = nn.Sequential(*layers) def forward(self, d0, d1, eps=0.1): return self.model.forward( torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) ) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge + 1.0) / 2.0 self.logit = self.net.forward(d0, d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace="Lab"): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": (N, C, X, Y) = in0.size() value = torch.mean( torch.mean( torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 ).view(N, 1, 1, Y), dim=3, ).view(N) return value elif self.colorspace == "Lab": value = l2( tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert in0.size()[0] == 1 # currently only supports batchSize 1 if self.colorspace == "RGB": value = dssim( 1.0 * tensor2im(in0.data), 1.0 * tensor2im(in1.data), range=255.0, ).astype("float") elif self.colorspace == "Lab": value = dssim( tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.0, ).astype("float") ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print("Network", net) print("Total number of parameters: %d" % num_params)