#!/usr/bin/env python3 # # Copyright 2017 Martin Heusel # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from the original implementation by Martin Heusel. # Source https://github.com/bioinf-jku/TTUR/blob/master/fid.py ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. The FID metric calculates the distance between two distributions of images. Typically, we have summary statistics (mean & covariance matrix) of one of these distributions, while the 2nd distribution is given by a GAN. When run as a stand-alone program, it compares the distribution of images that are stored as PNG/JPEG at a specified location with a distribution given by summary statistics (in pickle format). The FID is calculated by assuming that X_1 and X_2 are the activations of the pool_3 layer of the inception net for generated samples and real world samples respectivly. See --help to see further details. ''' from __future__ import absolute_import, division, print_function import numpy as np import scipy as sp import os import gzip, pickle import tensorflow as tf from scipy.misc import imread import pathlib import urllib class InvalidFIDException(Exception): pass def create_inception_graph(pth): """Creates a graph from saved GraphDef file.""" # Creates graph from saved graph_def.pb. with tf.gfile.FastGFile( pth, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString( f.read()) _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') #------------------------------------------------------------------------------- # code for handling inception net derived from # https://github.com/openai/improved-gan/blob/master/inception_score/model.py def _get_inception_layer(sess): """Prepares inception net for batched usage and returns pool_3 layer. """ layername = 'FID_Inception_Net/pool_3:0' pool3 = sess.graph.get_tensor_by_name(layername) ops = pool3.graph.get_operations() for op_idx, op in enumerate(ops): for o in op.outputs: shape = o.get_shape() if shape._dims is not None: shape = [s.value for s in shape] new_shape = [] for j, s in enumerate(shape): if s == 1 and j == 0: new_shape.append(None) else: new_shape.append(s) try: o._shape = tf.TensorShape(new_shape) except ValueError: o._shape_val = tf.TensorShape(new_shape) # EDIT: added for compatibility with tensorflow 1.6.0 return pool3 #------------------------------------------------------------------------------- def get_activations(images, sess, batch_size=50, verbose=False): """Calculates the activations of the pool_3 layer for all images. Params: -- images : Numpy array of dimension (n_images, hi, wi, 3). The values must lie between 0 and 256. -- sess : current session -- batch_size : the images numpy array is split into batches with batch size batch_size. A reasonable batch size depends on the disposable hardware. -- verbose : If set to True and parameter out_step is given, the number of calculated batches is reported. Returns: -- A numpy array of dimension (num images, 2048) that contains the activations of the given tensor when feeding inception with the query tensor. """ inception_layer = _get_inception_layer(sess) d0 = images.shape[0] if batch_size > d0: print("warning: batch size is bigger than the data size. setting batch size to data size") batch_size = d0 n_batches = d0//batch_size n_used_imgs = n_batches*batch_size pred_arr = np.empty((n_used_imgs,2048)) for i in range(n_batches): if verbose: print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) start = i*batch_size end = start + batch_size batch = images[start:end] pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) pred_arr[start:end] = pred.reshape(batch_size,-1) if verbose: print(" done") return pred_arr #------------------------------------------------------------------------------- def calculate_frechet_distance(mu1, sigma1, mu2, sigma2): """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Params: -- mu1 : Numpy array containing the activations of the pool_3 layer of the inception net ( like returned by the function 'get_predictions') -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted on an representive data set. -- sigma2: The covariance matrix over activations of the pool_3 layer, precalcualted on an representive data set. Returns: -- dist : The Frechet Distance. Raises: -- InvalidFIDException if nan occures. """ m = np.square(mu1 - mu2).sum() #s = sp.linalg.sqrtm(np.dot(sigma1, sigma2)) # EDIT: commented out s, _ = sp.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # EDIT: added dist = m + np.trace(sigma1+sigma2 - 2*s) #if np.isnan(dist): # EDIT: commented out # raise InvalidFIDException("nan occured in distance calculation.") # EDIT: commented out #return dist # EDIT: commented out return np.real(dist) # EDIT: added #------------------------------------------------------------------------------- def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): """Calculation of the statistics used by the FID. Params: -- images : Numpy array of dimension (n_images, hi, wi, 3). The values must lie between 0 and 255. -- sess : current session -- batch_size : the images numpy array is split into batches with batch size batch_size. A reasonable batch size depends on the available hardware. -- verbose : If set to True and parameter out_step is given, the number of calculated batches is reported. Returns: -- mu : The mean over samples of the activations of the pool_3 layer of the incption model. -- sigma : The covariance matrix of the activations of the pool_3 layer of the incption model. """ act = get_activations(images, sess, batch_size, verbose) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma #------------------------------------------------------------------------------- #------------------------------------------------------------------------------- # The following functions aren't needed for calculating the FID # they're just here to make this module work as a stand-alone script # for calculating FID scores #------------------------------------------------------------------------------- def check_or_download_inception(inception_path): ''' Checks if the path to the inception file is valid, or downloads the file if it is not present. ''' INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' if inception_path is None: inception_path = '/tmp' inception_path = pathlib.Path(inception_path) model_file = inception_path / 'classify_image_graph_def.pb' if not model_file.exists(): print("Downloading Inception model") from urllib import request import tarfile fn, _ = request.urlretrieve(INCEPTION_URL) with tarfile.open(fn, mode='r') as f: f.extract('classify_image_graph_def.pb', str(model_file.parent)) return str(model_file) def _handle_path(path, sess): if path.endswith('.npz'): f = np.load(path) m, s = f['mu'][:], f['sigma'][:] f.close() else: path = pathlib.Path(path) files = list(path.glob('*.jpg')) + list(path.glob('*.png')) x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) m, s = calculate_activation_statistics(x, sess) return m, s def calculate_fid_given_paths(paths, inception_path): ''' Calculates the FID of two paths. ''' inception_path = check_or_download_inception(inception_path) for p in paths: if not os.path.exists(p): raise RuntimeError("Invalid path: %s" % p) create_inception_graph(str(inception_path)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) m1, s1 = _handle_path(paths[0], sess) m2, s2 = _handle_path(paths[1], sess) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value if __name__ == "__main__": from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("path", type=str, nargs=2, help='Path to the generated images or to .npz statistic files') parser.add_argument("-i", "--inception", type=str, default=None, help='Path to Inception model (will be downloaded if not provided)') parser.add_argument("--gpu", default="", type=str, help='GPU to use (leave blank for CPU only)') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu fid_value = calculate_fid_given_paths(args.path, args.inception) print("FID: ", fid_value) #---------------------------------------------------------------------------- # EDIT: added class API: def __init__(self, num_images, image_shape, image_dtype, minibatch_size): import config self.network_dir = os.path.join(config.result_dir, '_inception_fid') self.network_file = check_or_download_inception(self.network_dir) self.sess = tf.get_default_session() create_inception_graph(self.network_file) def get_metric_names(self): return ['FID'] def get_metric_formatting(self): return ['%-10.4f'] def begin(self, mode): assert mode in ['warmup', 'reals', 'fakes'] self.activations = [] def feed(self, mode, minibatch): act = get_activations(minibatch.transpose(0,2,3,1), self.sess, batch_size=minibatch.shape[0]) self.activations.append(act) def end(self, mode): act = np.concatenate(self.activations) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) if mode in ['warmup', 'reals']: self.mu_real = mu self.sigma_real = sigma fid = calculate_frechet_distance(mu, sigma, self.mu_real, self.sigma_real) return [fid] #----------------------------------------------------------------------------