|
import argparse |
|
import pickle |
|
|
|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from scipy import linalg |
|
from tqdm import tqdm |
|
|
|
from torchvision import transforms |
|
from torchvision.datasets import ImageFolder |
|
from torch.utils.data import DataLoader |
|
|
|
from calc_inception import load_patched_inception_v3 |
|
import os |
|
|
|
@torch.no_grad() |
|
def extract_features(loader, inception, device): |
|
pbar = tqdm(loader) |
|
|
|
feature_list = [] |
|
|
|
for img,_ in pbar: |
|
img = img.to(device) |
|
feature = inception(img)[0].view(img.shape[0], -1) |
|
feature_list.append(feature.to('cpu')) |
|
|
|
features = torch.cat(feature_list, 0) |
|
|
|
return features |
|
|
|
|
|
def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): |
|
cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) |
|
|
|
if not np.isfinite(cov_sqrt).all(): |
|
print('product of cov matrices is singular') |
|
offset = np.eye(sample_cov.shape[0]) * eps |
|
cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) |
|
|
|
if np.iscomplexobj(cov_sqrt): |
|
if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): |
|
m = np.max(np.abs(cov_sqrt.imag)) |
|
|
|
raise ValueError(f'Imaginary component {m}') |
|
|
|
cov_sqrt = cov_sqrt.real |
|
|
|
mean_diff = sample_mean - real_mean |
|
mean_norm = mean_diff @ mean_diff |
|
|
|
trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) |
|
|
|
fid = mean_norm + trace |
|
|
|
return fid |
|
|
|
|
|
if __name__ == '__main__': |
|
device = 'cuda' |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--batch', type=int, default=64) |
|
parser.add_argument('--size', type=int, default=256) |
|
parser.add_argument('--path_a', type=str) |
|
parser.add_argument('--path_b', type=str) |
|
parser.add_argument('--iter', type=int, default=3) |
|
parser.add_argument('--end', type=int, default=13) |
|
|
|
args = parser.parse_args() |
|
|
|
inception = load_patched_inception_v3().eval().to(device) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize( (args.size, args.size) ), |
|
|
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
dset_a = ImageFolder(args.path_a, transform) |
|
loader_a = DataLoader(dset_a, batch_size=args.batch, num_workers=4) |
|
|
|
features_a = extract_features(loader_a, inception, device).numpy() |
|
print(f'extracted {features_a.shape[0]} features') |
|
|
|
real_mean = np.mean(features_a, 0) |
|
real_cov = np.cov(features_a, rowvar=False) |
|
|
|
|
|
for folder in range(args.iter,args.end+1): |
|
folder = 'eval_%d'%(folder*10000) |
|
if os.path.exists(os.path.join( args.path_b, folder )): |
|
print(folder) |
|
dset_b = ImageFolder( os.path.join( args.path_b, folder ), transform) |
|
loader_b = DataLoader(dset_b, batch_size=args.batch, num_workers=4) |
|
|
|
features_b = extract_features(loader_b, inception, device).numpy() |
|
print(f'extracted {features_b.shape[0]} features') |
|
|
|
sample_mean = np.mean(features_b, 0) |
|
sample_cov = np.cov(features_b, rowvar=False) |
|
|
|
fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) |
|
|
|
print(folder, ' fid:', fid) |
|
|