pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator for computing mean of per-example metrics."""
import functools
from typing import Mapping
from big_vision import input_pipeline
from big_vision.datasets import core as ds_core
from big_vision.pp import builder as pp_builder
import jax
import jax.numpy as jnp
import numpy as np
# Note: global to avoid jax re-compiling across different evaluator instances.
@functools.partial(jax.pmap, static_broadcasted_argnums=0, axis_name='batch')
def _run_predict_fn(predict_fn, params, batch):
"""Sum per-example metrics weighted by `_mask`."""
mask = batch['_mask']
metrics = predict_fn(params, batch)
# Sanity check output format of predict_fn.
assert isinstance(metrics, Mapping), 'predict_fn must return a dict'
for y in jax.tree_leaves(metrics):
if y.shape != mask.shape:
raise ValueError(
f'Expected per-example metrics of shape {mask.shape} found '
f'{jax.tree_map(lambda x: x.shape, metrics)}.')
metrics = {**metrics, '_mask': mask}
metrics = jax.tree_map(lambda x: jnp.inner(x, mask), metrics)
return jax.lax.psum(metrics, axis_name='batch')
class Evaluator:
"""Report the mean of per-example metrics computed by predict_fn.
`predict_fn(params, batch)` must return a dict from metric name to
per-example metrics of shape [batch_size].
"""
def __init__(self, predict_fn, data, pp_fn, batch_size,
cache_final=True, cache_raw=False, prefetch=1):
data = ds_core.get(**data)
self.dataset, self.steps = input_pipeline.make_for_inference(
data.get_tfdata(ordered=True), batch_size=batch_size,
num_ex_per_process=data.num_examples_per_process(),
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn),
cache_final=cache_final, cache_raw=cache_raw)
self.data_iter = input_pipeline.start_input_pipeline(self.dataset, prefetch)
self.predict_fn = predict_fn
def run(self, params):
"""Computes all metrics."""
metrics = []
# Compute batch metrics without blocking.
for _, batch in zip(range(self.steps), self.data_iter):
batch_metrics = _run_predict_fn(self.predict_fn, params, batch)
metrics.append(batch_metrics)
# Transfer metrics from device 0 to host (blocking).
metrics = jax.device_get(jax.tree_map(lambda x: x[0], metrics))
metrics_sum = jax.tree_map(lambda *x: np.sum(x), *metrics)
mask_sum = metrics_sum.pop('_mask')
for key, value_sum in metrics_sum.items():
yield (key, value_sum / mask_sum)