from typing import Optional, List import torch from torch import Tensor from torchmetrics import Metric import torchvision.models as models from torchvision import transforms from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE if _TORCH_FIDELITY_AVAILABLE: from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 else: class FeatureExtractorInceptionV3(Module): # type: ignore pass __doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"] class NoTrainInceptionV3(FeatureExtractorInceptionV3): def __init__( self, name: str, features_list: List[str], feature_extractor_weights_path: Optional[str] = None, ) -> None: super().__init__(name, features_list, feature_extractor_weights_path) # put into evaluation mode self.eval() def train(self, mode: bool) -> "NoTrainInceptionV3": """the inception network should not be able to be switched away from evaluation mode.""" return super().train(False) def forward(self, x: Tensor) -> Tensor: out = super().forward(x) return out[0].reshape(x.shape[0], -1) # -------------------------- VGG Trans --------------------------- # class Normalize(object): # """Rescale the image from 0-255 (uint8) to [0,1] (float32). # Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!""" # def __call__(self, image): # return image/255 # # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html # VGG_Trans = transforms.Compose([ # transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True), # # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR), # # transforms.CenterCrop(224), # Normalize(), # scale to [0, 1] # transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]) # ]) class ImprovedPrecessionRecall(Metric): is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5): super().__init__() # ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------ # Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40 # Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574 # TODO: Add option to switch between Inception and VGG feature extractor # self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval() # self.feature_extractor = transforms.Compose([ # VGG_Trans, # self.vgg_model.features, # transforms.Lambda(lambda x: torch.flatten(x, 1)), # self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features # ]) if isinstance(feature, int): if not _TORCH_FIDELITY_AVAILABLE: raise ModuleNotFoundError( "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed." " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`." ) valid_int_input = [64, 192, 768, 2048] if feature not in valid_int_input: raise ValueError( f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}." ) self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)]) elif isinstance(feature, torch.nn.Module): self.feature_extractor = feature else: raise TypeError("Got unknown input to argument `feature`") # --------------------------- End Feature Extractor --------------------------------------------------------------- self.knn = knn self.splits_real = splits_real self.splits_fake = splits_fake self.add_state("real_features", [], dist_reduce_fx=None) self.add_state("fake_features", [], dist_reduce_fx=None) def update(self, imgs: Tensor, real: bool) -> None: # type: ignore """Update the state with extracted features. Args: imgs: tensor with images feed to the feature extractor real: bool indicating if ``imgs`` belong to the real or the fake distribution """ assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8' features = self.feature_extractor(imgs).view(imgs.shape[0], -1) if real: self.real_features.append(features) else: self.fake_features.append(features) def compute(self): real_features = torch.concat(self.real_features) fake_features = torch.concat(self.fake_features) real_distances = _compute_pairwise_distances(real_features, self.splits_real) real_radii = _distances2radii(real_distances, self.knn) fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake) fake_radii = _distances2radii(fake_distances, self.knn) precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake) recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real) return precision, recall def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits): dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits) num_feat = pred_features.shape[0] count = 0 for i in range(num_feat): count += (dist[:, i] < ref_radii).any() return count / num_feat def _distances2radii(distances, knn): return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0] def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None): # X = [B, features] # Y = [B', features] Y = X if Y is None else Y # X = X.double() # Y = Y.double() splits_y = splits_x if splits_y is None else splits_y dist = torch.concat([ torch.concat([ (torch.sum(X_batch**2, dim=1, keepdim=True) + torch.sum(Y_batch**2, dim=1, keepdim=True).t() - 2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t())) for Y_batch in Y.chunk(splits_y, dim=0)], dim=1) for X_batch in X.chunk(splits_x, dim=0)]) # dist = torch.maximum(dist, torch.zeros_like(dist)) dist[dist<0] = 0 return torch.sqrt(dist)