|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for the contrastive task. |
|
|
|
DON'T COMPARE ACROSS RUNS, use for training health monitoring only. |
|
|
|
Note that this evaluator's `ncorrect_minibatch` is only a rough proxy for |
|
training progress and does not report the actual `ncorrect`: when the same |
|
labels found multiple times in a batch, then the reported value is biased |
|
towards lower values. |
|
|
|
Also note that the `ncorrect_minibatch` is a function of batch size (it's a lot |
|
easier to find correct values in small batches). |
|
""" |
|
import functools |
|
|
|
from big_vision import input_pipeline |
|
import big_vision.datasets.core as ds_core |
|
import big_vision.pp.builder as pp_builder |
|
import big_vision.utils as u |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
def _all_gather(z): |
|
"""All gather and flatten first two dims.""" |
|
gather_flat = lambda x: jnp.concatenate(jax.lax.all_gather(x, "batch"), 0) |
|
return jax.tree_map(gather_flat, z) |
|
|
|
|
|
|
|
|
|
@functools.lru_cache(None) |
|
def get_eval_fn(predict_fn, use_global_batch): |
|
"""Produces eval function, also applies pmap.""" |
|
|
|
@functools.partial(jax.pmap, axis_name="batch") |
|
def _eval_fn(params, images, labels, mask): |
|
zimg, ztxt, extras = predict_fn(params, images, labels) |
|
|
|
if use_global_batch: |
|
zimg, ztxt, mask = _all_gather((zimg, ztxt, mask)) |
|
|
|
|
|
losses, measurements = u.bidirectional_contrastive_loss( |
|
zimg, ztxt, extras["t"], mask, reduction=False) |
|
l = jax.lax.psum(losses * mask, axis_name="batch") |
|
c = jax.lax.psum(measurements["ncorrect"] * mask, axis_name="batch") |
|
n = jax.lax.psum(mask, axis_name="batch") |
|
return c, l, n |
|
|
|
return _eval_fn |
|
|
|
|
|
class Evaluator: |
|
"""Contrastive evaluator.""" |
|
|
|
def __init__(self, predict_fn, data, pp_fn, batch_size, |
|
use_global_batch, cache_final=True, |
|
cache_raw=False, prefetch=1, label_key="labels"): |
|
data = ds_core.get(**data) |
|
pp_fn = pp_builder.get_preprocess_fn(pp_fn) |
|
self.ds, self.steps = input_pipeline.make_for_inference( |
|
data.get_tfdata(ordered=True), pp_fn, batch_size, |
|
num_ex_per_process=data.num_examples_per_process(), |
|
cache_final=cache_final, cache_raw=cache_raw) |
|
self.data_iter = input_pipeline.start_input_pipeline(self.ds, prefetch) |
|
self.eval_fn = get_eval_fn(predict_fn, use_global_batch) |
|
self.label_key = label_key |
|
|
|
def run(self, params): |
|
"""Computes all metrics.""" |
|
l, c, nseen = 0, 0, 0 |
|
for _, batch in zip(range(self.steps), self.data_iter): |
|
labels, mask = batch.pop(self.label_key), batch.pop("_mask") |
|
batch_ncorrect, batch_losses, batch_n = self.eval_fn( |
|
params, batch["image"], labels, mask) |
|
|
|
|
|
|
|
|
|
c += np.sum(np.array(batch_ncorrect[0])) |
|
l += np.sum(np.array(batch_losses[0])) |
|
nseen += np.sum(np.array(batch_n[0])) |
|
yield ("ncorrect_minibatch", c / nseen) |
|
yield ("loss", l / nseen) |
|
|