File size: 3,870 Bytes
ad552d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from PIL import Image
import numpy as np
import torch
from .features import build_feature_extractor, get_reference_statistics
from .fid import get_batch_features, fid_from_feats
from .resize import build_resizer


"""
A helper class that allowing adding the images one batch at a time.
"""


class CleanFID:
    def __init__(self, mode="clean", model_name="inception_v3", device="cuda"):
        self.real_features = []
        self.gen_features = []
        self.mode = mode
        self.device = device
        if model_name == "inception_v3":
            self.feat_model = build_feature_extractor(mode, device)
            self.fn_resize = build_resizer(mode)
        elif model_name == "clip_vit_b_32":
            from .clip_features import CLIP_fx, img_preprocess_clip

            clip_fx = CLIP_fx("ViT-B/32")
            self.feat_model = clip_fx
            self.fn_resize = img_preprocess_clip

    """
    Funtion that takes an image (PIL.Image or np.array or torch.tensor)
    and returns the corresponding feature embedding vector.
    The image x is expected to be in range [0, 255]
    """

    def compute_features(self, x):
        # if x is a PIL Image
        if isinstance(x, Image.Image):
            x_np = np.array(x)
            x_np_resized = self.fn_resize(x_np)
            x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0)
            x_feat = get_batch_features(x_t, self.feat_model, self.device)
        elif isinstance(x, np.ndarray):
            x_np_resized = self.fn_resize(x)
            x_t = (
                torch.tensor(x_np_resized.transpose((2, 0, 1)))
                .unsqueeze(0)
                .to(self.device)
            )
            # normalization happens inside the self.feat_model, expected image range here is [0,255]
            x_feat = get_batch_features(x_t, self.feat_model, self.device)
        elif isinstance(x, torch.Tensor):
            # pdb.set_trace()
            # add the batch dimension if x is passed in as C,H,W
            if len(x.shape) == 3:
                x = x.unsqueeze(0)
            b, c, h, w = x.shape
            # convert back to np array and resize
            l_x_np_resized = []
            for _ in range(b):
                x_np = x[_].cpu().numpy().transpose((1, 2, 0))
                l_x_np_resized.append(self.fn_resize(x_np)[None,])
            x_np_resized = np.concatenate(l_x_np_resized)
            x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device)
            # normalization happens inside the self.feat_model, expected image range here is [0,255]
            x_feat = get_batch_features(x_t, self.feat_model, self.device)
        else:
            raise ValueError("image type could not be inferred")
        return x_feat

    """
    Extract the faetures from x and add to the list of reference real images
    """

    def add_real_images(self, x):
        x_feat = self.compute_features(x)
        self.real_features.append(x_feat)

    """
    Extract the faetures from x and add to the list of generated images
    """

    def add_gen_images(self, x):
        x_feat = self.compute_features(x)
        self.gen_features.append(x_feat)

    """
    Compute FID between the real and generated images added so far
    """

    def calculate_fid(self, verbose=True):
        feats1 = np.concatenate(self.real_features)
        feats2 = np.concatenate(self.gen_features)
        if verbose:
            print(f"# real images = {feats1.shape[0]}")
            print(f"# generated images = {feats2.shape[0]}")
        return fid_from_feats(feats1, feats2)

    """
    Remove the real image features added so far
    """

    def reset_real_features(self):
        self.real_features = []

    """
    Remove the generated image features added so far
    """

    def reset_gen_features(self):
        self.gen_features = []