pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator for the classfication task."""
from functools import partial, lru_cache
from big_vision import input_pipeline
import big_vision.datasets.core as ds_core
import big_vision.pp.builder as pp_builder
import big_vision.utils as u
import einops
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
import numpy as np
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = 'jit'
def dist(student, teacher, kind, feat_axis=-1,
epsilon=1e-12, t=1, ls=0.0, k=1):
"""Distance function used for distillation."""
diff = student - teacher
if kind == 'euclidean':
return jnp.sqrt(jnp.sum(diff * diff, axis=feat_axis) + epsilon)
elif kind == 'l2':
return jnp.sum(diff * diff, axis=feat_axis)
elif kind == 'hard':
pseudolabels = jnp.argmax(teacher, feat_axis)
pl = u.onehot(pseudolabels, teacher.shape[feat_axis])
if ls:
pl = (1.0 - ls) * pl + (ls / (pl.shape[-1] - 1)) * (1.0 - pl)
return u.softmax_xent(logits=student, labels=pl,
reduction=False, kl=True, axis=feat_axis)
elif kind == 'kl':
return t**2 * u.softmax_xent(
logits=student / t,
labels=jax.nn.softmax(teacher / t),
reduction=False, kl=True, axis=feat_axis)
elif kind == 'logsoftmax_euclidean':
logsoftmax_diff = (
jax.nn.log_softmax(student, axis=feat_axis) -
jax.nn.log_softmax(teacher, axis=feat_axis))
return jnp.sqrt(
jnp.sum(logsoftmax_diff * logsoftmax_diff, axis=feat_axis) + epsilon)
elif kind == 'agree':
def get_top_k(arr, k, ax):
return jax.lax.top_k(arr.swapaxes(ax, -1), k)[1].swapaxes(ax, -1)
return (get_top_k(student, k, feat_axis) ==
get_top_k(teacher, 1, feat_axis)).sum(feat_axis)
else:
assert False, f'Unknown kind of distance {kind}.'
@lru_cache(None)
def get_dist_fn(**kw):
return partial(dist, **kw)
# To avoid re-compiling the function for every new instance of the same
# evaluator on a different dataset!
@lru_cache(None)
def get_eval_fn(student_teacher_fwd, what, mesh, distances):
"""Produces eval function, also applies pmap."""
@partial(jax.jit, out_shardings=NamedSharding(mesh, P()))
def _eval_fn(train_state, batch, mask):
(_, out_s), (_, out_t) = student_teacher_fwd(train_state, batch)
repr_s = u.tree_get(out_s, what[0])
repr_t = u.tree_get(out_t, what[1])
# Let's flatten any non-vectors (eg feature-maps).
repr_s = einops.rearrange(repr_s, 'b ... -> b (...)')
repr_t = einops.rearrange(repr_t, 'b ... -> b (...)')
all_ds = []
# NOTE: we're gathering and returning all ; if this becomes too slow, we
# can change to compute and return summary stats later on.
for dist_fn in distances:
ds = dist_fn(repr_s, repr_t)
all_ds.append(ds)
all_masks = mask
return all_ds, all_masks
return _eval_fn
class Evaluator:
"""Distillation distance evaluator."""
def __init__(
self,
student_teacher_fwd,
data,
pp_fn,
distances,
what=('logits', 'logits'),
*,
devices,
**data_kw,
):
data = ds_core.get(**data)
pp_fn = pp_builder.get_preprocess_fn(pp_fn)
prefetch = data_kw.pop('prefetch', 1)
self.ds, self.steps = input_pipeline.make_for_inference(
data.get_tfdata(ordered=True),
pp_fn,
num_ex_per_process=data.num_examples_per_process(),
**data_kw,
)
self.data_iter = input_pipeline.start_global(self.ds, devices, prefetch)
dist_fns = tuple(get_dist_fn(**dist) for dist in distances)
self.dist_names = [
'_'.join(f'{k}={v}' for k, v in dist.items()) for dist in distances
]
mesh = jax.sharding.Mesh(devices, ('data',))
self.eval_fn = get_eval_fn(student_teacher_fwd, what, mesh, dist_fns)
def run(self, train_state):
"""Computes all metrics."""
all_ds = [[] for _ in self.dist_names]
for _, batch in zip(range(self.steps), self.data_iter):
mask = batch.pop('_mask')
batch_ds, batch_ms = self.eval_fn(train_state, batch, mask)
# All results are a replicated array shaped as follows:
# (local_devices, per_device_batch_size, elem_shape...)
# with each local device's entry being identical.
# So let's just take the first one to the host as numpy.
batch_ms = np.array(batch_ms)
for i, val in enumerate(batch_ds):
all_ds[i].append(np.array(val)[batch_ms == 1])
for name, ds in zip(self.dist_names, all_ds):
ds = np.concatenate(ds)
yield f'{name}/all', ds
yield f'{name}/avg', np.mean(ds)
yield f'{name}/min', np.min(ds)
yield f'{name}/max', np.max(ds)