File size: 3,314 Bytes
ad552d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
helpers for extracting features from image
"""
import os
import platform
import numpy as np
import torch
from torch.hub import get_dir
from .downloads_helper import check_download_url
from .inception_pytorch import InceptionV3
from .inception_torchscript import InceptionV3W


"""
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""


def feature_extractor(
    name="torchscript_inception",
    device=torch.device("cuda"),
    resize_inside=False,
    use_dataparallel=True,
):
    if name == "torchscript_inception":
        path = "./" if platform.system() == "Windows" else "/tmp"
        model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
            device
        )
        model.eval()
        if use_dataparallel:
            model = torch.nn.DataParallel(model)

        def model_fn(x):
            return model(x)

    elif name == "pytorch_inception":
        model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
        model.eval()
        if use_dataparallel:
            model = torch.nn.DataParallel(model)

        def model_fn(x):
            return model(x / 255)[0].squeeze(-1).squeeze(-1)

    else:
        raise ValueError(f"{name} feature extractor not implemented")
    return model_fn


"""
Build a feature extractor for each of the modes
"""


def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
    if mode == "legacy_pytorch":
        feat_model = feature_extractor(
            name="pytorch_inception",
            resize_inside=False,
            device=device,
            use_dataparallel=use_dataparallel,
        )
    elif mode == "legacy_tensorflow":
        feat_model = feature_extractor(
            name="torchscript_inception",
            resize_inside=True,
            device=device,
            use_dataparallel=use_dataparallel,
        )
    elif mode == "clean":
        feat_model = feature_extractor(
            name="torchscript_inception",
            resize_inside=False,
            device=device,
            use_dataparallel=use_dataparallel,
        )
    return feat_model


"""
Load precomputed reference statistics for commonly used datasets
"""


def get_reference_statistics(
    name,
    res,
    mode="clean",
    model_name="inception_v3",
    seed=0,
    split="test",
    metric="FID",
):
    base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
    if split == "custom":
        res = "na"
    if model_name == "inception_v3":
        model_modifier = ""
    else:
        model_modifier = "_" + model_name
    if metric == "FID":
        rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz")
        url = f"{base_url}/{rel_path}"
        stats_folder = os.path.join(get_dir(), "fid_stats")
        fpath = check_download_url(local_folder=stats_folder, url=url)
        stats = np.load(fpath)
        mu, sigma = stats["mu"], stats["sigma"]
        return mu, sigma
    elif metric == "KID":
        rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
        url = f"{base_url}/{rel_path}"
        stats_folder = os.path.join(get_dir(), "fid_stats")
        fpath = check_download_url(local_folder=stats_folder, url=url)
        stats = np.load(fpath)
        return stats["feats"]