Spaces:
Runtime error
Runtime error
File size: 3,314 Bytes
ad552d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
helpers for extracting features from image
"""
import os
import platform
import numpy as np
import torch
from torch.hub import get_dir
from .downloads_helper import check_download_url
from .inception_pytorch import InceptionV3
from .inception_torchscript import InceptionV3W
"""
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""
def feature_extractor(
name="torchscript_inception",
device=torch.device("cuda"),
resize_inside=False,
use_dataparallel=True,
):
if name == "torchscript_inception":
path = "./" if platform.system() == "Windows" else "/tmp"
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
device
)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x):
return model(x)
elif name == "pytorch_inception":
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x):
return model(x / 255)[0].squeeze(-1).squeeze(-1)
else:
raise ValueError(f"{name} feature extractor not implemented")
return model_fn
"""
Build a feature extractor for each of the modes
"""
def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
if mode == "legacy_pytorch":
feat_model = feature_extractor(
name="pytorch_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "legacy_tensorflow":
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=True,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "clean":
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
return feat_model
"""
Load precomputed reference statistics for commonly used datasets
"""
def get_reference_statistics(
name,
res,
mode="clean",
model_name="inception_v3",
seed=0,
split="test",
metric="FID",
):
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
if split == "custom":
res = "na"
if model_name == "inception_v3":
model_modifier = ""
else:
model_modifier = "_" + model_name
if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz")
url = f"{base_url}/{rel_path}"
stats_folder = os.path.join(get_dir(), "fid_stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
mu, sigma = stats["mu"], stats["sigma"]
return mu, sigma
elif metric == "KID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
url = f"{base_url}/{rel_path}"
stats_folder = os.path.join(get_dir(), "fid_stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
return stats["feats"]
|