Spaces:
Build error
Build error
File size: 6,381 Bytes
81170fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import numpy as np
import os
import functools
import argparse
import scipy
from tqdm import tqdm
import logging
from . import inception
from . import utils
logger = logging.getLogger(__name__)
class FID:
def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
"""
Evaluates the FID score for a given generator and a given dataset.
Implementation mostly taken from https://github.com/matthias-wright/jax-fid
Reference: https://arxiv.org/abs/1706.08500
Args:
generator (nn.Module): Generator network.
dataset (tf.data.Dataset): Dataset containing the real images.
config (argparse.Namespace): Configuration.
use_cache (bool): If True, only compute the activation stats once for the real images and store them.
truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
"""
self.num_images = config.num_fid_images
self.batch_size = config.batch_size
self.c_dim = config.c_dim
self.z_dim = config.z_dim
self.dataset = dataset
self.num_devices = jax.device_count()
self.num_local_devices = jax.local_device_count()
self.use_cache = use_cache
if self.use_cache:
self.cache = {}
rng = jax.random.PRNGKey(0)
inception_net = inception.InceptionV3(pretrained=True)
self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
self.inception_params = flax.jax_utils.replicate(self.inception_params)
#self.inception = jax.jit(functools.partial(model.apply, train=False))
self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
def compute_fid(self, generator_params, seed_offset=0):
generator_params = flax.jax_utils.replicate(generator_params)
mu_real, sigma_real = self.compute_stats_for_dataset()
mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
return fid_score
def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
# Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_1d(sigma1)
sigma2 = np.atleast_1d(sigma2)
assert mu1.shape == mu2.shape
assert sigma1.shape == sigma2.shape
diff = mu1 - mu2
covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
logger.info(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
def compute_stats_for_dataset(self):
if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
logger.info('Use cached statistics for dataset...')
return self.cache['mu'], self.cache['sigma']
print()
logger.info('Compute statistics for dataset...')
image_count = 0
activations = []
for batch in utils.prefetch(self.dataset, n_prefetch=2):
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
activations.append(act)
image_count += self.num_local_devices * self.batch_size
if image_count >= self.num_images:
break
activations = jnp.concatenate(activations, axis=0)
activations = activations[:self.num_images]
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
self.cache['mu'] = mu
self.cache['sigma'] = sigma
return mu, sigma
def compute_stats_for_generator(self, generator_params, seed_offset):
print()
logger.info('Compute statistics for generator...')
num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices)))
activations = []
for i in range(num_batches):
rng = jax.random.PRNGKey(seed_offset + i)
z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim))
labels = None
if self.c_dim > 0:
labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim)
labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim))
image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
image = 2 * image - 1
act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
activations.append(act)
activations = jnp.concatenate(activations, axis=0)
activations = activations[:self.num_images]
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return mu, sigma
|