|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for computing mean of per-example metrics.""" |
|
import functools |
|
from typing import Mapping |
|
|
|
from big_vision import input_pipeline |
|
from big_vision.datasets import core as ds_core |
|
from big_vision.pp import builder as pp_builder |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
|
|
@functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch') |
|
def _run_predict_fn(predict_fn, params, batch): |
|
"""Sum per-example metrics weighted by `_mask`.""" |
|
mask = batch['_mask'] |
|
metrics = predict_fn(params, batch) |
|
|
|
assert isinstance(metrics, Mapping), 'predict_fn must return a dict' |
|
for y in jax.tree_leaves(metrics): |
|
if y.shape != mask.shape: |
|
raise ValueError( |
|
f'Expected per-example metrics of shape {mask.shape} found ' |
|
f'{jax.tree_map(lambda x: x.shape, metrics)}.') |
|
metrics = {**metrics, '_mask': mask} |
|
metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics) |
|
return jax.lax.psum(metrics, axis_name='batch') |
|
|
|
|
|
class Evaluator: |
|
"""Report the mean of per-example metrics computed by predict_fn. |
|
|
|
`predict_fn(params, batch)` must return a dict from metric name to |
|
per-example metrics of shape [batch_size]. |
|
""" |
|
|
|
def __init__(self, predict_fn, data, pp_fn, batch_size, |
|
cache_final=True, cache_raw=False, prefetch=1): |
|
data = ds_core.get(**data) |
|
self.dataset, self.steps = input_pipeline.make_for_inference( |
|
data.get_tfdata(ordered=True), batch_size=batch_size, |
|
num_ex_per_process=data.num_examples_per_process(), |
|
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn), |
|
cache_final=cache_final, cache_raw=cache_raw) |
|
self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch) |
|
self.predict_fn = predict_fn |
|
|
|
def run(self, params): |
|
"""Computes all metrics.""" |
|
metrics = [] |
|
|
|
|
|
for _, batch in zip(range(self.steps), self.data_iter): |
|
batch_metrics = _run_predict_fn(self.predict_fn, params, batch) |
|
metrics.append(batch_metrics) |
|
|
|
|
|
metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics)) |
|
|
|
metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics) |
|
mask_sum = metrics_sum.pop('_mask') |
|
for key, value_sum in metrics_sum.items(): |
|
yield (key, value_sum / mask_sum) |
|
|