medfusion-app / medical_diffusion /metrics /torchmetrics_pr_recall.py
mueller-franzes's picture
init
f85e212
raw
history blame
6.83 kB
from typing import Optional, List
import torch
from torch import Tensor
from torchmetrics import Metric
import torchvision.models as models
from torchvision import transforms
from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE
if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
else:
class FeatureExtractorInceptionV3(Module): # type: ignore
pass
__doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"]
class NoTrainInceptionV3(FeatureExtractorInceptionV3):
def __init__(
self,
name: str,
features_list: List[str],
feature_extractor_weights_path: Optional[str] = None,
) -> None:
super().__init__(name, features_list, feature_extractor_weights_path)
# put into evaluation mode
self.eval()
def train(self, mode: bool) -> "NoTrainInceptionV3":
"""the inception network should not be able to be switched away from evaluation mode."""
return super().train(False)
def forward(self, x: Tensor) -> Tensor:
out = super().forward(x)
return out[0].reshape(x.shape[0], -1)
# -------------------------- VGG Trans ---------------------------
# class Normalize(object):
# """Rescale the image from 0-255 (uint8) to [0,1] (float32).
# Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!"""
# def __call__(self, image):
# return image/255
# # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html
# VGG_Trans = transforms.Compose([
# transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
# # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR),
# # transforms.CenterCrop(224),
# Normalize(), # scale to [0, 1]
# transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
# ])
class ImprovedPrecessionRecall(Metric):
is_differentiable: bool = False
higher_is_better: bool = True
full_state_update: bool = False
def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5):
super().__init__()
# ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------
# Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40
# Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574
# TODO: Add option to switch between Inception and VGG feature extractor
# self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval()
# self.feature_extractor = transforms.Compose([
# VGG_Trans,
# self.vgg_model.features,
# transforms.Lambda(lambda x: torch.flatten(x, 1)),
# self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features
# ])
if isinstance(feature, int):
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = [64, 192, 768, 2048]
if feature not in valid_int_input:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
)
self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
elif isinstance(feature, torch.nn.Module):
self.feature_extractor = feature
else:
raise TypeError("Got unknown input to argument `feature`")
# --------------------------- End Feature Extractor ---------------------------------------------------------------
self.knn = knn
self.splits_real = splits_real
self.splits_fake = splits_fake
self.add_state("real_features", [], dist_reduce_fx=None)
self.add_state("fake_features", [], dist_reduce_fx=None)
def update(self, imgs: Tensor, real: bool) -> None: # type: ignore
"""Update the state with extracted features.
Args:
imgs: tensor with images feed to the feature extractor
real: bool indicating if ``imgs`` belong to the real or the fake distribution
"""
assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8'
features = self.feature_extractor(imgs).view(imgs.shape[0], -1)
if real:
self.real_features.append(features)
else:
self.fake_features.append(features)
def compute(self):
real_features = torch.concat(self.real_features)
fake_features = torch.concat(self.fake_features)
real_distances = _compute_pairwise_distances(real_features, self.splits_real)
real_radii = _distances2radii(real_distances, self.knn)
fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake)
fake_radii = _distances2radii(fake_distances, self.knn)
precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake)
recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real)
return precision, recall
def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits):
dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits)
num_feat = pred_features.shape[0]
count = 0
for i in range(num_feat):
count += (dist[:, i] < ref_radii).any()
return count / num_feat
def _distances2radii(distances, knn):
return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0]
def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None):
# X = [B, features]
# Y = [B', features]
Y = X if Y is None else Y
# X = X.double()
# Y = Y.double()
splits_y = splits_x if splits_y is None else splits_y
dist = torch.concat([
torch.concat([
(torch.sum(X_batch**2, dim=1, keepdim=True) +
torch.sum(Y_batch**2, dim=1, keepdim=True).t() -
2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t()))
for Y_batch in Y.chunk(splits_y, dim=0)], dim=1)
for X_batch in X.chunk(splits_x, dim=0)])
# dist = torch.maximum(dist, torch.zeros_like(dist))
dist[dist<0] = 0
return torch.sqrt(dist)