Spaces:
Runtime error
Runtime error
File size: 8,005 Bytes
480bfbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import warnings
import numbers
import numpy as np
import scipy
import torch
from torch.nn import functional as F
from .. import models, utils
from ..external_models import inception
class _TruncatedDataset:
"""
Truncates a dataset, making only part of it accessible
by `torch.utils.data.DataLoader`.
"""
def __init__(self, dataset, max_len):
self.dataset = dataset
self.max_len = max_len
def __len__(self):
return min(len(self.dataset), self.max_len)
def __getitem__(self, index):
return self.dataset[index]
class FID:
"""
This class evaluates the FID metric of a generator.
Arguments:
G (Generator)
prior_generator (PriorGenerator)
dataset (indexable)
device (int, str, torch.device, optional): The device
to use for calculations. By default, the same device
is chosen as the parameters in `generator` reside on.
num_samples (int): Number of samples of reals and fakes
to gather statistics for which are used for calculating
the metric. Default value is 50 000.
fid_model (nn.Module): A model that returns feature maps
of shape (batch_size, features, *). Default value
is InceptionV3.
fid_size (int, optional): Resize any data fed to `fid_model` by scaling
the data so that its smallest side is the same size as this
argument.
truncation_psi (float, optional): Truncation of the generator
when evaluating.
truncation_cutoff (int, optional): Cutoff for truncation when
evaluating.
reals_batch_size (int, optional): Batch size to use for real
samples statistics gathering.
reals_data_workers (int, optional): Number of workers fetching
the real data samples. Default value is 0.
verbose (bool): Write progress of gathering statistics for reals
to stdout. Default value is True.
"""
def __init__(self,
G,
prior_generator,
dataset,
device=None,
num_samples=50000,
fid_model=None,
fid_size=None,
truncation_psi=None,
truncation_cutoff=None,
reals_batch_size=None,
reals_data_workers=0,
verbose=True):
device_ids = []
if isinstance(G, torch.nn.DataParallel):
device_ids = G.device_ids
G = utils.unwrap_module(G)
assert isinstance(G, models.Generator)
assert isinstance(prior_generator, utils.PriorGenerator)
if device is None:
device = next(G.parameters()).device
else:
device = torch.device(device)
assert torch.device(prior_generator.device) == device, \
'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
'is not the same as the specified (or infered from the model)' + \
'device ({}) for the PPL evaluation.'.format(device)
G.eval().to(device)
if device_ids:
G = torch.nn.DataParallel(G, device_ids=device_ids)
self.G = G
self.prior_generator = prior_generator
self.device = device
self.num_samples = num_samples
self.batch_size = self.prior_generator.batch_size
if fid_model is None:
warnings.warn(
'Using default fid model metric based on Inception V3. ' + \
'This metric will only work on image data where values are in ' + \
'the range [-1, 1], please specify another module if you want ' + \
'to use other kinds of data formats.'
)
fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
if device_ids:
fid_model = torch.nn.DataParallel(fid_model, device_ids)
self.fid_model = fid_model.eval().to(device)
self.fid_size = fid_size
dataset = _TruncatedDataset(dataset, self.num_samples)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=reals_batch_size or self.batch_size,
num_workers=reals_data_workers
)
features = []
self.labels = []
if verbose:
progress = utils.ProgressWriter(
np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
progress.write('FID: Gathering statistics for reals...', step=False)
for batch in dataloader:
data = batch
if isinstance(batch, (tuple, list)):
data = batch[0]
if len(batch) > 1:
self.labels.append(batch[1])
data = self._scale_for_fid(data).to(self.device)
with torch.no_grad():
batch_features = self.fid_model(data)
batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
features.append(batch_features.cpu())
progress.step()
if verbose:
progress.write('FID: Statistics for reals gathered!', step=False)
progress.close()
features = torch.cat(features, dim=0).numpy()
self.mu_real = np.mean(features, axis=0)
self.sigma_real = np.cov(features, rowvar=False)
self.truncation_psi = truncation_psi
self.truncation_cutoff = truncation_cutoff
def _scale_for_fid(self, data):
if not self.fid_size:
return data
scale_factor = self.fid_size / min(data.size()[2:])
if scale_factor == 1:
return data
mode = 'nearest'
if scale_factor < 1:
mode = 'area'
return F.interpolate(data, scale_factor=scale_factor, mode=mode)
def __call__(self, *args, **kwargs):
return self.evaluate(*args, **kwargs)
def evaluate(self, verbose=True):
"""
Evaluate the FID.
Arguments:
verbose (bool): Write progress to stdout.
Default value is True.
Returns:
fid (float): Metric value.
"""
utils.unwrap_module(self.G).set_truncation(
truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
self.G.eval()
features = []
if verbose:
progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
progress.write('FID: Gathering statistics for fakes...', step=False)
remaining = self.num_samples
for i in range(0, self.num_samples, self.batch_size):
latents, latent_labels = self.prior_generator(
batch_size=min(self.batch_size, remaining))
if latent_labels is not None and self.labels:
latent_labels = self.labels[i].to(self.device)
length = min(len(latents), len(latent_labels))
latents, latent_labels = latents[:length], latent_labels[:length]
with torch.no_grad():
fakes = self.G(latents, labels=latent_labels)
with torch.no_grad():
batch_features = self.fid_model(fakes)
batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
features.append(batch_features.cpu())
remaining -= len(latents)
progress.step()
if verbose:
progress.write('FID: Statistics for fakes gathered!', step=False)
progress.close()
features = torch.cat(features, dim=0).numpy()
mu_fake = np.mean(features, axis=0)
sigma_fake = np.cov(features, rowvar=False)
m = np.square(mu_fake - self.mu_real).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
return float(np.real(dist))
|