|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluation producing ColTran FID-5K metric.""" |
|
|
|
import functools |
|
import os |
|
|
|
from absl import logging |
|
import einops |
|
import jax |
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
import tensorflow_gan as tfgan |
|
import tensorflow_hub as tfhub |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
ROOT = os.environ.get("FID_DATA_DIR", ".") |
|
|
|
|
|
def _preprocess(image, resolution=512): |
|
"""ColTran dataset preprocessing. |
|
|
|
See, |
|
github.com/google-research/google-research/blob/master/coltran/datasets.py#L44 |
|
|
|
Args: |
|
image: ImageNet example from TFDS. |
|
resolution: Integer representing output size. |
|
|
|
Returns: |
|
An int32 image of size (resolution, resolution, 3). |
|
""" |
|
image_shape = tf.shape(image) |
|
height, width = image_shape[0], image_shape[1] |
|
side_size = tf.minimum(height, width) |
|
image = tf.image.resize_with_crop_or_pad( |
|
image, target_height=side_size, target_width=side_size) |
|
image = tf.image.resize(image, method="area", antialias=True, |
|
size=(resolution, resolution)) |
|
image = tf.cast(tf.round(image), dtype=tf.int32) |
|
return image |
|
|
|
|
|
def _normalize(x): |
|
"""Coltran normalization to expected range for Inception module. |
|
|
|
Args: |
|
x: Image with values in [0,255]. |
|
|
|
Returns: |
|
Image with values in [-1,1]. |
|
""" |
|
x = tf.cast(x, tf.float32) |
|
x = (x / 128.0) - 1.0 |
|
return x |
|
|
|
|
|
class Evaluator: |
|
"""ColTran FID-5K Evaluator. |
|
|
|
This Evaluator aims to mirror the evaluation pipeline used by Kumar et.al. |
|
in Colorization Transformer (https://arxiv.org/abs/2102.04432). |
|
|
|
To be clear: much of this code is direct snippets from ColTran code. |
|
|
|
See, |
|
github.com/google-research/google-research/blob/master/coltran/datasets.py#L44 |
|
|
|
The ColTran pipeline has numerous stages, where serialied data is passed |
|
between binaries via file, etc... While we don't physically write the same |
|
files, we simulate the effects of the serialization (e.g., quantization). |
|
""" |
|
|
|
def __init__(self, |
|
predict_fn, |
|
batch_size, |
|
device_batch_size=5, |
|
coltran_seed=1, |
|
predict_kwargs=None): |
|
"""Create Evaluator. |
|
|
|
Args: |
|
predict_fn: Colorization prediction function. Expects grayscale images |
|
of size (512, 512, 3) in keys `image` and `image_ctx` with values in |
|
the range [-1,1]. Outputs `color` image in range [-1,1]. |
|
batch_size: ignored. |
|
device_batch_size: number of images per batch, per device. |
|
coltran_seed: used to specify the block of 5_000 images used to generate |
|
the reference pool. Value of `1` matches default ColTran code. |
|
predict_kwargs: arguments passed to `predict_fn`. |
|
""" |
|
del batch_size |
|
|
|
self.num_devices = jax.local_device_count() |
|
self.device_batch_size = device_batch_size |
|
logging.log(logging.INFO, "Colorizing with batch size %i on %i devices.", |
|
self.device_batch_size, self.num_devices) |
|
assert 5_000 % (self.device_batch_size * self.num_devices) == 0 |
|
|
|
predict = functools.partial(predict_fn, **(predict_kwargs or {})) |
|
self.predict_fn = jax.pmap(predict) |
|
|
|
module = tfhub.load(tfgan.eval.INCEPTION_TFHUB) |
|
def _pools(x): |
|
return np.squeeze(module(x)[tfgan.eval.INCEPTION_FINAL_POOL].numpy()) |
|
|
|
self.inception_pool = _pools |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _eval_data_preprocess(example): |
|
|
|
image = _preprocess(example["image"], resolution=512) |
|
image = _normalize(image) |
|
grayscale = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) |
|
return { |
|
"image": image, |
|
"grayscale": grayscale, |
|
"file_name": example["file_name"] |
|
} |
|
|
|
ds = tfds.load("imagenet2012", split="validation") |
|
ds = ds.map(_eval_data_preprocess) |
|
ds = ds.take(5_000) |
|
ds = ds.batch(self.device_batch_size) |
|
ds = ds.batch(self.num_devices) |
|
self.eval_data = ds.cache().prefetch(tf.data.AUTOTUNE) |
|
|
|
|
|
def _reference_data_preprocess(example): |
|
|
|
image = _preprocess(example["image"], resolution=256) |
|
image = _normalize(image) |
|
return {"image": image, "file_name": example["file_name"]} |
|
|
|
ds = tfds.load("imagenet2012", split="validation") |
|
ds = ds.map(_reference_data_preprocess) |
|
|
|
ds = ds.skip(5_000) |
|
|
|
|
|
ds = ds.skip(coltran_seed * 5_000) |
|
ds = ds.take(5_000) |
|
ds = ds.batch(device_batch_size) |
|
self.reference_data = ds.cache().prefetch(tf.data.AUTOTUNE) |
|
|
|
def _get_file(name): |
|
return os.path.join(ROOT, name) |
|
|
|
with gfile.GFile(_get_file("eval_file_names.txt")) as f: |
|
self.eval_file_names = frozenset(f.read().splitlines()) |
|
|
|
with gfile.GFile(_get_file("reference_file_names.txt")) as f: |
|
self.reference_file_names = frozenset(f.read().splitlines()) |
|
|
|
def run(self, params): |
|
"""Run eval.""" |
|
|
|
if jax.process_index(): |
|
return |
|
|
|
color_pools = [] |
|
color_file_names = set() |
|
for i, batch in enumerate(self.eval_data.as_numpy_iterator()): |
|
predict_batch = { |
|
"labels": batch["image"], |
|
"image": batch["grayscale"], |
|
"image_ctx": batch["grayscale"], |
|
} |
|
y = self.predict_fn(params, predict_batch) |
|
y = y["color"] |
|
y = einops.rearrange(y, "d b h w c -> (d b) h w c") |
|
|
|
|
|
y = tf.image.resize(y, (256, 256), "area") |
|
|
|
|
|
y = np.clip(np.round((y + 1.) * 128.), 0, 255) |
|
y = _normalize(y) |
|
|
|
color_pools.append(self.inception_pool(y)) |
|
|
|
file_names = einops.rearrange(batch["file_name"], "d b -> (d b)") |
|
color_file_names.update([f.decode() for f in file_names]) |
|
|
|
logging.log_every_n_seconds( |
|
logging.INFO, |
|
"ColTran FID eval: processed %i colorized examples so far.", 30, |
|
(i + 1) * self.device_batch_size * self.num_devices) |
|
|
|
reference_pools = [] |
|
reference_file_names = set() |
|
for i, batch in enumerate(self.reference_data.as_numpy_iterator()): |
|
image = batch["image"] |
|
assert np.array_equal(image.shape, (self.device_batch_size, 256, 256, 3)) |
|
reference_pools.append(self.inception_pool(image)) |
|
reference_file_names.update([f.decode() for f in batch["file_name"]]) |
|
|
|
logging.log_every_n_seconds( |
|
logging.INFO, |
|
"ColTran FID eval: processed %i reference examples so far.", 30, |
|
(i + 1) * self.device_batch_size) |
|
|
|
if color_file_names != self.eval_file_names: |
|
raise ValueError("unknown: {}\nmissing: {}".format( |
|
color_file_names - self.eval_file_names, |
|
self.eval_file_names - color_file_names)) |
|
|
|
if reference_file_names != self.reference_file_names: |
|
raise ValueError("unknown: {}\nmissing: {}".format( |
|
reference_file_names - self.reference_file_names, |
|
self.reference_file_names - reference_file_names)) |
|
|
|
color = np.concatenate(color_pools, axis=0) |
|
reference = np.concatenate(reference_pools, axis=0) |
|
|
|
if color.shape[0] != 5_000: |
|
raise ValueError(color.shape) |
|
|
|
if reference.shape[0] != 5_000: |
|
raise ValueError(reference.shape) |
|
|
|
yield "FID_5k", tfgan.eval.frechet_classifier_distance_from_activations( |
|
color, reference) |
|
|