hjc-owo
init repo
966ae59
"""
Implementation of Content loss, Style loss, LPIPS and DISTS metrics
References:
.. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias
(2016). A Neural Algorithm of Artistic Style}
Association for Research in Vision and Ophthalmology (ARVO)
https://arxiv.org/abs/1508.06576
.. [2] Zhang, Richard and Isola, Phillip and Efros, et al.
(2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
"""
from typing import List, Union, Collection
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights
from .utils import _validate_input, _reduce
from .functional import similarity_map, L2Pool2d
# Map VGG names to corresponding number in torchvision layer
VGG16_LAYERS = {
"conv1_1": '0', "relu1_1": '1',
"conv1_2": '2', "relu1_2": '3',
"pool1": '4',
"conv2_1": '5', "relu2_1": '6',
"conv2_2": '7', "relu2_2": '8',
"pool2": '9',
"conv3_1": '10', "relu3_1": '11',
"conv3_2": '12', "relu3_2": '13',
"conv3_3": '14', "relu3_3": '15',
"pool3": '16',
"conv4_1": '17', "relu4_1": '18',
"conv4_2": '19', "relu4_2": '20',
"conv4_3": '21', "relu4_3": '22',
"pool4": '23',
"conv5_1": '24', "relu5_1": '25',
"conv5_2": '26', "relu5_2": '27',
"conv5_3": '28', "relu5_3": '29',
"pool5": '30',
}
VGG19_LAYERS = {
"conv1_1": '0', "relu1_1": '1',
"conv1_2": '2', "relu1_2": '3',
"pool1": '4',
"conv2_1": '5', "relu2_1": '6',
"conv2_2": '7', "relu2_2": '8',
"pool2": '9',
"conv3_1": '10', "relu3_1": '11',
"conv3_2": '12', "relu3_2": '13',
"conv3_3": '14', "relu3_3": '15',
"conv3_4": '16', "relu3_4": '17',
"pool3": '18',
"conv4_1": '19', "relu4_1": '20',
"conv4_2": '21', "relu4_2": '22',
"conv4_3": '23', "relu4_3": '24',
"conv4_4": '25', "relu4_4": '26',
"pool4": '27',
"conv5_1": '28', "relu5_1": '29',
"conv5_2": '30', "relu5_2": '31',
"conv5_3": '32', "relu5_3": '33',
"conv5_4": '34', "relu5_4": '35',
"pool5": '36',
}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Constant used in feature normalization to avoid zero division
EPS = 1e-10
class ContentLoss(_Loss):
r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks.
Uses pretrained VGG models from torchvision.
Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1]
Args:
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
layers: List of strings with layer names. Default: ``'relu3_3'``
weights: List of float weight to balance different layers
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
and computing distance. See references for details.
Examples:
>>> loss = ContentLoss()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()
References:
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
A Neural Algorithm of Artistic Style
Association for Research in Vision and Ophthalmology (ARVO)
https://arxiv.org/abs/1508.06576
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
"""
def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",),
weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False,
distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
std: List[float] = IMAGENET_STD, normalize_features: bool = False,
allow_layers_weights_mismatch: bool = False) -> None:
assert allow_layers_weights_mismatch or len(layers) == len(weights), \
f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \
f'which will cause incorrect results. Please provide weight for each layer.'
super().__init__()
if callable(feature_extractor):
self.model = feature_extractor
self.layers = layers
else:
if feature_extractor == "vgg16":
# self.model = vgg16(pretrained=True, progress=False).features
self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features
self.layers = [VGG16_LAYERS[l] for l in layers]
elif feature_extractor == "vgg19":
# self.model = vgg19(pretrained=True, progress=False).features
self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features
self.layers = [VGG19_LAYERS[l] for l in layers]
else:
raise ValueError("Unknown feature extractor")
if replace_pooling:
self.model = self.replace_pooling(self.model)
# Disable gradients
for param in self.model.parameters():
param.requires_grad_(False)
self.distance = {
"mse": nn.MSELoss,
"mae": nn.L1Loss,
}[distance](reduction='none')
self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights]
mean = torch.tensor(mean)
std = torch.tensor(std)
self.mean = mean.view(1, -1, 1, 1)
self.std = std.view(1, -1, 1, 1)
self.normalize_features = normalize_features
self.reduction = reduction
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""Computation of Content loss between feature representations of prediction :math:`x` and
target :math:`y` tensors.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
Content loss between feature representations
"""
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1))
self.model.to(x)
x_features = self.get_features(x)
y_features = self.get_features(y)
distances = self.compute_distance(x_features, y_features)
# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1)
return _reduce(loss, self.reduction)
def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]:
r"""Take L2 or L1 distance between feature maps depending on ``distance``.
Args:
x_features: Features of the input tensor.
y_features: Features of the target tensor.
Returns:
Distance between feature maps
"""
return [self.distance(x, y) for x, y in zip(x_features, y_features)]
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
Args:
x: Tensor. Shape :math:`(N, C, H, W)`.
Returns:
List of features extracted from intermediate layers
"""
# Normalize input
x = (x - self.mean.to(x)) / self.std.to(x)
features = []
for name, module in self.model._modules.items():
x = module(x)
if name in self.layers:
features.append(self.normalize(x) if self.normalize_features else x)
return features
@staticmethod
def normalize(x: torch.Tensor) -> torch.Tensor:
r"""Normalize feature maps in channel direction to unit length.
Args:
x: Tensor. Shape :math:`(N, C, H, W)`.
Returns:
Normalized input
"""
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + EPS)
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
r"""Turn All MaxPool layers into AveragePool
Args:
module: Module to change MaxPool int AveragePool
Returns:
Module with AveragePool instead MaxPool
"""
module_output = module
if isinstance(module, torch.nn.MaxPool2d):
module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
for name, child in module.named_children():
module_output.add_module(name, self.replace_pooling(child))
return module_output
class StyleLoss(ContentLoss):
r"""Creates Style loss that can be used for image style transfer or as a measure in
image to image tasks. Computes distance between Gram matrices of feature maps.
Uses pretrained VGG models from torchvision.
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
If no normalisation is required, change `mean` and `std` values accordingly.
Args:
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``.
layers: List of strings with layer names. Default: ``'relu3_3'``
weights: List of float weight to balance different layers
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
normalize_features: If true, unit-normalize each feature in channel dimension before scaling
and computing distance. See references for details.
Examples:
>>> loss = StyleLoss()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()
References:
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
A Neural Algorithm of Artistic Style
Association for Research in Vision and Ophthalmology (ARVO)
https://arxiv.org/abs/1508.06576
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
"""
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor):
r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``.
Args:
x_features: Features of the input tensor.
y_features: Features of the target tensor.
Returns:
Distance between Gram matrices
"""
x_gram = [self.gram_matrix(x) for x in x_features]
y_gram = [self.gram_matrix(x) for x in y_features]
return [self.distance(x, y) for x, y in zip(x_gram, y_gram)]
@staticmethod
def gram_matrix(x: torch.Tensor) -> torch.Tensor:
r"""Compute Gram matrix for batch of features.
Args:
x: Tensor. Shape :math:`(N, C, H, W)`.
Returns:
Gram matrix for given input
"""
B, C, H, W = x.size()
gram = []
for i in range(B):
features = x[i].view(C, H * W)
# Add fake channel dimension
gram.append(torch.mm(features, features.t()).unsqueeze(0))
return torch.stack(gram)
class LPIPS(ContentLoss):
r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported.
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
If no normalisation is required, change `mean` and `std` values accordingly.
Args:
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details.
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``.
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
Examples:
>>> loss = LPIPS()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()
References:
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016).
A Neural Algorithm of Artistic Style
Association for Research in Vision and Ophthalmology (ARVO)
https://arxiv.org/abs/1508.06576
Zhang, Richard and Isola, Phillip and Efros, et al. (2018)
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric
IEEE/CVF Conference on Computer Vision and Pattern Recognition
https://arxiv.org/abs/1801.03924
https://github.com/richzhang/PerceptualSimilarity
"""
_weights_url = "https://github.com/photosynthesis-team/" + \
"photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt"
def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean",
mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None:
lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights,
replace_pooling=replace_pooling, distance=distance,
reduction=reduction, mean=mean, std=std,
normalize_features=True)
class DISTS(ContentLoss):
r"""Deep Image Structure and Texture Similarity metric.
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1].
If no normalisation is required, change `mean` and `std` values accordingly.
Args:
reduction: Specifies the reduction type:
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'``
mean: List of float values used for data standardization. Default: ImageNet mean.
If there is no need to normalize data, use [0., 0., 0.].
std: List of float values used for data standardization. Default: ImageNet std.
If there is no need to normalize data, use [1., 1., 1.].
Examples:
>>> loss = DISTS()
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True)
>>> y = torch.rand(3, 3, 256, 256)
>>> output = loss(x, y)
>>> output.backward()
References:
Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020).
Image Quality Assessment: Unifying Structure and Texture Similarity.
https://arxiv.org/abs/2004.07728
https://github.com/dingkeyan93/DISTS
"""
_weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt"
def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN,
std: List[float] = IMAGENET_STD) -> None:
dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']
channels = [3, 64, 128, 256, 512, 512]
weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False)
dists_weights = list(torch.split(weights['alpha'], channels, dim=1))
dists_weights.extend(torch.split(weights['beta'], channels, dim=1))
super().__init__("vgg16", layers=dists_layers, weights=dists_weights,
replace_pooling=True, reduction=reduction, mean=mean, std=std,
normalize_features=False, allow_layers_weights_mismatch=True)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Args:
x: An input tensor. Shape :math:`(N, C, H, W)`.
y: A target tensor. Shape :math:`(N, C, H, W)`.
Returns:
Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1].
"""
_, _, H, W = x.shape
if min(H, W) > 256:
x = torch.nn.functional.interpolate(
x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
y = torch.nn.functional.interpolate(
y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear')
loss = super().forward(x, y)
return 1 - loss
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]:
r"""Compute structure similarity between feature maps
Args:
x_features: Features of the input tensor.
y_features: Features of the target tensor.
Returns:
Structural similarity distance between feature maps
"""
structure_distance, texture_distance = [], []
# Small constant for numerical stability
EPS = 1e-6
for x, y in zip(x_features, y_features):
x_mean = x.mean([2, 3], keepdim=True)
y_mean = y.mean([2, 3], keepdim=True)
structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS))
x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean
texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS))
return structure_distance + texture_distance
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]:
r"""
Args:
x: Input tensor
Returns:
List of features extracted from input tensor
"""
features = super().get_features(x)
# Add input tensor as an additional feature
features.insert(0, x)
return features
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module:
r"""Turn All MaxPool layers into L2Pool
Args:
module: Module to change MaxPool into L2Pool
Returns:
Module with L2Pool instead of MaxPool
"""
module_output = module
if isinstance(module, torch.nn.MaxPool2d):
module_output = L2Pool2d(kernel_size=3, stride=2, padding=1)
for name, child in module.named_children():
module_output.add_module(name, self.replace_pooling(child))
return module_output