mcding
published version
ad552d8
raw
history blame
3.31 kB
"""
helpers for extracting features from image
"""
import os
import platform
import numpy as np
import torch
from torch.hub import get_dir
from .downloads_helper import check_download_url
from .inception_pytorch import InceptionV3
from .inception_torchscript import InceptionV3W
"""
returns a functions that takes an image in range [0,255]
and outputs a feature embedding vector
"""
def feature_extractor(
name="torchscript_inception",
device=torch.device("cuda"),
resize_inside=False,
use_dataparallel=True,
):
if name == "torchscript_inception":
path = "./" if platform.system() == "Windows" else "/tmp"
model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(
device
)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x):
return model(x)
elif name == "pytorch_inception":
model = InceptionV3(output_blocks=[3], resize_input=False).to(device)
model.eval()
if use_dataparallel:
model = torch.nn.DataParallel(model)
def model_fn(x):
return model(x / 255)[0].squeeze(-1).squeeze(-1)
else:
raise ValueError(f"{name} feature extractor not implemented")
return model_fn
"""
Build a feature extractor for each of the modes
"""
def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True):
if mode == "legacy_pytorch":
feat_model = feature_extractor(
name="pytorch_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "legacy_tensorflow":
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=True,
device=device,
use_dataparallel=use_dataparallel,
)
elif mode == "clean":
feat_model = feature_extractor(
name="torchscript_inception",
resize_inside=False,
device=device,
use_dataparallel=use_dataparallel,
)
return feat_model
"""
Load precomputed reference statistics for commonly used datasets
"""
def get_reference_statistics(
name,
res,
mode="clean",
model_name="inception_v3",
seed=0,
split="test",
metric="FID",
):
base_url = "https://www.cs.cmu.edu/~clean-fid/stats/"
if split == "custom":
res = "na"
if model_name == "inception_v3":
model_modifier = ""
else:
model_modifier = "_" + model_name
if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz")
url = f"{base_url}/{rel_path}"
stats_folder = os.path.join(get_dir(), "fid_stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
mu, sigma = stats["mu"], stats["sigma"]
return mu, sigma
elif metric == "KID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
url = f"{base_url}/{rel_path}"
stats_folder = os.path.join(get_dir(), "fid_stats")
fpath = check_download_url(local_folder=stats_folder, url=url)
stats = np.load(fpath)
return stats["feats"]