|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Discriminative zero-shot classification evaluator. |
|
""" |
|
|
|
import functools |
|
import time |
|
|
|
from absl import logging |
|
from big_vision import input_pipeline |
|
from big_vision import utils |
|
from big_vision.evaluators.proj.image_text import prompt_engineering |
|
from big_vision.pp import ops_general |
|
from big_vision.pp import ops_image |
|
import big_vision.pp.builder as pp_builder |
|
import jax |
|
import jax.numpy as jnp |
|
from jax.sharding import NamedSharding |
|
from jax.sharding import PartitionSpec as P |
|
import numpy as np |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
DATASET_NAMES = ("imagenet2012", "cifar100", "oxford_iiit_pet") |
|
DEFAULT_OVERRIDES = ( |
|
("imagenet2012", ( |
|
("class_names", "clip"), |
|
("split", "validation"), |
|
)), |
|
) |
|
|
|
|
|
def _with_infinite_padding(dataset): |
|
"""Adds "infinite padding" to the dataset.""" |
|
filler_element = tf.nest.map_structure( |
|
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec) |
|
filler_element["mask"] = [False] |
|
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element) |
|
dataset = dataset.map( |
|
lambda features: dict(mask=True, **features), |
|
num_parallel_calls=tf.data.experimental.AUTOTUNE) |
|
return dataset.concatenate(filler_dataset.repeat(None)) |
|
|
|
|
|
|
|
def _get_dataset_info(builder): |
|
return builder.info |
|
|
|
|
|
def prepare_datasets(img_dataset, |
|
class_names, |
|
*, |
|
prompt_templates, |
|
pp_img, |
|
pp_txt, |
|
cache_final=False, |
|
pre_filter_fn=None, |
|
class_name_offset=0): |
|
"""Returns unbatched `ds_images, ds_texts` datasets.""" |
|
|
|
assert prompt_templates, "Must specify prompt templates (e.g. simply ['{}'])" |
|
|
|
def expand_aliases(idx, class_name): |
|
class_names = tf.strings.split(class_name, ",") |
|
return tf.data.Dataset.from_tensor_slices(( |
|
tf.repeat([idx + class_name_offset], len(class_names), axis=0), |
|
class_names, |
|
)) |
|
|
|
def add_prompts(idx, class_name): |
|
return tf.data.Dataset.from_tensor_slices({ |
|
"label": tf.repeat([idx], len(prompt_templates), axis=0), |
|
"class_name": tf.repeat([class_name], len(prompt_templates), axis=0), |
|
"prompt_template": prompt_templates, |
|
}) |
|
|
|
def substitute_prompt(features): |
|
parts = tf.strings.split(features["prompt_template"], "{}") |
|
tf.debugging.assert_equal(len(parts), 2, features["prompt_template"]) |
|
return { |
|
"label": features["label"], |
|
"texts": tf.strings.join([parts[0], features["class_name"], parts[1]]) |
|
} |
|
|
|
if pre_filter_fn: |
|
img_dataset = img_dataset.filter(pre_filter_fn) |
|
ds_images = img_dataset.map( |
|
pp_builder.get_preprocess_fn(f"{pp_img}|keep('label', 'image')")) |
|
ds_texts = tf.data.Dataset.from_tensor_slices(list(class_names)).enumerate( |
|
).flat_map(expand_aliases).flat_map(add_prompts).map(substitute_prompt).map( |
|
pp_builder.get_preprocess_fn(f"{pp_txt}|keep('label', 'labels')")) |
|
|
|
if cache_final: |
|
ds_images, ds_texts = ds_images.cache(), ds_texts.cache() |
|
|
|
return ds_images, ds_texts |
|
|
|
|
|
def _split_and_batch(dataset_name, data_dir, class_names, batch_size, split, |
|
get_ds): |
|
"""Splits dataset, calls `get_ds` and returns padded + batched datasets.""" |
|
assert not batch_size % jax.device_count(), ( |
|
f"batch_size={batch_size} % jax.device_count()={jax.device_count()}") |
|
builder = tfds.builder(dataset_name, data_dir=data_dir) |
|
|
|
|
|
if len(class_names) < jax.process_count(): |
|
|
|
class_names += [""] * (jax.process_count() - len(class_names)) |
|
per_process = len(class_names) // jax.process_count() |
|
class_name_offset = per_process * jax.process_index() |
|
if jax.process_index() == jax.process_count() - 1: |
|
class_names = class_names[class_name_offset:] |
|
else: |
|
class_names = class_names[class_name_offset:class_name_offset + per_process] |
|
|
|
ds_images, ds_texts = get_ds( |
|
builder.as_dataset(split=tfds.split_for_jax_process(split)), |
|
class_names, |
|
class_name_offset=class_name_offset) |
|
return ( |
|
_with_infinite_padding(ds_images).batch(batch_size), |
|
_with_infinite_padding(ds_texts).batch(batch_size), |
|
) |
|
|
|
|
|
def _average_embeddings(embeddings, *, labels, num_classes, normalize): |
|
"""Computes per-class averages of `embeddings`.""" |
|
assert embeddings.ndim == 2, f"Expected {embeddings.ndim}==2" |
|
assert labels.ndim == 1, f"Expected {labels.ndim}==1" |
|
assert len(labels) == len(embeddings), ( |
|
f"Expected {len(labels)}=={len(embeddings)}") |
|
|
|
byidx = [[] for _ in range(num_classes)] |
|
for label, embedding in zip(labels, embeddings): |
|
byidx[label].append(embedding) |
|
missing = set(range(num_classes)) - set( |
|
idx for idx, embs in enumerate(byidx) if len(embs)) |
|
assert not missing, f"Classes without embeddings: {missing}" |
|
embeddings = [np.array(embedding).mean(axis=0) for embedding in byidx] |
|
embeddings = np.stack(embeddings) |
|
|
|
assert len(embeddings) == num_classes |
|
if normalize: |
|
embeddings /= 1e-8 + np.linalg.norm(embeddings, axis=1, keepdims=True) |
|
return embeddings |
|
|
|
|
|
class Evaluator: |
|
"""Zero-shot classification evaluator.""" |
|
|
|
def __init__(self, |
|
predict_fn, |
|
*, |
|
batch_size, |
|
devices, |
|
dataset_names=DATASET_NAMES, |
|
data_dir=None, |
|
class_names="dataset_info:label", |
|
split="test", |
|
prompt_templates="clip_paper", |
|
canonicalize=True, |
|
pp_img="resize(224)|value_range(-1,1)", |
|
pp_txt="tokenize(max_len=16, eos='sticky', " |
|
"pad_value=1, inkey='texts', outkey='labels')", |
|
cache_final=False, |
|
pre_filter_fn=None, |
|
first_class_name_only=True, |
|
dataset_overrides=DEFAULT_OVERRIDES, |
|
async_delay=1): |
|
"""Initializes a new zero-shot classification evaluator. |
|
|
|
See `prepare_datasets()` for details on how the dataset is pre-processed. |
|
|
|
Args: |
|
predict_fn: Prediction function with signature |
|
`zimg, ztxt, out = predict_fn(params, images, texts)` |
|
batch_size: Global batch size. |
|
devices: list of devices. |
|
dataset_names: Names of TFDS datasets to evaluate on. |
|
data_dir: Optional argument to `tfds.builder()`. |
|
class_names: Usually specified as a string that is interpreted by |
|
`prompt_engineering.get_class_names()` to look up class names. |
|
Alternatively, this attribute can be a list of class names (using "," |
|
to separate multiple aliases). |
|
split: Which dataset split to use for evaluation. |
|
prompt_templates: Specifies which prompt templates to use. See module |
|
big_vision.evaluators.proj.image_text.prompte_engineering |
|
for valid values. |
|
canonicalize: Whether class names and prompt templates should be |
|
canonicalized. See `prompt_engineering.py` for details. |
|
pp_img: Preprocessing string for images. Preprocessed features should |
|
contain key "image" with value that can be batched and is suitable for |
|
the `images` argument of `predict_fn` input``. |
|
pp_txt: Preprocessing string for texts. Can expect "texts" key as an input |
|
(shape=[], dtype=string), and is expected to produce "labels" key that |
|
is suitable for the `text` argument of `predict_fn` input. |
|
cache_final: Wether preprocesse dataset should be cached. |
|
pre_filter_fn: Predicate applied to the dataset for filtering records. |
|
first_class_name_only: Whether only the first class name should be |
|
considered (i.e. not using any aliases). |
|
dataset_overrides: Mapping `dataset_name` to an optional dictionary that |
|
can override parameters `dataset_name`, `data_dir`, `pp_img`, `pp_txt`, |
|
`class_names`, `split`, `pre_filter_fn`, and the extra |
|
`class_names_dataset_name`. |
|
Works with tuple/dict of tuples/dicts. |
|
async_delay: How many steps to wait before checking if all hosts have |
|
finished their batch. A value > 1 allows for more parallelized |
|
processing, but will results in more unnecessary steps with padded data. |
|
""" |
|
t0 = time.monotonic() |
|
self.datasets = {} |
|
self.prompt_templates = prompt_engineering.get_prompt_templates( |
|
prompt_templates, canonicalize=canonicalize) |
|
self._axis_name = "batch" |
|
dataset_overrides = {k: dict(v) for k, v in dict(dataset_overrides).items()} |
|
|
|
for dataset_name in dataset_names: |
|
overrides = dataset_overrides.pop(dataset_name, {}) |
|
dataset_name_ = overrides.pop("dataset_name", dataset_name) |
|
data_dir_ = overrides.pop("data_dir", data_dir) |
|
class_names_dataset_name = overrides.pop("class_names_dataset_name", |
|
dataset_name_) |
|
class_names_ = overrides.pop("class_names", class_names) |
|
class_names_ = prompt_engineering.get_class_names( |
|
dataset_name=class_names_dataset_name, |
|
source=class_names_, |
|
canonicalize=canonicalize) |
|
pp_img_ = overrides.pop("pp_img", pp_img) |
|
pp_txt_ = overrides.pop("pp_txt", pp_txt) |
|
cache_final_ = overrides.pop("cache_final", cache_final) |
|
split_ = overrides.pop("split", split) |
|
pre_filter_fn_ = overrides.pop("pre_filter_fn", pre_filter_fn) |
|
prompt_templates_ = overrides.pop("prompt_templates", prompt_templates) |
|
canonicalize_ = overrides.pop("canonicalize", canonicalize) |
|
prompt_templates_ = prompt_engineering.get_prompt_templates( |
|
prompt_templates_, canonicalize=canonicalize_) |
|
assert not overrides, f"Unknown overrides {dataset_name}: {overrides}" |
|
|
|
if first_class_name_only: |
|
class_names_ = [name.split(",")[0] for name in class_names_] |
|
ds_images, ds_texts = _split_and_batch( |
|
dataset_name=dataset_name_, |
|
data_dir=data_dir_, |
|
class_names=class_names_, |
|
batch_size=batch_size, |
|
split=split_, |
|
get_ds=functools.partial( |
|
prepare_datasets, |
|
pp_img=pp_img_, |
|
pp_txt=pp_txt_, |
|
cache_final=cache_final_, |
|
pre_filter_fn=pre_filter_fn_, |
|
prompt_templates=prompt_templates_)) |
|
self.datasets[dataset_name] = dict( |
|
images=ds_images, texts=ds_texts, class_names=class_names_, |
|
dataset_name=dataset_name_, split=split_) |
|
|
|
assert not dataset_overrides, f"Extra overrides: {dataset_overrides}" |
|
|
|
def embed_texts(train_state, texts): |
|
"""Returns text embeddings.""" |
|
_, ztxt, _ = predict_fn(train_state, {"labels": texts}) |
|
return ztxt |
|
|
|
def count_correct(train_state, return_embeddings, *, mask, labels, image, |
|
ztxt): |
|
"""Returns count of correct predictions (and optionally embeddings).""" |
|
zimg, _, _ = predict_fn(train_state, {"image": image}) |
|
best_txt = (zimg @ ztxt.T).argmax(axis=1) |
|
|
|
|
|
|
|
|
|
if labels.ndim == 1: |
|
labels = labels[..., None] |
|
assert labels.ndim == 2, labels.shape |
|
matching = (best_txt[:, None] == labels).sum(axis=1) |
|
correct = jnp.where(mask, (matching > 0).astype(jnp.int32), 0).sum() |
|
correct = jnp.sum(correct) |
|
if return_embeddings: |
|
return correct, zimg |
|
else: |
|
return correct, None |
|
|
|
self.devices = devices |
|
self.mesh = jax.sharding.Mesh(devices, ("devices",)) |
|
|
|
self._embed_texts_p = jax.jit( |
|
embed_texts, out_shardings=NamedSharding(self.mesh, P())) |
|
self._count_correct_p = jax.jit(count_correct, static_argnums=(1,), |
|
out_shardings=NamedSharding(self.mesh, P())) |
|
self._count_p = jax.jit(jnp.sum, |
|
out_shardings=NamedSharding(self.mesh, P())) |
|
self._all_gather_p = jax.jit( |
|
lambda x: x, out_shardings=NamedSharding(self.mesh, P())) |
|
|
|
self._compiled = set() |
|
assert async_delay > 0, f"async_delay must be >0, not {async_delay}" |
|
self._async_delay = async_delay |
|
logging.info("Initialized evaluator in %.1f seconds", time.monotonic() - t0) |
|
|
|
def _embed_texts(self, train_state, dataset_name): |
|
"""Returns per-class averaged text embeddings.""" |
|
t0 = time.monotonic() |
|
logging.info("Starting text embedding...") |
|
ns = [] |
|
embeddings = [] |
|
data = {"label": [], "mask": []} |
|
|
|
ds_b = input_pipeline.start_global( |
|
self.datasets[dataset_name]["texts"], self.devices) |
|
for batch in ds_b: |
|
ns.append(jax.device_get(self._count_p(batch["mask"]))) |
|
if len(ns) >= self._async_delay and ns[-self._async_delay] == 0: |
|
break |
|
|
|
embeddings.append(jax.device_get(self._embed_texts_p( |
|
train_state, batch["labels"]))) |
|
for name in data: |
|
data[name].append(jax.device_get(self._all_gather_p(batch[name]))) |
|
|
|
if self._embed_texts_p not in self._compiled: |
|
logging.info("Compiled text embeddings in %.1fs", time.monotonic() - t0) |
|
t0 = time.monotonic() |
|
self._compiled.add(self._embed_texts_p) |
|
|
|
ns = np.array(ns) |
|
n = ns.sum() |
|
data["embedding"] = embeddings |
|
data = {k: np.concatenate(v, axis=0) for k, v in data.items()} |
|
mask = data.pop("mask").astype(bool) |
|
data = {k: v[mask] for k, v in data.items()} |
|
data["average_embedding"] = _average_embeddings( |
|
data["embedding"], |
|
labels=data["label"], |
|
num_classes=len(self.datasets[dataset_name]["class_names"]), |
|
normalize=True) |
|
|
|
logging.info("Embedded %s text in %d steps - ...%s", dataset_name, len(ns), |
|
ns[-10:]) |
|
logging.info("Totalling %d text in %.1fs", n, time.monotonic() - t0) |
|
logging.info("Total texts embeddings size %.1fM", |
|
data["embedding"].nbytes / 1e6) |
|
return data |
|
|
|
def evaluate(self, |
|
train_state, |
|
dataset_name, |
|
*, |
|
return_embeddings=False): |
|
"""Returns evaluation results.""" |
|
texts = self._embed_texts(train_state, dataset_name) |
|
ztxt_p = texts["average_embedding"] |
|
ztxt_p = utils.reshard(ztxt_p, NamedSharding(self.mesh, P())) |
|
|
|
t0 = time.monotonic() |
|
logging.info("Starting image embedding...") |
|
|
|
ns = [] |
|
embeddings = [] |
|
corrects = [] |
|
data = {"mask": [], "label": []} if return_embeddings else {} |
|
|
|
ds_b = input_pipeline.start_global( |
|
self.datasets[dataset_name]["images"], self.devices) |
|
for batch in ds_b: |
|
ns.append(jax.device_get(self._count_p(batch["mask"]))) |
|
if len(ns) >= self._async_delay and ns[-self._async_delay] == 0: |
|
break |
|
|
|
labels = batch["label"] |
|
correct_p, embs_p = self._count_correct_p( |
|
train_state, |
|
return_embeddings, |
|
mask=batch["mask"], |
|
labels=labels, |
|
image=batch["image"], |
|
ztxt=ztxt_p, |
|
) |
|
corrects.append(jax.device_get(correct_p)) |
|
if self._count_correct_p not in self._compiled: |
|
logging.info("Compiled image embeddings in %.1fs", |
|
time.monotonic() - t0) |
|
t0 = time.monotonic() |
|
self._compiled.add(self._count_correct_p) |
|
|
|
if return_embeddings: |
|
embeddings.append(jax.device_get(self._all_gather_p(embs_p))) |
|
for name in data: |
|
data[name].append(jax.device_get(self._all_gather_p(batch[name]))) |
|
|
|
ns = np.array(ns) |
|
n = ns.sum() |
|
correct = np.array(corrects).sum() |
|
|
|
logging.info("Embedded %s image in %d steps - ...%s", dataset_name, len(ns), |
|
ns[-10:]) |
|
logging.info("Totalling %d image in %.1fs", n, time.monotonic() - t0) |
|
ret = { |
|
"accuracy": correct / n, |
|
"correct": correct, |
|
"count": n, |
|
} |
|
logging.info("Dataset %s, results %s", dataset_name, ret) |
|
|
|
if return_embeddings: |
|
data["embedding"] = embeddings |
|
data = {k: np.concatenate(v, axis=0) for k, v in data.items()} |
|
logging.info("Total images embeddings size %.1fM", |
|
data["embedding"].nbytes / 1e6) |
|
mask = data.pop("mask").astype(bool) |
|
ret["images"] = {k: v[mask] for k, v in data.items()} |
|
ret["texts"] = texts |
|
|
|
return ret |
|
|
|
def run(self, train_state): |
|
"""Returns metrics.""" |
|
return [(f"{dataset_name}_accuracy", |
|
self.evaluate(train_state, dataset_name)["accuracy"]) |
|
for dataset_name in self.datasets] |
|
|