File size: 6,318 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""A VGG-based perceptual loss function for PyTorch."""
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models, transforms
import torch
import torch.nn as nn
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Lambda(nn.Module):
"""Wraps a callable in an :class:`nn.Module` without registering it."""
def __init__(self, func):
super().__init__()
object.__setattr__(self, 'forward', func)
def extra_repr(self):
return getattr(self.forward, '__name__', type(self.forward).__name__) + '()'
class WeightedLoss(nn.ModuleList):
"""A weighted combination of multiple loss functions."""
def __init__(self, losses, weights, verbose=False):
super().__init__()
for loss in losses:
self.append(loss if isinstance(loss, nn.Module) else Lambda(loss))
self.weights = weights
self.verbose = verbose
def _print_losses(self, losses):
for i, loss in enumerate(losses):
print(f'({i}) {type(self[i]).__name__}: {loss.item()}')
def forward(self, *args, **kwargs):
losses = []
for loss, weight in zip(self, self.weights):
losses.append(loss(*args, **kwargs) * weight)
if self.verbose:
self._print_losses(losses)
return sum(losses)
class TVLoss(nn.Module):
"""Total variation loss (Lp penalty on image gradient magnitude).
The input must be 4D. If a target (second parameter) is passed in, it is
ignored.
``p=1`` yields the vectorial total variation norm. It is a generalization
of the originally proposed (isotropic) 2D total variation norm (see
(see https://en.wikipedia.org/wiki/Total_variation_denoising) for color
images. On images with a single channel it is equal to the 2D TV norm.
``p=2`` yields a variant that is often used for smoothing out noise in
reconstructions of images from neural network feature maps (see Mahendran
and Vevaldi, "Understanding Deep Image Representations by Inverting
Them", https://arxiv.org/abs/1412.0035)
:attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'``
similarly to the loss functions in :mod:`torch.nn`. The default is
``'mean'``.
"""
def __init__(self, p, reduction='mean', eps=1e-8):
super().__init__()
if p not in {1, 2}:
raise ValueError('p must be 1 or 2')
if reduction not in {'mean', 'sum', 'none'}:
raise ValueError("reduction must be 'mean', 'sum', or 'none'")
self.p = p
self.reduction = reduction
self.eps = eps
def forward(self, input, target=None):
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, :-1] - input[..., :-1, 1:]
y_diff = input[..., :-1, :-1] - input[..., 1:, :-1]
diff = x_diff**2 + y_diff**2
if self.p == 1:
diff = (diff + self.eps).mean(dim=1, keepdims=True).sqrt()
if self.reduction == 'mean':
return diff.mean()
if self.reduction == 'sum':
return diff.sum()
return diff
@LOSSFUNC.register_module(module_name="vgg_loss")
class VGGLoss(AbstractLossClass):
"""Computes the VGG perceptual loss between two batches of images.
The input and target must be 4D tensors with three channels
``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be
normalized to the range 0–1.
The VGG perceptual loss is the mean squared difference between the features
computed for the input and target at layer :attr:`layer` (default 8, or
``relu2_2``) of the pretrained model specified by :attr:`model` (either
``'vgg16'`` (default) or ``'vgg19'``).
If :attr:`shift` is nonzero, a random shift of at most :attr:`shift`
pixels in both height and width will be applied to all images in the input
and target. The shift will only be applied when the loss function is in
training mode, and will not be applied if a precomputed feature map is
supplied as the target.
:attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'``
similarly to the loss functions in :mod:`torch.nn`. The default is
``'mean'``.
:meth:`get_features()` may be used to precompute the features for the
target, to speed up the case where inputs are compared against the same
target over and over. To use the precomputed features, pass them in as
:attr:`target` and set :attr:`target_is_features` to :code:`True`.
Instances of :class:`VGGLoss` must be manually converted to the same
device and dtype as their inputs.
"""
models = {'vgg16': models.vgg16, 'vgg19': models.vgg19}
def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'):
super().__init__()
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
self.shift = shift
self.reduction = reduction
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.model = self.models[model](pretrained=True).features[:layer+1]
self.model.eval()
self.model.requires_grad_(False)
self.model.to(device)
def get_features(self, input):
return self.model(self.normalize(input))
def train(self, mode=True):
self.training = mode
def forward(self, input, target, target_is_features=False):
if target_is_features:
input_feats = self.get_features(input)
target_feats = target
else:
sep = input.shape[0]
batch = torch.cat([input, target])
if self.shift and self.training:
padded = F.pad(batch, [self.shift] * 4, mode='replicate')
batch = transforms.RandomCrop(batch.shape[2:])(padded)
feats = self.get_features(batch)
input_feats, target_feats = feats[:sep], feats[sep:]
# input_feats, target_feats = \
# self.instancenorm(input_feats), \
# self.instancenorm(target_feats)
return F.mse_loss(input_feats, target_feats, reduction=self.reduction) |