File size: 6,834 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Optional, List

import torch 
from torch import Tensor
from torchmetrics import Metric
import torchvision.models as models
from torchvision import transforms



from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE

if _TORCH_FIDELITY_AVAILABLE:
    from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
else:
    class FeatureExtractorInceptionV3(Module):  # type: ignore
        pass
    __doctest_skip__ = ["ImprovedPrecessionRecall", "IPR"]

class NoTrainInceptionV3(FeatureExtractorInceptionV3):
    def __init__(
        self,
        name: str,
        features_list: List[str],
        feature_extractor_weights_path: Optional[str] = None,
    ) -> None:
        super().__init__(name, features_list, feature_extractor_weights_path)
        # put into evaluation mode
        self.eval()

    def train(self, mode: bool) -> "NoTrainInceptionV3":
        """the inception network should not be able to be switched away from evaluation mode."""
        return super().train(False)

    def forward(self, x: Tensor) -> Tensor:
        out = super().forward(x)
        return out[0].reshape(x.shape[0], -1)


# -------------------------- VGG Trans ---------------------------
# class Normalize(object):
#     """Rescale the image from 0-255 (uint8) to [0,1] (float32). 
#        Note, this doesn't ensure that min=0 and max=1 as a min-max scale would do!"""

#     def __call__(self, image):
#         return image/255

# # see https://pytorch.org/vision/main/models/generated/torchvision.models.vgg16.html 
# VGG_Trans = transforms.Compose([
#     transforms.Resize([224, 224], interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
#     # transforms.Resize([256, 256], interpolation=InterpolationMode.BILINEAR),
#     # transforms.CenterCrop(224),
#     Normalize(), # scale to [0, 1]
#     transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
# ])        



class ImprovedPrecessionRecall(Metric):
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False


    def __init__(self, feature=2048, knn=3, splits_real=1, splits_fake=5):
        super().__init__()


        # ------------------------- Init Feature Extractor (VGG or Inception) ------------------------------
        # Original VGG: https://github.com/kynkaat/improved-precision-and-recall-metric/blob/b0247eafdead494a5d243bd2efb1b0b124379ae9/utils.py#L40 
        # Compare Inception: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/evaluations/evaluator.py#L574 
        # TODO: Add option to switch between Inception and VGG feature extractor 
        # self.vgg_model = models.vgg16(weights='IMAGENET1K_V1').eval()
        # self.feature_extractor = transforms.Compose([
        #     VGG_Trans, 
        #     self.vgg_model.features,
        #     transforms.Lambda(lambda x: torch.flatten(x, 1)),
        #     self.vgg_model.classifier[:4] # [:4] corresponds to 4096 features 
        # ])

        if isinstance(feature, int):
            if not _TORCH_FIDELITY_AVAILABLE:
                raise ModuleNotFoundError(
                    "FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
                    " Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
                )
            valid_int_input = [64, 192, 768, 2048]
            if feature not in valid_int_input:
                raise ValueError(
                    f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
                )

            self.feature_extractor = NoTrainInceptionV3(name="inception-v3-compat", features_list=[str(feature)])
        elif isinstance(feature, torch.nn.Module):
            self.feature_extractor = feature
        else:
            raise TypeError("Got unknown input to argument `feature`")

        # --------------------------- End Feature Extractor ---------------------------------------------------------------

        self.knn = knn 
        self.splits_real = splits_real
        self.splits_fake = splits_fake
        self.add_state("real_features", [], dist_reduce_fx=None)
        self.add_state("fake_features", [], dist_reduce_fx=None)

        

    def update(self, imgs: Tensor, real: bool) -> None:  # type: ignore
        """Update the state with extracted features.

        Args:
            imgs: tensor with images feed to the feature extractor
            real: bool indicating if ``imgs`` belong to the real or the fake distribution
        """
        assert torch.is_tensor(imgs) and imgs.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8'

        features = self.feature_extractor(imgs).view(imgs.shape[0], -1)  

        if real:
            self.real_features.append(features)
        else:
            self.fake_features.append(features)

    def compute(self):
        real_features = torch.concat(self.real_features)
        fake_features = torch.concat(self.fake_features)

        real_distances = _compute_pairwise_distances(real_features, self.splits_real)
        real_radii = _distances2radii(real_distances, self.knn)

        fake_distances = _compute_pairwise_distances(fake_features, self.splits_fake)
        fake_radii = _distances2radii(fake_distances, self.knn)

        precision = _compute_metric(real_features, real_radii, self.splits_real, fake_features, self.splits_fake)
        recall = _compute_metric(fake_features, fake_radii, self.splits_fake, real_features, self.splits_real)

        return precision, recall
    
def _compute_metric(ref_features, ref_radii, ref_splits, pred_features, pred_splits):
    dist = _compute_pairwise_distances(ref_features, ref_splits, pred_features, pred_splits)
    num_feat = pred_features.shape[0] 
    count = 0
    for i in range(num_feat):
        count += (dist[:, i] < ref_radii).any()
    return count / num_feat

def _distances2radii(distances, knn):
    return torch.topk(distances, knn+1, dim=1, largest=False)[0].max(dim=1)[0]

def _compute_pairwise_distances(X, splits_x, Y=None, splits_y=None):
    # X = [B, features]
    # Y = [B', features]
    Y = X if Y is None else Y
    # X = X.double()
    # Y = Y.double()
    splits_y = splits_x if splits_y is None else splits_y
    dist = torch.concat([
        torch.concat([
            (torch.sum(X_batch**2, dim=1, keepdim=True) + 
             torch.sum(Y_batch**2, dim=1, keepdim=True).t() - 
             2 * torch.einsum("bd,dn->bn", X_batch, Y_batch.t())) 
        for Y_batch in Y.chunk(splits_y, dim=0)], dim=1)
        for X_batch in X.chunk(splits_x, dim=0)])

    # dist = torch.maximum(dist, torch.zeros_like(dist))
    dist[dist<0] = 0
    return torch.sqrt(dist)