# Copyright 2024 Big Vision Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils for few-shot evaluation.""" # pylint: disable=consider-using-from-import,g-importing-member import functools import big_vision.datasets.core as ds_core import big_vision.input_pipeline as input_pipeline import big_vision.pp.builder as pp_builder import big_vision.utils as u import jax import jax.numpy as jnp from jax.sharding import NamedSharding as Sharding from jax.sharding import PartitionSpec as P import numpy as np BIAS_CONSTANT = 100.0 # Temporary global flag to facilitate backwards compatability. Will be removed # by the end of year 2023. API = "jit" # Setup function for few-shot regression on CPU to avoid "polluting" the TPU. @u.jit_cpu(static_argnums=(2,)) def _precompute_cache(x, y, num_classes): """Cache quantities to speed-up the computation of L2-regularized least-sq.""" # Whiten mean = jnp.mean(x, axis=0, keepdims=True) std = jnp.std(x, axis=0, keepdims=True) + 1e-5 x = (x - mean) / std # Add a constant feature for the bias, large so it's almost unregularized: x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) # To one-hot representation rescaled into {-1, 1} y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0 num_points, dim = x.shape # Let N be the number of points, D the dimension and C the number of classes. # We have x of shape (N, D) and y of shape (N, C). # For least-squares, we can compute # # (A) when N >= D, (x^T x + l2 Id)^{-1} x^T y # (B) when D > N, x^T (x x^T + l2 Id)^{-1} y # # We pre-compute the eigen-decomposition of either x^T x or x x^T which # becomes q diag(eigs) q^T with q unitary matrix either (D, D) or (N, N) # and eigs a vector (D,) or (N,). # # For any l2 > 0, we can compute (x^T x + l2 Id)^{-1} or (x x^T + l2 Id)^{-1} # by simply computing q (diag(eigs) + l2 Id)^{-1} q^T. # (SVD would be more natural here, but it proved slower, so we use eigh) # # Both cases (A) and (B) can be viewed as lhs (diag(eigs) + l2 Id)^{-1} rhs, # where lhs/rhs are pre-computed left/right-hand sides to specify. # # Detailed evaluation in terms of time and fewshot metrics can be found in # (internal link) # # Implemented by Rodolphe Jenatton. if num_points >= dim: eigs, q = jnp.linalg.eigh(x.T @ x) rhs = q.T @ (x.T @ y) lhs = q else: eigs, q = jnp.linalg.eigh(x @ x.T) rhs = q.T @ y lhs = x.T @ q cache = { "eigs": eigs, "rhs": rhs, "lhs": lhs, "mean": mean, "std": std } return cache @u.jit_cpu() def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg): """Computes (x,y) linear regression accuracy on (x_test, y_test).""" x_test = (x_test - cache["mean"]) / cache["std"] x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) rhs = cache["rhs"] lhs = cache["lhs"] eigs = cache["eigs"] # See comments in _precompute_cache for context about the formula. scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs)) scaling = scaling.reshape((1, -1)) w = (lhs * scaling) @ rhs # Predict test-set values and measure their accuracy preds = jnp.argmax(x_test @ w, axis=1) return jnp.mean(preds == y_test) class Evaluator: """Class for few-shot evaluation.""" def __init__(self, predict_fn, batch_size, datasets, shots, l2_reg, pp_train, pp_eval, display_first, representation_layer=None, num_seeds=3, label_key="label", mask_key="_mask", data_dir=None, *, devices): self.datasets = datasets self.shots = shots self.l2_reg = l2_reg self.batch_size = batch_size self.pp_tr = pp_train self.pp_te = pp_eval self.display_first = display_first self._datasets = {} # Cache for tfds data. Persists while object is alive. self._repr = {} # Cache for precomputed repr. Persists within the run call. self.num_seeds = num_seeds self.label_key = label_key self.mask_key = mask_key self.data_dir = data_dir self.devices = devices self.mesh = jax.sharding.Mesh(devices, ("devices",)) self.repr_fn = self.get_representation_fn( predict_fn, representation_layer) def get_representation_fn(self, predict_fn, representation_layer): # `out_shardings=Sharding(self.mesh, P())` will "all_gather" the outputs. @functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P())) def _repr_fn(train_state, batch, labels, mask): zimg, *_, out = predict_fn(train_state, batch) if representation_layer is not None: rep = u.tree_get(out, representation_layer) else: rep = zimg return rep, labels, mask return _repr_fn # Setup input pipeline. def _get_dataset(self, dataset, train_split, test_split): """Lazy-loads given dataset.""" key = (dataset, train_split, test_split) try: return self._datasets[key] except KeyError: # NOTE: only supporting TFDS data for now for bwd compat/lazyness. train_data = ds_core.get( name=dataset, split=train_split, data_dir=self.data_dir ) test_data = ds_core.get( name=dataset, split=test_split, data_dir=self.data_dir ) train_ds, batches_tr = input_pipeline.make_for_inference( train_data.get_tfdata(ordered=True), num_ex_per_process=train_data.num_examples_per_process(), batch_size=self.batch_size, preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr)) test_ds, batches_te = input_pipeline.make_for_inference( test_data.get_tfdata(ordered=True), num_ex_per_process=test_data.num_examples_per_process(), batch_size=self.batch_size, preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te)) num_classes = train_data.builder.info.features[self.label_key].num_classes return self._datasets.setdefault( key, (train_ds, batches_tr, test_ds, batches_te, num_classes)) def _get_repr(self, params, data, steps): """Compute representation for the whole dataset.""" pre_logits_list = [] labels_list = [] for batch, _ in zip( input_pipeline.start_global(data, self.devices, 0), range(steps)): labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key) pre_logits, labels, mask = jax.device_get(self.repr_fn( params, batch, labels, mask)) mask = mask.astype(bool) pre_logits_list.append(pre_logits[mask]) labels_list.append(labels[mask]) pre_logits = np.concatenate(pre_logits_list, axis=0) labels = np.concatenate(labels_list, axis=0) return pre_logits, labels def compute_fewshot_metrics(self, train_state, seed, dataset, train_split, test_split): """Compute few-shot metrics on one dataset.""" if dataset in self._repr: repr_train, labels_train, repr_test, labels_test, num_classes = ( self._repr[dataset]) else: train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset( dataset, train_split, test_split) repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr) repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te) self._repr[dataset] = (repr_train, labels_train, repr_test, labels_test, num_classes) # Collect where we have samples of which classes. rng = np.random.default_rng(seed) class_indices = [rng.permutation(np.where(labels_train == cls_i)[0]) for cls_i in range(num_classes)] results = {} for shots in self.shots: all_idx = [indices[:shots] for indices in class_indices] all_idx = np.concatenate(all_idx, axis=0) x = u.put_cpu(repr_train[all_idx]) y = u.put_cpu(labels_train[all_idx]) repr_test, labels_test = u.put_cpu((repr_test, labels_test)) # Note the code is optimized to solve multiple LSR tasks for changing l2 # strength, even though we currently used the fixed l2_reg constant. cache = _precompute_cache(x, y, num_classes) acc = _eig_fewshot_acc_fn( cache, repr_test, labels_test, u.put_cpu(self.l2_reg)) results[shots] = jax.device_get(acc) return results def run(self, train_state): """New API executed in terms of old API.""" self._repr = {} for seed in range(self.num_seeds): for name, dataset_args in self.datasets.items(): result = self.compute_fewshot_metrics(train_state, seed, *dataset_args) for shots, v in result.items(): prefix = "a/" if (name, shots) in self.display_first else "z/" suffix = f"-seed-{seed}" yield f"{prefix}{name}_{shots}shot{suffix}", v