Spaces:
Runtime error
Runtime error
from __future__ import absolute_import | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import warnings | |
from . import pretrained_networks as pn | |
from .utils import normalize_tensor, l2, dssim, tensor2np, tensor2tensorlab, tensor2im | |
def spatial_average(in_tens, keepdim=True): | |
return in_tens.mean([2, 3], keepdim=keepdim) | |
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W | |
in_H, in_W = in_tens.shape[2], in_tens.shape[3] | |
return nn.Upsample(size=out_HW, mode="bilinear", align_corners=False)(in_tens) | |
# Learned perceptual metric | |
class LPIPS(nn.Module): | |
def __init__( | |
self, | |
pretrained=True, | |
net="alex", | |
version="0.1", | |
lpips=True, | |
spatial=False, | |
pnet_rand=False, | |
pnet_tune=False, | |
use_dropout=True, | |
model_path=None, | |
eval_mode=True, | |
verbose=True, | |
): | |
"""Initializes a perceptual loss torch.nn.Module | |
Parameters (default listed first) | |
--------------------------------- | |
lpips : bool | |
[True] use linear layers on top of base/trunk network | |
[False] means no linear layers; each layer is averaged together | |
pretrained : bool | |
This flag controls the linear layers, which are only in effect when lpips=True above | |
[True] means linear layers are calibrated with human perceptual judgments | |
[False] means linear layers are randomly initialized | |
pnet_rand : bool | |
[False] means trunk loaded with ImageNet classification weights | |
[True] means randomly initialized trunk | |
net : str | |
['alex','vgg','squeeze'] are the base/trunk networks available | |
version : str | |
['v0.1'] is the default and latest | |
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) | |
model_path : 'str' | |
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 | |
The following parameters should only be changed if training the network | |
eval_mode : bool | |
[True] is for test mode (default) | |
[False] is for training mode | |
pnet_tune | |
[False] keep base/trunk frozen | |
[True] tune the base/trunk network | |
use_dropout : bool | |
[True] to use dropout when training linear layers | |
[False] for no dropout when training linear layers | |
""" | |
super(LPIPS, self).__init__() | |
warnings.filterwarnings("ignore") | |
if verbose: | |
pass | |
# print( | |
# "Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]" | |
# % ( | |
# "LPIPS" if lpips else "baseline", | |
# net, | |
# version, | |
# "on" if spatial else "off", | |
# ) | |
# ) | |
self.pnet_type = net | |
self.pnet_tune = pnet_tune | |
self.pnet_rand = pnet_rand | |
self.spatial = spatial | |
self.lpips = lpips # false means baseline of just averaging all layers | |
self.version = version | |
self.scaling_layer = ScalingLayer() | |
if self.pnet_type in ["vgg", "vgg16"]: | |
net_type = pn.vgg16 | |
self.chns = [64, 128, 256, 512, 512] | |
elif self.pnet_type == "alex": | |
net_type = pn.alexnet | |
self.chns = [64, 192, 384, 256, 256] | |
elif self.pnet_type == "squeeze": | |
net_type = pn.squeezenet | |
self.chns = [64, 128, 256, 384, 384, 512, 512] | |
self.L = len(self.chns) | |
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) | |
if lpips: | |
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) | |
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) | |
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) | |
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) | |
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) | |
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] | |
if self.pnet_type == "squeeze": # 7 layers for squeezenet | |
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) | |
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) | |
self.lins += [self.lin5, self.lin6] | |
self.lins = nn.ModuleList(self.lins) | |
if pretrained: | |
if model_path is None: | |
import inspect | |
import os | |
model_path = os.path.abspath( | |
os.path.join( | |
inspect.getfile(self.__init__), | |
"..", | |
"weights/v%s/%s.pth" % (version, net), | |
) | |
) | |
if verbose: | |
pass | |
# print("Loading model from: %s" % model_path) | |
self.load_state_dict( | |
torch.load(model_path, map_location="cpu"), strict=False | |
) | |
if eval_mode: | |
self.eval() | |
def forward(self, in0, in1, retPerLayer=False, normalize=False): | |
if ( | |
normalize | |
): # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] | |
in0 = 2 * in0 - 1 | |
in1 = 2 * in1 - 1 | |
# v0.0 - original release had a bug, where input was not scaled | |
in0_input, in1_input = ( | |
(self.scaling_layer(in0), self.scaling_layer(in1)) | |
if self.version == "0.1" | |
else (in0, in1) | |
) | |
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) | |
feats0, feats1, diffs = {}, {}, {} | |
for kk in range(self.L): | |
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( | |
outs1[kk] | |
) | |
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 | |
if self.lpips: | |
if self.spatial: | |
res = [ | |
upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) | |
for kk in range(self.L) | |
] | |
else: | |
res = [ | |
spatial_average(self.lins[kk](diffs[kk]), keepdim=True) | |
for kk in range(self.L) | |
] | |
else: | |
if self.spatial: | |
res = [ | |
upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) | |
for kk in range(self.L) | |
] | |
else: | |
res = [ | |
spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) | |
for kk in range(self.L) | |
] | |
val = 0 | |
for l in range(self.L): | |
val += res[l] | |
if retPerLayer: | |
return (val, res) | |
else: | |
return val | |
class ScalingLayer(nn.Module): | |
def __init__(self): | |
super(ScalingLayer, self).__init__() | |
self.register_buffer( | |
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] | |
) | |
self.register_buffer( | |
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] | |
) | |
def forward(self, inp): | |
return (inp - self.shift) / self.scale | |
class NetLinLayer(nn.Module): | |
"""A single linear layer which does a 1x1 conv""" | |
def __init__(self, chn_in, chn_out=1, use_dropout=False): | |
super(NetLinLayer, self).__init__() | |
layers = ( | |
[ | |
nn.Dropout(), | |
] | |
if (use_dropout) | |
else [] | |
) | |
layers += [ | |
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), | |
] | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |
class Dist2LogitLayer(nn.Module): | |
"""takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True)""" | |
def __init__(self, chn_mid=32, use_sigmoid=True): | |
super(Dist2LogitLayer, self).__init__() | |
layers = [ | |
nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), | |
] | |
layers += [ | |
nn.LeakyReLU(0.2, True), | |
] | |
layers += [ | |
nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), | |
] | |
layers += [ | |
nn.LeakyReLU(0.2, True), | |
] | |
layers += [ | |
nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), | |
] | |
if use_sigmoid: | |
layers += [ | |
nn.Sigmoid(), | |
] | |
self.model = nn.Sequential(*layers) | |
def forward(self, d0, d1, eps=0.1): | |
return self.model.forward( | |
torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1) | |
) | |
class BCERankingLoss(nn.Module): | |
def __init__(self, chn_mid=32): | |
super(BCERankingLoss, self).__init__() | |
self.net = Dist2LogitLayer(chn_mid=chn_mid) | |
# self.parameters = list(self.net.parameters()) | |
self.loss = torch.nn.BCELoss() | |
def forward(self, d0, d1, judge): | |
per = (judge + 1.0) / 2.0 | |
self.logit = self.net.forward(d0, d1) | |
return self.loss(self.logit, per) | |
# L2, DSSIM metrics | |
class FakeNet(nn.Module): | |
def __init__(self, use_gpu=True, colorspace="Lab"): | |
super(FakeNet, self).__init__() | |
self.use_gpu = use_gpu | |
self.colorspace = colorspace | |
class L2(FakeNet): | |
def forward(self, in0, in1, retPerLayer=None): | |
assert in0.size()[0] == 1 # currently only supports batchSize 1 | |
if self.colorspace == "RGB": | |
(N, C, X, Y) = in0.size() | |
value = torch.mean( | |
torch.mean( | |
torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2 | |
).view(N, 1, 1, Y), | |
dim=3, | |
).view(N) | |
return value | |
elif self.colorspace == "Lab": | |
value = l2( | |
tensor2np(tensor2tensorlab(in0.data, to_norm=False)), | |
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), | |
range=100.0, | |
).astype("float") | |
ret_var = Variable(torch.Tensor((value,))) | |
if self.use_gpu: | |
ret_var = ret_var.cuda() | |
return ret_var | |
class DSSIM(FakeNet): | |
def forward(self, in0, in1, retPerLayer=None): | |
assert in0.size()[0] == 1 # currently only supports batchSize 1 | |
if self.colorspace == "RGB": | |
value = dssim( | |
1.0 * tensor2im(in0.data), | |
1.0 * tensor2im(in1.data), | |
range=255.0, | |
).astype("float") | |
elif self.colorspace == "Lab": | |
value = dssim( | |
tensor2np(tensor2tensorlab(in0.data, to_norm=False)), | |
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), | |
range=100.0, | |
).astype("float") | |
ret_var = Variable(torch.Tensor((value,))) | |
if self.use_gpu: | |
ret_var = ret_var.cuda() | |
return ret_var | |
def print_network(net): | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
print("Network", net) | |
print("Total number of parameters: %d" % num_params) | |