Spaces:
Runtime error
Runtime error
File size: 6,834 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from typing import Optional, List
import torch
from torch import Tensor
from torchmetrics import Metric
import torchvision.models as models
from torchvision import transforms
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
else:
class FeatureExtractorInceptionV3(Module): # type: ignore
pass
__doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"]
class NoTrainInceptionV3(FeatureExtractorInceptionV3):
def __init__(
self,
name: str,
features_list: List[str],
feature_extractor_weights_path: Optional[str] = None,
) -> None:
super().__init__(name, features_list, feature_extractor_weights_path)
# put into evaluation mode
self.eval()
def train(self, mode: bool) -> "NoTrainInceptionV3":
"""the inception network should not be able to be switched away from evaluation mode."""
return super().train(False)
def forward(self, x: Tensor) -> Tensor:
out = super().forward(x)
return out[0].reshape(x.shape[0], -1)
# -------------------------- VGG Trans ---------------------------
# class Normalize(object):
# """Rescale the image from 0-255 (uint8) to [0,1] (float32).
# Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!"""
# def __call__(self, image):
# return image/255
# # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html
# VGG_Trans = transforms.Compose([
# transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
# # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR),
# # transforms.CenterCrop(224),
# Normalize(), # scale to [0, 1]
# transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
# ])
class ImprovedPrecessionRecall(Metric):
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5):
super().__init__()
# ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------
# Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40
# Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574
# TODO: Add option to switch between Inception and VGG feature extractor
# self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval()
# self.feature_extractor = transforms.Compose([
# VGG_Trans,
# self.vgg_model.features,
# transforms.Lambda(lambda x: torch.flatten(x, 1)),
# self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features
# ])
if isinstance(feature, int):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)
self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, torch.nn.Module):
self.feature_extractor = feature
else:
raise TypeError("Got unknown input to argument `feature`")
# --------------------------- End Feature Extractor ---------------------------------------------------------------
self.knn = knn
self.splits_real = splits_real
self.splits_fake = splits_fake
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
"""Update the state with extracted features.
Args:
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8'
features = self.feature_extractor(imgs).view(imgs.shape[0], -1)
if real:
self.real_features.append(features)
else:
self.fake_features.append(features)
def compute(self):
real_features = torch.concat(self.real_features)
fake_features = torch.concat(self.fake_features)
real_distances = _compute_pairwise_distances(real_features, self.splits_real)
real_radii = _distances2radii(real_distances, self.knn)
fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake)
fake_radii = _distances2radii(fake_distances, self.knn)
precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake)
recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real)
return precision, recall
def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits):
dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits)
num_feat = pred_features.shape[0]
count = 0
for i in range(num_feat):
count += (dist[:, i] < ref_radii).any()
return count / num_feat
def _distances2radii(distances, knn):
return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0]
def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None):
# X = [B, features]
# Y = [B', features]
Y = X if Y is None else Y
# X = X.double()
# Y = Y.double()
splits_y = splits_x if splits_y is None else splits_y
dist = torch.concat([
torch.concat([
(torch.sum(X_batch**2, dim=1, keepdim=True) +
torch.sum(Y_batch**2, dim=1, keepdim=True).t() -
2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t()))
for Y_batch in Y.chunk(splits_y, dim=0)], dim=1)
for X_batch in X.chunk(splits_x, dim=0)])
# dist = torch.maximum(dist, torch.zeros_like(dist))
dist[dist<0] = 0
return torch.sqrt(dist)
|