Spaces:
Runtime error
Runtime error
""" | |
helpers for extracting features from image | |
""" | |
import os | |
import platform | |
import numpy as np | |
import torch | |
from torch.hub import get_dir | |
from .downloads_helper import check_download_url | |
from .inception_pytorch import InceptionV3 | |
from .inception_torchscript import InceptionV3W | |
""" | |
returns a functions that takes an image in range [0,255] | |
and outputs a feature embedding vector | |
""" | |
def feature_extractor( | |
name="torchscript_inception", | |
device=torch.device("cuda"), | |
resize_inside=False, | |
use_dataparallel=True, | |
): | |
if name == "torchscript_inception": | |
path = "./" if platform.system() == "Windows" else "/tmp" | |
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to( | |
device | |
) | |
model.eval() | |
if use_dataparallel: | |
model = torch.nn.DataParallel(model) | |
def model_fn(x): | |
return model(x) | |
elif name == "pytorch_inception": | |
model = InceptionV3(output_blocks=[3], resize_input=False).to(device) | |
model.eval() | |
if use_dataparallel: | |
model = torch.nn.DataParallel(model) | |
def model_fn(x): | |
return model(x / 255)[0].squeeze(-1).squeeze(-1) | |
else: | |
raise ValueError(f"{name} feature extractor not implemented") | |
return model_fn | |
""" | |
Build a feature extractor for each of the modes | |
""" | |
def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True): | |
if mode == "legacy_pytorch": | |
feat_model = feature_extractor( | |
name="pytorch_inception", | |
resize_inside=False, | |
device=device, | |
use_dataparallel=use_dataparallel, | |
) | |
elif mode == "legacy_tensorflow": | |
feat_model = feature_extractor( | |
name="torchscript_inception", | |
resize_inside=True, | |
device=device, | |
use_dataparallel=use_dataparallel, | |
) | |
elif mode == "clean": | |
feat_model = feature_extractor( | |
name="torchscript_inception", | |
resize_inside=False, | |
device=device, | |
use_dataparallel=use_dataparallel, | |
) | |
return feat_model | |
""" | |
Load precomputed reference statistics for commonly used datasets | |
""" | |
def get_reference_statistics( | |
name, | |
res, | |
mode="clean", | |
model_name="inception_v3", | |
seed=0, | |
split="test", | |
metric="FID", | |
): | |
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/" | |
if split == "custom": | |
res = "na" | |
if model_name == "inception_v3": | |
model_modifier = "" | |
else: | |
model_modifier = "_" + model_name | |
if metric == "FID": | |
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz") | |
url = f"{base_url}/{rel_path}" | |
stats_folder = os.path.join(get_dir(), "fid_stats") | |
fpath = check_download_url(local_folder=stats_folder, url=url) | |
stats = np.load(fpath) | |
mu, sigma = stats["mu"], stats["sigma"] | |
return mu, sigma | |
elif metric == "KID": | |
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz") | |
url = f"{base_url}/{rel_path}" | |
stats_folder = os.path.join(get_dir(), "fid_stats") | |
fpath = check_download_url(local_folder=stats_folder, url=url) | |
stats = np.load(fpath) | |
return stats["feats"] | |