Spaces:
Running
Running
""" | |
Implementation of Content loss, Style loss, LPIPS and DISTS metrics | |
References: | |
.. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias | |
(2016). A Neural Algorithm of Artistic Style} | |
Association for Research in Vision and Ophthalmology (ARVO) | |
https://arxiv.org/abs/1508.06576 | |
.. [2] Zhang, Richard and Isola, Phillip and Efros, et al. | |
(2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
https://arxiv.org/abs/1801.03924 | |
""" | |
from typing import List, Union, Collection | |
import torch | |
import torch.nn as nn | |
from torch.nn.modules.loss import _Loss | |
from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights | |
from .utils import _validate_input, _reduce | |
from .functional import similarity_map, L2Pool2d | |
# Map VGG names to corresponding number in torchvision layer | |
VGG16_LAYERS = { | |
"conv1_1": '0', "relu1_1": '1', | |
"conv1_2": '2', "relu1_2": '3', | |
"pool1": '4', | |
"conv2_1": '5', "relu2_1": '6', | |
"conv2_2": '7', "relu2_2": '8', | |
"pool2": '9', | |
"conv3_1": '10', "relu3_1": '11', | |
"conv3_2": '12', "relu3_2": '13', | |
"conv3_3": '14', "relu3_3": '15', | |
"pool3": '16', | |
"conv4_1": '17', "relu4_1": '18', | |
"conv4_2": '19', "relu4_2": '20', | |
"conv4_3": '21', "relu4_3": '22', | |
"pool4": '23', | |
"conv5_1": '24', "relu5_1": '25', | |
"conv5_2": '26', "relu5_2": '27', | |
"conv5_3": '28', "relu5_3": '29', | |
"pool5": '30', | |
} | |
VGG19_LAYERS = { | |
"conv1_1": '0', "relu1_1": '1', | |
"conv1_2": '2', "relu1_2": '3', | |
"pool1": '4', | |
"conv2_1": '5', "relu2_1": '6', | |
"conv2_2": '7', "relu2_2": '8', | |
"pool2": '9', | |
"conv3_1": '10', "relu3_1": '11', | |
"conv3_2": '12', "relu3_2": '13', | |
"conv3_3": '14', "relu3_3": '15', | |
"conv3_4": '16', "relu3_4": '17', | |
"pool3": '18', | |
"conv4_1": '19', "relu4_1": '20', | |
"conv4_2": '21', "relu4_2": '22', | |
"conv4_3": '23', "relu4_3": '24', | |
"conv4_4": '25', "relu4_4": '26', | |
"pool4": '27', | |
"conv5_1": '28', "relu5_1": '29', | |
"conv5_2": '30', "relu5_2": '31', | |
"conv5_3": '32', "relu5_3": '33', | |
"conv5_4": '34', "relu5_4": '35', | |
"pool5": '36', | |
} | |
IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
IMAGENET_STD = [0.229, 0.224, 0.225] | |
# Constant used in feature normalization to avoid zero division | |
EPS = 1e-10 | |
class ContentLoss(_Loss): | |
r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks. | |
Uses pretrained VGG models from torchvision. | |
Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1] | |
Args: | |
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. | |
layers: List of strings with layer names. Default: ``'relu3_3'`` | |
weights: List of float weight to balance different layers | |
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
reduction: Specifies the reduction type: | |
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
mean: List of float values used for data standardization. Default: ImageNet mean. | |
If there is no need to normalize data, use [0., 0., 0.]. | |
std: List of float values used for data standardization. Default: ImageNet std. | |
If there is no need to normalize data, use [1., 1., 1.]. | |
normalize_features: If true, unit-normalize each feature in channel dimension before scaling | |
and computing distance. See references for details. | |
Examples: | |
>>> loss = ContentLoss() | |
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
>>> y = torch.rand(3, 3, 256, 256) | |
>>> output = loss(x, y) | |
>>> output.backward() | |
References: | |
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
A Neural Algorithm of Artistic Style | |
Association for Research in Vision and Ophthalmology (ARVO) | |
https://arxiv.org/abs/1508.06576 | |
Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
https://arxiv.org/abs/1801.03924 | |
""" | |
def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",), | |
weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False, | |
distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, | |
std: List[float] = IMAGENET_STD, normalize_features: bool = False, | |
allow_layers_weights_mismatch: bool = False) -> None: | |
assert allow_layers_weights_mismatch or len(layers) == len(weights), \ | |
f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \ | |
f'which will cause incorrect results. Please provide weight for each layer.' | |
super().__init__() | |
if callable(feature_extractor): | |
self.model = feature_extractor | |
self.layers = layers | |
else: | |
if feature_extractor == "vgg16": | |
# self.model = vgg16(pretrained=True, progress=False).features | |
self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features | |
self.layers = [VGG16_LAYERS[l] for l in layers] | |
elif feature_extractor == "vgg19": | |
# self.model = vgg19(pretrained=True, progress=False).features | |
self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features | |
self.layers = [VGG19_LAYERS[l] for l in layers] | |
else: | |
raise ValueError("Unknown feature extractor") | |
if replace_pooling: | |
self.model = self.replace_pooling(self.model) | |
# Disable gradients | |
for param in self.model.parameters(): | |
param.requires_grad_(False) | |
self.distance = { | |
"mse": nn.MSELoss, | |
"mae": nn.L1Loss, | |
}[distance](reduction='none') | |
self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights] | |
mean = torch.tensor(mean) | |
std = torch.tensor(std) | |
self.mean = mean.view(1, -1, 1, 1) | |
self.std = std.view(1, -1, 1, 1) | |
self.normalize_features = normalize_features | |
self.reduction = reduction | |
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
r"""Computation of Content loss between feature representations of prediction :math:`x` and | |
target :math:`y` tensors. | |
Args: | |
x: An input tensor. Shape :math:`(N, C, H, W)`. | |
y: A target tensor. Shape :math:`(N, C, H, W)`. | |
Returns: | |
Content loss between feature representations | |
""" | |
_validate_input([x, y], dim_range=(4, 4), data_range=(0, -1)) | |
self.model.to(x) | |
x_features = self.get_features(x) | |
y_features = self.get_features(y) | |
distances = self.compute_distance(x_features, y_features) | |
# Scale distances, then average in spatial dimensions, then stack and sum in channels dimension | |
loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) | |
return _reduce(loss, self.reduction) | |
def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]: | |
r"""Take L2 or L1 distance between feature maps depending on ``distance``. | |
Args: | |
x_features: Features of the input tensor. | |
y_features: Features of the target tensor. | |
Returns: | |
Distance between feature maps | |
""" | |
return [self.distance(x, y) for x, y in zip(x_features, y_features)] | |
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
Args: | |
x: Tensor. Shape :math:`(N, C, H, W)`. | |
Returns: | |
List of features extracted from intermediate layers | |
""" | |
# Normalize input | |
x = (x - self.mean.to(x)) / self.std.to(x) | |
features = [] | |
for name, module in self.model._modules.items(): | |
x = module(x) | |
if name in self.layers: | |
features.append(self.normalize(x) if self.normalize_features else x) | |
return features | |
def normalize(x: torch.Tensor) -> torch.Tensor: | |
r"""Normalize feature maps in channel direction to unit length. | |
Args: | |
x: Tensor. Shape :math:`(N, C, H, W)`. | |
Returns: | |
Normalized input | |
""" | |
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | |
return x / (norm_factor + EPS) | |
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: | |
r"""Turn All MaxPool layers into AveragePool | |
Args: | |
module: Module to change MaxPool int AveragePool | |
Returns: | |
Module with AveragePool instead MaxPool | |
""" | |
module_output = module | |
if isinstance(module, torch.nn.MaxPool2d): | |
module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | |
for name, child in module.named_children(): | |
module_output.add_module(name, self.replace_pooling(child)) | |
return module_output | |
class StyleLoss(ContentLoss): | |
r"""Creates Style loss that can be used for image style transfer or as a measure in | |
image to image tasks. Computes distance between Gram matrices of feature maps. | |
Uses pretrained VGG models from torchvision. | |
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
If no normalisation is required, change `mean` and `std` values accordingly. | |
Args: | |
feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. | |
layers: List of strings with layer names. Default: ``'relu3_3'`` | |
weights: List of float weight to balance different layers | |
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
reduction: Specifies the reduction type: | |
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
mean: List of float values used for data standardization. Default: ImageNet mean. | |
If there is no need to normalize data, use [0., 0., 0.]. | |
std: List of float values used for data standardization. Default: ImageNet std. | |
If there is no need to normalize data, use [1., 1., 1.]. | |
normalize_features: If true, unit-normalize each feature in channel dimension before scaling | |
and computing distance. See references for details. | |
Examples: | |
>>> loss = StyleLoss() | |
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
>>> y = torch.rand(3, 3, 256, 256) | |
>>> output = loss(x, y) | |
>>> output.backward() | |
References: | |
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
A Neural Algorithm of Artistic Style | |
Association for Research in Vision and Ophthalmology (ARVO) | |
https://arxiv.org/abs/1508.06576 | |
Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
https://arxiv.org/abs/1801.03924 | |
""" | |
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor): | |
r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``. | |
Args: | |
x_features: Features of the input tensor. | |
y_features: Features of the target tensor. | |
Returns: | |
Distance between Gram matrices | |
""" | |
x_gram = [self.gram_matrix(x) for x in x_features] | |
y_gram = [self.gram_matrix(x) for x in y_features] | |
return [self.distance(x, y) for x, y in zip(x_gram, y_gram)] | |
def gram_matrix(x: torch.Tensor) -> torch.Tensor: | |
r"""Compute Gram matrix for batch of features. | |
Args: | |
x: Tensor. Shape :math:`(N, C, H, W)`. | |
Returns: | |
Gram matrix for given input | |
""" | |
B, C, H, W = x.size() | |
gram = [] | |
for i in range(B): | |
features = x[i].view(C, H * W) | |
# Add fake channel dimension | |
gram.append(torch.mm(features, features.t()).unsqueeze(0)) | |
return torch.stack(gram) | |
class LPIPS(ContentLoss): | |
r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported. | |
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
If no normalisation is required, change `mean` and `std` values accordingly. | |
Args: | |
replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. | |
distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. | |
reduction: Specifies the reduction type: | |
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
mean: List of float values used for data standardization. Default: ImageNet mean. | |
If there is no need to normalize data, use [0., 0., 0.]. | |
std: List of float values used for data standardization. Default: ImageNet std. | |
If there is no need to normalize data, use [1., 1., 1.]. | |
Examples: | |
>>> loss = LPIPS() | |
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
>>> y = torch.rand(3, 3, 256, 256) | |
>>> output = loss(x, y) | |
>>> output.backward() | |
References: | |
Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). | |
A Neural Algorithm of Artistic Style | |
Association for Research in Vision and Ophthalmology (ARVO) | |
https://arxiv.org/abs/1508.06576 | |
Zhang, Richard and Isola, Phillip and Efros, et al. (2018) | |
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric | |
IEEE/CVF Conference on Computer Vision and Pattern Recognition | |
https://arxiv.org/abs/1801.03924 | |
https://github.com/richzhang/PerceptualSimilarity | |
""" | |
_weights_url = "https://github.com/photosynthesis-team/" + \ | |
"photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt" | |
def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean", | |
mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None: | |
lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) | |
super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights, | |
replace_pooling=replace_pooling, distance=distance, | |
reduction=reduction, mean=mean, std=std, | |
normalize_features=True) | |
class DISTS(ContentLoss): | |
r"""Deep Image Structure and Texture Similarity metric. | |
By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. | |
If no normalisation is required, change `mean` and `std` values accordingly. | |
Args: | |
reduction: Specifies the reduction type: | |
``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` | |
mean: List of float values used for data standardization. Default: ImageNet mean. | |
If there is no need to normalize data, use [0., 0., 0.]. | |
std: List of float values used for data standardization. Default: ImageNet std. | |
If there is no need to normalize data, use [1., 1., 1.]. | |
Examples: | |
>>> loss = DISTS() | |
>>> x = torch.rand(3, 3, 256, 256, requires_grad=True) | |
>>> y = torch.rand(3, 3, 256, 256) | |
>>> output = loss(x, y) | |
>>> output.backward() | |
References: | |
Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020). | |
Image Quality Assessment: Unifying Structure and Texture Similarity. | |
https://arxiv.org/abs/2004.07728 | |
https://github.com/dingkeyan93/DISTS | |
""" | |
_weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt" | |
def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, | |
std: List[float] = IMAGENET_STD) -> None: | |
dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] | |
channels = [3, 64, 128, 256, 512, 512] | |
weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) | |
dists_weights = list(torch.split(weights['alpha'], channels, dim=1)) | |
dists_weights.extend(torch.split(weights['beta'], channels, dim=1)) | |
super().__init__("vgg16", layers=dists_layers, weights=dists_weights, | |
replace_pooling=True, reduction=reduction, mean=mean, std=std, | |
normalize_features=False, allow_layers_weights_mismatch=True) | |
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
r""" | |
Args: | |
x: An input tensor. Shape :math:`(N, C, H, W)`. | |
y: A target tensor. Shape :math:`(N, C, H, W)`. | |
Returns: | |
Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1]. | |
""" | |
_, _, H, W = x.shape | |
if min(H, W) > 256: | |
x = torch.nn.functional.interpolate( | |
x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') | |
y = torch.nn.functional.interpolate( | |
y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') | |
loss = super().forward(x, y) | |
return 1 - loss | |
def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]: | |
r"""Compute structure similarity between feature maps | |
Args: | |
x_features: Features of the input tensor. | |
y_features: Features of the target tensor. | |
Returns: | |
Structural similarity distance between feature maps | |
""" | |
structure_distance, texture_distance = [], [] | |
# Small constant for numerical stability | |
EPS = 1e-6 | |
for x, y in zip(x_features, y_features): | |
x_mean = x.mean([2, 3], keepdim=True) | |
y_mean = y.mean([2, 3], keepdim=True) | |
structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS)) | |
x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True) | |
y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True) | |
xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean | |
texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS)) | |
return structure_distance + texture_distance | |
def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: | |
r""" | |
Args: | |
x: Input tensor | |
Returns: | |
List of features extracted from input tensor | |
""" | |
features = super().get_features(x) | |
# Add input tensor as an additional feature | |
features.insert(0, x) | |
return features | |
def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: | |
r"""Turn All MaxPool layers into L2Pool | |
Args: | |
module: Module to change MaxPool into L2Pool | |
Returns: | |
Module with L2Pool instead of MaxPool | |
""" | |
module_output = module | |
if isinstance(module, torch.nn.MaxPool2d): | |
module_output = L2Pool2d(kernel_size=3, stride=2, padding=1) | |
for name, child in module.named_children(): | |
module_output.add_module(name, self.replace_pooling(child)) | |
return module_output | |