pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-host image->text and text->image retrieval evaluation.
Example how to add to config:
config.evals {}
config.evals.retieval = dict(log_steps=1200, type='proj.image_text.retrieval')
config.evals.retrieval.dataset = 'coco_captions'
config.evals.retrieval.txt_name = ('captions', 'text')
# Note that initial "decode|" is not needed.
config.evals.retrieval.pp_img = 'resize(224)|value_range(-1,1)'
# Raw text strings use key "texts" in feature dict. The evaluator expects
# tokenized text with key "labels".
config.evals.retrieval.pp_txt = (
'tokenize(max_len=16, eos="sticky", pad_value=1, inkey="texts", '
' outkey="labels")')
Example to support precomputed data:
See `big_vision/configs/proj/image_text/lit.py`.
"""
import functools
import operator
import time
from absl import logging
from big_vision import input_pipeline
from big_vision.evaluators.proj.image_text import image_text_retrieval
import big_vision.pp.builder as pp_builder
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"
def _with_infinite_padding(dataset):
"""Adds "infinite padding" to the dataset."""
filler_element = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
filler_element["mask"] = [False]
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
dataset = dataset.map(
lambda features: dict(mask=True, **features),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset.concatenate(filler_dataset.repeat(None))
# This is needed so retrieval_test can replace dataset info.
def _get_dataset_info(builder):
return builder.info
def prepare_datasets(
dataset, *, pp_img, pp_txt, txt_name, offset=0, cache_final=False
):
"""Returns unbatched `ds_images, ds_texts` datasets.
Args:
dataset: An image-text `tf.data.Dataset` that is expected to contain the
following features: "image" (dtype=uint8, shape=[None, None, 3]),
`txt_name` (dtype=string, shape=[None]).
pp_img: String defining pre-processing for images. The pre-processing can
expect the following features to be prepared: "image", "id". The
pre-processing should convert the "image" (dtype=uint8,
shape=[None, None, 3]) to "image" (dtype=float32, shape=[sz, sz, 3]).
pp_txt: String defining pre-processing for text. The pre-processing can
expect the following features to be prepared: "texts", "id", "caption_id".
The pre-processing should convert the "texts" (dtype=string, shape=[])
into a tokenized "labels" (dtype=int32, shape=[max_len]).
txt_name: Name of the text feature to unroll in the original `dataset`. Can
be a simple string feature name, or an iterable of strings to specify a
nested feature (e.g. for "coco_captions", this would be
`('captions', 'text')`).
offset: Offset that should be added to enumerated examples to generate IDs.
In a multi-host setup, this is typically set to a value large enough to
make all IDs distinct.
cache_final: Whether the dataset should be cached.
Returns:
Image and text datasets.
"""
def get_feature_value(data, feature_name):
if isinstance(feature_name, str):
feature_name = [feature_name]
return functools.reduce(operator.getitem, feature_name, data)
def get_captions(idx, features):
"""Returns a dataset with unrolled "caption" for every example."""
texts = get_feature_value(features, txt_name)
texts = tf.experimental.numpy.atleast_1d(texts) # For single-text GT.
texts_n = tf.shape(texts)[0]
return tf.data.Dataset.from_tensor_slices({
"id": tf.tile([idx + offset], [texts_n]),
"caption_i": tf.stack(tf.range(texts_n)),
"texts": tf.stack(texts),
})
def add_id(idx, features):
return {**features, "id": idx + offset}
ds_images = dataset.enumerate().map(add_id).map(
pp_builder.get_preprocess_fn(f"{pp_img}|keep('id', 'image')"))
ds_texts = dataset.enumerate().flat_map(get_captions).map(
pp_builder.get_preprocess_fn(
f"{pp_txt}|keep('id', 'caption_i', 'labels')"))
if cache_final:
ds_images, ds_texts = ds_images.cache(), ds_texts.cache()
return ds_images, ds_texts
def _split_and_batch(dataset_name, batch_size, split, get_ds, data_dir=None):
"""Splits dataset, calls `get_ds` and returns padded + batched datasets."""
assert not batch_size % jax.device_count(), (
f"batch_size={batch_size} % jax.device_count()={jax.device_count()}")
builder = tfds.builder(dataset_name, data_dir=data_dir)
info = _get_dataset_info(builder)
num_examples = info.splits[split].num_examples
ds_images, ds_texts = get_ds(
builder.as_dataset(split=tfds.split_for_jax_process(split)),
offset=jax.process_index() * num_examples,
)
return (
_with_infinite_padding(ds_images).batch(batch_size),
_with_infinite_padding(ds_texts).batch(batch_size),
)
class Evaluator:
"""Image/text retrieval evaluator."""
def __init__(self,
predict_fn,
*,
dataset,
pp_img,
pp_txt,
txt_name,
batch_size,
devices,
data_dir=None,
split="test",
cache_final=True):
"""Initializes a new zero-shot image/text retrieval evaluator.
See `prepare_datasets()` for details on how the dataset is pre-processed.
Args:
predict_fn: Prediction function with signature
`zimg, ztxt, out = predict_fn(params, images, texts)`
dataset: The TFDS dataset name of the eval data.
pp_img: Preprocessing string for images. Preprocessed features should
contain key "image" with value that can be batched and is suitable for
`predict_fn(images)` input``.
pp_txt: Preprocessing string for texts. Can expect "texts" key as an input
(shape=[], dtype=string), and is expected to produce "labels" key that
is suitable for `predict_fn(texts)` input.
txt_name: The name of the feature of captions (can be a tuple to look up a
value in a nested feature dictionary). Expected shape=[None],
dtype=string. specified then items are used as lookup path.
batch_size: Global batch size.
devices: list of devices.
data_dir: Optional dir to load the TFDS dataset from.
split: The split of the eval data.
cache_final: Wether preprocessed dataset should be cached.
"""
self.ds_images, self.ds_texts = _split_and_batch(
dataset,
batch_size,
split,
functools.partial(
prepare_datasets,
pp_img=pp_img,
pp_txt=pp_txt,
txt_name=txt_name,
cache_final=cache_final,
),
data_dir=data_dir,
)
self._axis_name = "batch"
self.devices = devices
mesh = jax.sharding.Mesh(devices, ("devices",))
def embed_images(train_state, images):
zimg, _, _ = predict_fn(train_state, {"image": images})
return zimg
def embed_texts(train_state, texts):
_, ztxt, _ = predict_fn(train_state, {"labels": texts})
return ztxt
self._embed_images_p = jax.jit(embed_images,
out_shardings=NamedSharding(mesh, P()))
self._embed_texts_p = jax.jit(embed_texts,
out_shardings=NamedSharding(mesh, P()))
self._all_gather_p = jax.jit(
lambda x: x, out_shardings=NamedSharding(mesh, P()))
self._count_p = jax.jit(jnp.sum, out_shardings=NamedSharding(mesh, P()))
self._compiled = set()
def _embed(self, name, train_state, ds, embed_fn, id_names):
"""Embeds features name `name` using `embed_fn`.
Args:
name: Feature name to be embedded.
train_state: train_state for the predict_fn.
ds: The dataset.
embed_fn: A pmapped function that returns the embeddings.
id_names: An iterable of feature names that should be collected.
Returns:
A dictionary with "embeddings" and `id_names` as keys.
"""
ns = []
embeddings = []
ids = {id_name: [] for id_name in list(id_names) + ["mask"]}
t0 = time.time()
ds_b = input_pipeline.start_global(ds, self.devices)
for batch in ds_b:
ns.append(jax.device_get(self._count_p(batch["mask"])))
# Due to infinite padding, this loop will never end. We will stop once
# all processes only process padded data. We don't check the latest
# DeviceArray `ns[-1]` Because we want to keep our computation async for
# efficiency reasons.
if len(ns) >= 2 and ns[-2] == 0:
break
embs = embed_fn(train_state, batch[name])
if embed_fn not in self._compiled:
logging.info("Compiled %s embeddings in %.3fs", name, time.time() - t0)
t0 = time.time()
self._compiled.add(embed_fn)
embeddings.append(jax.device_get(embs))
for id_name in ids:
ids[id_name].append(jax.device_get(self._all_gather_p(batch[id_name])))
# Only access DeviceArray at end of loop for better efficiency.
ns = np.array(ns)
embeddings = np.concatenate(embeddings)
ids = {k: np.concatenate(v) for k, v in ids.items()}
masks = ids.pop("mask").astype(bool)
logging.info("Processed %s in %d steps - ...%s", name, len(ns), ns[-10:])
n = ns.sum()
logging.info("Totalling %d %s in %.3fs", n, name, time.time() - t0)
return {
"embeddings": embeddings[masks],
**{k: v[masks] for k, v in ids.items()},
}
def evaluate(self, train_state):
"""Returns evaluation results."""
images = self._embed("image", train_state, self.ds_images,
self._embed_images_p, ("id",))
texts = self._embed("labels", train_state, self.ds_texts,
self._embed_texts_p, ("id", "caption_i"))
# Shapes: (nimg, emb) * (emb, ntxt) -> (nimg, ntxt)
similarities = np.dot(images["embeddings"], texts["embeddings"].T)
t0 = time.time()
id2img = {id_: i for i, id_ in enumerate(images["id"])}
text_image_correspondence = [id2img[id_] for id_ in texts["id"]]
img2txt = image_text_retrieval.image_to_text_retrieval_eval(
-similarities, text_image_correspondence)
txt2img = image_text_retrieval.text_to_image_retrieval_eval(
-similarities, text_image_correspondence)
logging.info("Computed retrieval metrics in %.3fs", time.time() - t0)
return dict(
images=images,
texts=texts,
img2txt=img2txt,
txt2img=txt2img,
)
def run(self, train_state):
"""Returns metrics."""
results = self.evaluate(train_state)
return [(f"{direction}_{k.lower()}", v)
for direction in ("img2txt", "txt2img")
for k, v in results[direction].items()]