|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Multi-host image->text and text->image retrieval evaluation. |
|
|
|
Example how to add to config: |
|
|
|
config.evals {} |
|
config.evals.retieval = dict(log_steps=1200, type='proj.image_text.retrieval') |
|
config.evals.retrieval.dataset = 'coco_captions' |
|
config.evals.retrieval.txt_name = ('captions', 'text') |
|
# Note that initial "decode|" is not needed. |
|
config.evals.retrieval.pp_img = 'resize(224)|value_range(-1,1)' |
|
# Raw text strings use key "texts" in feature dict. The evaluator expects |
|
# tokenized text with key "labels". |
|
config.evals.retrieval.pp_txt = ( |
|
'tokenize(max_len=16, eos="sticky", pad_value=1, inkey="texts", ' |
|
' outkey="labels")') |
|
|
|
Example to support precomputed data: |
|
See `big_vision/configs/proj/image_text/lit.py`. |
|
""" |
|
|
|
import functools |
|
import operator |
|
import time |
|
|
|
from absl import logging |
|
from big_vision import input_pipeline |
|
from big_vision.evaluators.proj.image_text import image_text_retrieval |
|
import big_vision.pp.builder as pp_builder |
|
import jax |
|
import jax.numpy as jnp |
|
from jax.sharding import NamedSharding |
|
from jax.sharding import PartitionSpec as P |
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
def _with_infinite_padding(dataset): |
|
"""Adds "infinite padding" to the dataset.""" |
|
filler_element = tf.nest.map_structure( |
|
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec) |
|
filler_element["mask"] = [False] |
|
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element) |
|
dataset = dataset.map( |
|
lambda features: dict(mask=True, **features), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
return dataset.concatenate(filler_dataset.repeat(None)) |
|
|
|
|
|
|
|
def _get_dataset_info(builder): |
|
return builder.info |
|
|
|
|
|
def prepare_datasets( |
|
dataset, *, pp_img, pp_txt, txt_name, offset=0, cache_final=False |
|
): |
|
"""Returns unbatched `ds_images, ds_texts` datasets. |
|
|
|
Args: |
|
dataset: An image-text `tf.data.Dataset` that is expected to contain the |
|
following features: "image" (dtype=uint8, shape=[None, None, 3]), |
|
`txt_name` (dtype=string, shape=[None]). |
|
pp_img: String defining pre-processing for images. The pre-processing can |
|
expect the following features to be prepared: "image", "id". The |
|
pre-processing should convert the "image" (dtype=uint8, |
|
shape=[None, None, 3]) to "image" (dtype=float32, shape=[sz, sz, 3]). |
|
pp_txt: String defining pre-processing for text. The pre-processing can |
|
expect the following features to be prepared: "texts", "id", "caption_id". |
|
The pre-processing should convert the "texts" (dtype=string, shape=[]) |
|
into a tokenized "labels" (dtype=int32, shape=[max_len]). |
|
txt_name: Name of the text feature to unroll in the original `dataset`. Can |
|
be a simple string feature name, or an iterable of strings to specify a |
|
nested feature (e.g. for "coco_captions", this would be |
|
`('captions', 'text')`). |
|
offset: Offset that should be added to enumerated examples to generate IDs. |
|
In a multi-host setup, this is typically set to a value large enough to |
|
make all IDs distinct. |
|
cache_final: Whether the dataset should be cached. |
|
|
|
Returns: |
|
Image and text datasets. |
|
""" |
|
|
|
def get_feature_value(data, feature_name): |
|
if isinstance(feature_name, str): |
|
feature_name = [feature_name] |
|
return functools.reduce(operator.getitem, feature_name, data) |
|
|
|
def get_captions(idx, features): |
|
"""Returns a dataset with unrolled "caption" for every example.""" |
|
texts = get_feature_value(features, txt_name) |
|
texts = tf.experimental.numpy.atleast_1d(texts) |
|
texts_n = tf.shape(texts)[0] |
|
return tf.data.Dataset.from_tensor_slices({ |
|
"id": tf.tile([idx + offset], [texts_n]), |
|
"caption_i": tf.stack(tf.range(texts_n)), |
|
"texts": tf.stack(texts), |
|
}) |
|
|
|
def add_id(idx, features): |
|
return {**features, "id": idx + offset} |
|
|
|
ds_images = dataset.enumerate().map(add_id).map( |
|
pp_builder.get_preprocess_fn(f"{pp_img}|keep('id', 'image')")) |
|
ds_texts = dataset.enumerate().flat_map(get_captions).map( |
|
pp_builder.get_preprocess_fn( |
|
f"{pp_txt}|keep('id', 'caption_i', 'labels')")) |
|
if cache_final: |
|
ds_images, ds_texts = ds_images.cache(), ds_texts.cache() |
|
return ds_images, ds_texts |
|
|
|
|
|
def _split_and_batch(dataset_name, batch_size, split, get_ds, data_dir=None): |
|
"""Splits dataset, calls `get_ds` and returns padded + batched datasets.""" |
|
assert not batch_size % jax.device_count(), ( |
|
f"batch_size={batch_size} % jax.device_count()={jax.device_count()}") |
|
builder = tfds.builder(dataset_name, data_dir=data_dir) |
|
info = _get_dataset_info(builder) |
|
num_examples = info.splits[split].num_examples |
|
ds_images, ds_texts = get_ds( |
|
builder.as_dataset(split=tfds.split_for_jax_process(split)), |
|
offset=jax.process_index() * num_examples, |
|
) |
|
return ( |
|
_with_infinite_padding(ds_images).batch(batch_size), |
|
_with_infinite_padding(ds_texts).batch(batch_size), |
|
) |
|
|
|
|
|
class Evaluator: |
|
"""Image/text retrieval evaluator.""" |
|
|
|
def __init__(self, |
|
predict_fn, |
|
*, |
|
dataset, |
|
pp_img, |
|
pp_txt, |
|
txt_name, |
|
batch_size, |
|
devices, |
|
data_dir=None, |
|
split="test", |
|
cache_final=True): |
|
"""Initializes a new zero-shot image/text retrieval evaluator. |
|
|
|
See `prepare_datasets()` for details on how the dataset is pre-processed. |
|
|
|
Args: |
|
predict_fn: Prediction function with signature |
|
`zimg, ztxt, out = predict_fn(params, images, texts)` |
|
dataset: The TFDS dataset name of the eval data. |
|
pp_img: Preprocessing string for images. Preprocessed features should |
|
contain key "image" with value that can be batched and is suitable for |
|
`predict_fn(images)` input``. |
|
pp_txt: Preprocessing string for texts. Can expect "texts" key as an input |
|
(shape=[], dtype=string), and is expected to produce "labels" key that |
|
is suitable for `predict_fn(texts)` input. |
|
txt_name: The name of the feature of captions (can be a tuple to look up a |
|
value in a nested feature dictionary). Expected shape=[None], |
|
dtype=string. specified then items are used as lookup path. |
|
batch_size: Global batch size. |
|
devices: list of devices. |
|
data_dir: Optional dir to load the TFDS dataset from. |
|
split: The split of the eval data. |
|
cache_final: Wether preprocessed dataset should be cached. |
|
""" |
|
self.ds_images, self.ds_texts = _split_and_batch( |
|
dataset, |
|
batch_size, |
|
split, |
|
functools.partial( |
|
prepare_datasets, |
|
pp_img=pp_img, |
|
pp_txt=pp_txt, |
|
txt_name=txt_name, |
|
cache_final=cache_final, |
|
), |
|
data_dir=data_dir, |
|
) |
|
self._axis_name = "batch" |
|
|
|
self.devices = devices |
|
mesh = jax.sharding.Mesh(devices, ("devices",)) |
|
|
|
def embed_images(train_state, images): |
|
zimg, _, _ = predict_fn(train_state, {"image": images}) |
|
return zimg |
|
|
|
def embed_texts(train_state, texts): |
|
_, ztxt, _ = predict_fn(train_state, {"labels": texts}) |
|
return ztxt |
|
|
|
self._embed_images_p = jax.jit(embed_images, |
|
out_shardings=NamedSharding(mesh, P())) |
|
self._embed_texts_p = jax.jit(embed_texts, |
|
out_shardings=NamedSharding(mesh, P())) |
|
self._all_gather_p = jax.jit( |
|
lambda x: x, out_shardings=NamedSharding(mesh, P())) |
|
self._count_p = jax.jit(jnp.sum, out_shardings=NamedSharding(mesh, P())) |
|
self._compiled = set() |
|
|
|
def _embed(self, name, train_state, ds, embed_fn, id_names): |
|
"""Embeds features name `name` using `embed_fn`. |
|
|
|
Args: |
|
name: Feature name to be embedded. |
|
train_state: train_state for the predict_fn. |
|
ds: The dataset. |
|
embed_fn: A pmapped function that returns the embeddings. |
|
id_names: An iterable of feature names that should be collected. |
|
|
|
Returns: |
|
A dictionary with "embeddings" and `id_names` as keys. |
|
""" |
|
ns = [] |
|
embeddings = [] |
|
ids = {id_name: [] for id_name in list(id_names) + ["mask"]} |
|
|
|
t0 = time.time() |
|
|
|
ds_b = input_pipeline.start_global(ds, self.devices) |
|
for batch in ds_b: |
|
ns.append(jax.device_get(self._count_p(batch["mask"]))) |
|
|
|
|
|
|
|
|
|
|
|
if len(ns) >= 2 and ns[-2] == 0: |
|
break |
|
|
|
embs = embed_fn(train_state, batch[name]) |
|
if embed_fn not in self._compiled: |
|
logging.info("Compiled %s embeddings in %.3fs", name, time.time() - t0) |
|
t0 = time.time() |
|
self._compiled.add(embed_fn) |
|
|
|
embeddings.append(jax.device_get(embs)) |
|
for id_name in ids: |
|
ids[id_name].append(jax.device_get(self._all_gather_p(batch[id_name]))) |
|
|
|
|
|
ns = np.array(ns) |
|
embeddings = np.concatenate(embeddings) |
|
ids = {k: np.concatenate(v) for k, v in ids.items()} |
|
masks = ids.pop("mask").astype(bool) |
|
logging.info("Processed %s in %d steps - ...%s", name, len(ns), ns[-10:]) |
|
n = ns.sum() |
|
logging.info("Totalling %d %s in %.3fs", n, name, time.time() - t0) |
|
return { |
|
"embeddings": embeddings[masks], |
|
**{k: v[masks] for k, v in ids.items()}, |
|
} |
|
|
|
def evaluate(self, train_state): |
|
"""Returns evaluation results.""" |
|
images = self._embed("image", train_state, self.ds_images, |
|
self._embed_images_p, ("id",)) |
|
texts = self._embed("labels", train_state, self.ds_texts, |
|
self._embed_texts_p, ("id", "caption_i")) |
|
|
|
similarities = np.dot(images["embeddings"], texts["embeddings"].T) |
|
|
|
t0 = time.time() |
|
id2img = {id_: i for i, id_ in enumerate(images["id"])} |
|
text_image_correspondence = [id2img[id_] for id_ in texts["id"]] |
|
img2txt = image_text_retrieval.image_to_text_retrieval_eval( |
|
-similarities, text_image_correspondence) |
|
txt2img = image_text_retrieval.text_to_image_retrieval_eval( |
|
-similarities, text_image_correspondence) |
|
logging.info("Computed retrieval metrics in %.3fs", time.time() - t0) |
|
|
|
return dict( |
|
images=images, |
|
texts=texts, |
|
img2txt=img2txt, |
|
txt2img=txt2img, |
|
) |
|
|
|
def run(self, train_state): |
|
"""Returns metrics.""" |
|
results = self.evaluate(train_state) |
|
return [(f"{direction}_{k.lower()}", v) |
|
for direction in ("img2txt", "txt2img") |
|
for k, v in results[direction].items()] |
|
|