pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for few-shot evaluation."""
# pylint: disable=consider-using-from-import,g-importing-member
import functools
import big_vision.datasets.core as ds_core
import big_vision.input_pipeline as input_pipeline
import big_vision.pp.builder as pp_builder
import big_vision.utils as u
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding as Sharding
from jax.sharding import PartitionSpec as P
import numpy as np
BIAS_CONSTANT = 100.0
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = "jit"
# Setup function for few-shot regression on CPU to avoid "polluting" the TPU.
@u.jit_cpu(static_argnums=(2,))
def _precompute_cache(x, y, num_classes):
"""Cache quantities to speed-up the computation of L2-regularized least-sq."""
# Whiten
mean = jnp.mean(x, axis=0, keepdims=True)
std = jnp.std(x, axis=0, keepdims=True) + 1e-5
x = (x - mean) / std
# Add a constant feature for the bias, large so it's almost unregularized:
x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
# To one-hot representation rescaled into {-1, 1}
y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0
num_points, dim = x.shape
# Let N be the number of points, D the dimension and C the number of classes.
# We have x of shape (N, D) and y of shape (N, C).
# For least-squares, we can compute
#
# (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y
# (B) when D > N, x^T (x x^T + l2 Id)^{-1} y
#
# We pre-compute the eigen-decomposition of either x^T x or x x^T which
# becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N)
# and eigs a vector (D,) or (N,).
#
# For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1}
# by simply computing q (diag(eigs) + l2 Id)^{-1} q^T.
# (SVD would be more natural here, but it proved slower, so we use eigh)
#
# Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs,
# where lhs/rhs are pre-computed left/right-hand sides to specify.
#
# Detailed evaluation in terms of time and fewshot metrics can be found in
# (internal link)
#
# Implemented by Rodolphe Jenatton.
if num_points >= dim:
eigs, q = jnp.linalg.eigh(x.T @ x)
rhs = q.T @ (x.T @ y)
lhs = q
else:
eigs, q = jnp.linalg.eigh(x @ x.T)
rhs = q.T @ y
lhs = x.T @ q
cache = {
"eigs": eigs,
"rhs": rhs,
"lhs": lhs,
"mean": mean,
"std": std
}
return cache
@u.jit_cpu()
def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg):
"""Computes (x,y) linear regression accuracy on (x_test, y_test)."""
x_test = (x_test - cache["mean"]) / cache["std"]
x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT)
rhs = cache["rhs"]
lhs = cache["lhs"]
eigs = cache["eigs"]
# See comments in _precompute_cache for context about the formula.
scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs))
scaling = scaling.reshape((1, -1))
w = (lhs * scaling) @ rhs
# Predict test-set values and measure their accuracy
preds = jnp.argmax(x_test @ w, axis=1)
return jnp.mean(preds == y_test)
class Evaluator:
"""Class for few-shot evaluation."""
def __init__(self, predict_fn, batch_size,
datasets, shots, l2_reg,
pp_train, pp_eval, display_first,
representation_layer=None, num_seeds=3,
label_key="label", mask_key="_mask", data_dir=None, *,
devices):
self.datasets = datasets
self.shots = shots
self.l2_reg = l2_reg
self.batch_size = batch_size
self.pp_tr = pp_train
self.pp_te = pp_eval
self.display_first = display_first
self._datasets = {} # Cache for tfds data. Persists while object is alive.
self._repr = {} # Cache for precomputed repr. Persists within the run call.
self.num_seeds = num_seeds
self.label_key = label_key
self.mask_key = mask_key
self.data_dir = data_dir
self.devices = devices
self.mesh = jax.sharding.Mesh(devices, ("devices",))
self.repr_fn = self.get_representation_fn(
predict_fn, representation_layer)
def get_representation_fn(self, predict_fn, representation_layer):
# `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs.
@functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P()))
def _repr_fn(train_state, batch, labels, mask):
zimg, *_, out = predict_fn(train_state, batch)
if representation_layer is not None:
rep = u.tree_get(out, representation_layer)
else:
rep = zimg
return rep, labels, mask
return _repr_fn
# Setup input pipeline.
def _get_dataset(self, dataset, train_split, test_split):
"""Lazy-loads given dataset."""
key = (dataset, train_split, test_split)
try:
return self._datasets[key]
except KeyError:
# NOTE: only supporting TFDS data for now for bwd compat/lazyness.
train_data = ds_core.get(
name=dataset, split=train_split, data_dir=self.data_dir
)
test_data = ds_core.get(
name=dataset, split=test_split, data_dir=self.data_dir
)
train_ds, batches_tr = input_pipeline.make_for_inference(
train_data.get_tfdata(ordered=True),
num_ex_per_process=train_data.num_examples_per_process(),
batch_size=self.batch_size,
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr))
test_ds, batches_te = input_pipeline.make_for_inference(
test_data.get_tfdata(ordered=True),
num_ex_per_process=test_data.num_examples_per_process(),
batch_size=self.batch_size,
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te))
num_classes = train_data.builder.info.features[self.label_key].num_classes
return self._datasets.setdefault(
key, (train_ds, batches_tr, test_ds, batches_te, num_classes))
def _get_repr(self, params, data, steps):
"""Compute representation for the whole dataset."""
pre_logits_list = []
labels_list = []
for batch, _ in zip(
input_pipeline.start_global(data, self.devices, 0), range(steps)):
labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key)
pre_logits, labels, mask = jax.device_get(self.repr_fn(
params, batch, labels, mask))
mask = mask.astype(bool)
pre_logits_list.append(pre_logits[mask])
labels_list.append(labels[mask])
pre_logits = np.concatenate(pre_logits_list, axis=0)
labels = np.concatenate(labels_list, axis=0)
return pre_logits, labels
def compute_fewshot_metrics(self, train_state, seed,
dataset, train_split, test_split):
"""Compute few-shot metrics on one dataset."""
if dataset in self._repr:
repr_train, labels_train, repr_test, labels_test, num_classes = (
self._repr[dataset])
else:
train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset(
dataset, train_split, test_split)
repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr)
repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te)
self._repr[dataset] = (repr_train, labels_train,
repr_test, labels_test,
num_classes)
# Collect where we have samples of which classes.
rng = np.random.default_rng(seed)
class_indices = [rng.permutation(np.where(labels_train == cls_i)[0])
for cls_i in range(num_classes)]
results = {}
for shots in self.shots:
all_idx = [indices[:shots] for indices in class_indices]
all_idx = np.concatenate(all_idx, axis=0)
x = u.put_cpu(repr_train[all_idx])
y = u.put_cpu(labels_train[all_idx])
repr_test, labels_test = u.put_cpu((repr_test, labels_test))
# Note the code is optimized to solve multiple LSR tasks for changing l2
# strength, even though we currently used the fixed l2_reg constant.
cache = _precompute_cache(x, y, num_classes)
acc = _eig_fewshot_acc_fn(
cache, repr_test, labels_test, u.put_cpu(self.l2_reg))
results[shots] = jax.device_get(acc)
return results
def run(self, train_state):
"""New API executed in terms of old API."""
self._repr = {}
for seed in range(self.num_seeds):
for name, dataset_args in self.datasets.items():
result = self.compute_fewshot_metrics(train_state, seed, *dataset_args)
for shots, v in result.items():
prefix = "a/" if (name, shots) in self.display_first else "z/"
suffix = f"-seed-{seed}"
yield f"{prefix}{name}_{shots}shot{suffix}", v