import warnings import numbers import numpy as np import scipy import torch from torch.nn import functional as F from .. import models, utils from ..external_models import inception class _TruncatedDataset: """ Truncates a dataset, making only part of it accessible by `torch.utils.data.DataLoader`. """ def __init__(self, dataset, max_len): self.dataset = dataset self.max_len = max_len def __len__(self): return min(len(self.dataset), self.max_len) def __getitem__(self, index): return self.dataset[index] class FID: """ This class evaluates the FID metric of a generator. Arguments: G (Generator) prior_generator (PriorGenerator) dataset (indexable) device (int, str, torch.device, optional): The device to use for calculations. By default, the same device is chosen as the parameters in `generator` reside on. num_samples (int): Number of samples of reals and fakes to gather statistics for which are used for calculating the metric. Default value is 50 000. fid_model (nn.Module): A model that returns feature maps of shape (batch_size, features, *). Default value is InceptionV3. fid_size (int, optional): Resize any data fed to `fid_model` by scaling the data so that its smallest side is the same size as this argument. truncation_psi (float, optional): Truncation of the generator when evaluating. truncation_cutoff (int, optional): Cutoff for truncation when evaluating. reals_batch_size (int, optional): Batch size to use for real samples statistics gathering. reals_data_workers (int, optional): Number of workers fetching the real data samples. Default value is 0. verbose (bool): Write progress of gathering statistics for reals to stdout. Default value is True. """ def __init__(self, G, prior_generator, dataset, device=None, num_samples=50000, fid_model=None, fid_size=None, truncation_psi=None, truncation_cutoff=None, reals_batch_size=None, reals_data_workers=0, verbose=True): device_ids = [] if isinstance(G, torch.nn.DataParallel): device_ids = G.device_ids G = utils.unwrap_module(G) assert isinstance(G, models.Generator) assert isinstance(prior_generator, utils.PriorGenerator) if device is None: device = next(G.parameters()).device else: device = torch.device(device) assert torch.device(prior_generator.device) == device, \ 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \ 'is not the same as the specified (or infered from the model)' + \ 'device ({}) for the PPL evaluation.'.format(device) G.eval().to(device) if device_ids: G = torch.nn.DataParallel(G, device_ids=device_ids) self.G = G self.prior_generator = prior_generator self.device = device self.num_samples = num_samples self.batch_size = self.prior_generator.batch_size if fid_model is None: warnings.warn( 'Using default fid model metric based on Inception V3. ' + \ 'This metric will only work on image data where values are in ' + \ 'the range [-1, 1], please specify another module if you want ' + \ 'to use other kinds of data formats.' ) fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1) if device_ids: fid_model = torch.nn.DataParallel(fid_model, device_ids) self.fid_model = fid_model.eval().to(device) self.fid_size = fid_size dataset = _TruncatedDataset(dataset, self.num_samples) dataloader = torch.utils.data.DataLoader( dataset, batch_size=reals_batch_size or self.batch_size, num_workers=reals_data_workers ) features = [] self.labels = [] if verbose: progress = utils.ProgressWriter( np.ceil(self.num_samples / (reals_batch_size or self.batch_size))) progress.write('FID: Gathering statistics for reals...', step=False) for batch in dataloader: data = batch if isinstance(batch, (tuple, list)): data = batch[0] if len(batch) > 1: self.labels.append(batch[1]) data = self._scale_for_fid(data).to(self.device) with torch.no_grad(): batch_features = self.fid_model(data) batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1) features.append(batch_features.cpu()) progress.step() if verbose: progress.write('FID: Statistics for reals gathered!', step=False) progress.close() features = torch.cat(features, dim=0).numpy() self.mu_real = np.mean(features, axis=0) self.sigma_real = np.cov(features, rowvar=False) self.truncation_psi = truncation_psi self.truncation_cutoff = truncation_cutoff def _scale_for_fid(self, data): if not self.fid_size: return data scale_factor = self.fid_size / min(data.size()[2:]) if scale_factor == 1: return data mode = 'nearest' if scale_factor < 1: mode = 'area' return F.interpolate(data, scale_factor=scale_factor, mode=mode) def __call__(self, *args, **kwargs): return self.evaluate(*args, **kwargs) def evaluate(self, verbose=True): """ Evaluate the FID. Arguments: verbose (bool): Write progress to stdout. Default value is True. Returns: fid (float): Metric value. """ utils.unwrap_module(self.G).set_truncation( truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff) self.G.eval() features = [] if verbose: progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size)) progress.write('FID: Gathering statistics for fakes...', step=False) remaining = self.num_samples for i in range(0, self.num_samples, self.batch_size): latents, latent_labels = self.prior_generator( batch_size=min(self.batch_size, remaining)) if latent_labels is not None and self.labels: latent_labels = self.labels[i].to(self.device) length = min(len(latents), len(latent_labels)) latents, latent_labels = latents[:length], latent_labels[:length] with torch.no_grad(): fakes = self.G(latents, labels=latent_labels) with torch.no_grad(): batch_features = self.fid_model(fakes) batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1) features.append(batch_features.cpu()) remaining -= len(latents) progress.step() if verbose: progress.write('FID: Statistics for fakes gathered!', step=False) progress.close() features = torch.cat(features, dim=0).numpy() mu_fake = np.mean(features, axis=0) sigma_fake = np.cov(features, rowvar=False) m = np.square(mu_fake - self.mu_real).sum() s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False) dist = m + np.trace(sigma_fake + self.sigma_real - 2*s) return float(np.real(dist))