|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for few-shot evaluation.""" |
|
|
|
|
|
import functools |
|
|
|
import big_vision.datasets.core as ds_core |
|
import big_vision.input_pipeline as input_pipeline |
|
import big_vision.pp.builder as pp_builder |
|
import big_vision.utils as u |
|
import jax |
|
import jax.numpy as jnp |
|
from jax.sharding import NamedSharding as Sharding |
|
from jax.sharding import PartitionSpec as P |
|
import numpy as np |
|
|
|
BIAS_CONSTANT = 100.0 |
|
|
|
|
|
|
|
API = "jit" |
|
|
|
|
|
|
|
@u.jit_cpu(static_argnums=(2,)) |
|
def _precompute_cache(x, y, num_classes): |
|
"""Cache quantities to speed-up the computation of L2-regularized least-sq.""" |
|
|
|
mean = jnp.mean(x, axis=0, keepdims=True) |
|
std = jnp.std(x, axis=0, keepdims=True) + 1e-5 |
|
x = (x - mean) / std |
|
|
|
|
|
x = jnp.pad(x, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) |
|
|
|
|
|
y = 2.0 * jax.nn.one_hot(y, num_classes) - 1.0 |
|
|
|
num_points, dim = x.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if num_points >= dim: |
|
eigs, q = jnp.linalg.eigh(x.T @ x) |
|
rhs = q.T @ (x.T @ y) |
|
lhs = q |
|
else: |
|
eigs, q = jnp.linalg.eigh(x @ x.T) |
|
rhs = q.T @ y |
|
lhs = x.T @ q |
|
|
|
cache = { |
|
"eigs": eigs, |
|
"rhs": rhs, |
|
"lhs": lhs, |
|
"mean": mean, |
|
"std": std |
|
} |
|
return cache |
|
|
|
|
|
@u.jit_cpu() |
|
def _eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg): |
|
"""Computes (x,y) linear regression accuracy on (x_test, y_test).""" |
|
|
|
x_test = (x_test - cache["mean"]) / cache["std"] |
|
x_test = jnp.pad(x_test, ((0, 0), (0, 1)), constant_values=BIAS_CONSTANT) |
|
|
|
rhs = cache["rhs"] |
|
lhs = cache["lhs"] |
|
eigs = cache["eigs"] |
|
|
|
|
|
scaling = 1.0 / (eigs + l2_reg * jnp.ones_like(eigs)) |
|
scaling = scaling.reshape((1, -1)) |
|
w = (lhs * scaling) @ rhs |
|
|
|
preds = jnp.argmax(x_test @ w, axis=1) |
|
return jnp.mean(preds == y_test) |
|
|
|
|
|
class Evaluator: |
|
"""Class for few-shot evaluation.""" |
|
|
|
def __init__(self, predict_fn, batch_size, |
|
datasets, shots, l2_reg, |
|
pp_train, pp_eval, display_first, |
|
representation_layer=None, num_seeds=3, |
|
label_key="label", mask_key="_mask", data_dir=None, *, |
|
devices): |
|
self.datasets = datasets |
|
self.shots = shots |
|
self.l2_reg = l2_reg |
|
self.batch_size = batch_size |
|
self.pp_tr = pp_train |
|
self.pp_te = pp_eval |
|
self.display_first = display_first |
|
self._datasets = {} |
|
self._repr = {} |
|
self.num_seeds = num_seeds |
|
self.label_key = label_key |
|
self.mask_key = mask_key |
|
self.data_dir = data_dir |
|
self.devices = devices |
|
self.mesh = jax.sharding.Mesh(devices, ("devices",)) |
|
self.repr_fn = self.get_representation_fn( |
|
predict_fn, representation_layer) |
|
|
|
def get_representation_fn(self, predict_fn, representation_layer): |
|
|
|
@functools.partial(jax.jit, out_shardings=Sharding(self.mesh, P())) |
|
def _repr_fn(train_state, batch, labels, mask): |
|
zimg, *_, out = predict_fn(train_state, batch) |
|
if representation_layer is not None: |
|
rep = u.tree_get(out, representation_layer) |
|
else: |
|
rep = zimg |
|
return rep, labels, mask |
|
return _repr_fn |
|
|
|
|
|
def _get_dataset(self, dataset, train_split, test_split): |
|
"""Lazy-loads given dataset.""" |
|
key = (dataset, train_split, test_split) |
|
try: |
|
return self._datasets[key] |
|
except KeyError: |
|
|
|
train_data = ds_core.get( |
|
name=dataset, split=train_split, data_dir=self.data_dir |
|
) |
|
test_data = ds_core.get( |
|
name=dataset, split=test_split, data_dir=self.data_dir |
|
) |
|
train_ds, batches_tr = input_pipeline.make_for_inference( |
|
train_data.get_tfdata(ordered=True), |
|
num_ex_per_process=train_data.num_examples_per_process(), |
|
batch_size=self.batch_size, |
|
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_tr)) |
|
test_ds, batches_te = input_pipeline.make_for_inference( |
|
test_data.get_tfdata(ordered=True), |
|
num_ex_per_process=test_data.num_examples_per_process(), |
|
batch_size=self.batch_size, |
|
preprocess_fn=pp_builder.get_preprocess_fn(self.pp_te)) |
|
|
|
num_classes = train_data.builder.info.features[self.label_key].num_classes |
|
return self._datasets.setdefault( |
|
key, (train_ds, batches_tr, test_ds, batches_te, num_classes)) |
|
|
|
def _get_repr(self, params, data, steps): |
|
"""Compute representation for the whole dataset.""" |
|
pre_logits_list = [] |
|
labels_list = [] |
|
for batch, _ in zip( |
|
input_pipeline.start_global(data, self.devices, 0), range(steps)): |
|
labels, mask = batch.pop(self.label_key), batch.pop(self.mask_key) |
|
pre_logits, labels, mask = jax.device_get(self.repr_fn( |
|
params, batch, labels, mask)) |
|
mask = mask.astype(bool) |
|
pre_logits_list.append(pre_logits[mask]) |
|
labels_list.append(labels[mask]) |
|
pre_logits = np.concatenate(pre_logits_list, axis=0) |
|
labels = np.concatenate(labels_list, axis=0) |
|
|
|
return pre_logits, labels |
|
|
|
def compute_fewshot_metrics(self, train_state, seed, |
|
dataset, train_split, test_split): |
|
"""Compute few-shot metrics on one dataset.""" |
|
if dataset in self._repr: |
|
repr_train, labels_train, repr_test, labels_test, num_classes = ( |
|
self._repr[dataset]) |
|
else: |
|
train_ds, steps_tr, test_ds, steps_te, num_classes = self._get_dataset( |
|
dataset, train_split, test_split) |
|
repr_train, labels_train = self._get_repr(train_state, train_ds, steps_tr) |
|
repr_test, labels_test = self._get_repr(train_state, test_ds, steps_te) |
|
self._repr[dataset] = (repr_train, labels_train, |
|
repr_test, labels_test, |
|
num_classes) |
|
|
|
|
|
rng = np.random.default_rng(seed) |
|
class_indices = [rng.permutation(np.where(labels_train == cls_i)[0]) |
|
for cls_i in range(num_classes)] |
|
|
|
results = {} |
|
for shots in self.shots: |
|
all_idx = [indices[:shots] for indices in class_indices] |
|
all_idx = np.concatenate(all_idx, axis=0) |
|
x = u.put_cpu(repr_train[all_idx]) |
|
y = u.put_cpu(labels_train[all_idx]) |
|
repr_test, labels_test = u.put_cpu((repr_test, labels_test)) |
|
|
|
|
|
|
|
cache = _precompute_cache(x, y, num_classes) |
|
acc = _eig_fewshot_acc_fn( |
|
cache, repr_test, labels_test, u.put_cpu(self.l2_reg)) |
|
results[shots] = jax.device_get(acc) |
|
|
|
return results |
|
|
|
def run(self, train_state): |
|
"""New API executed in terms of old API.""" |
|
self._repr = {} |
|
for seed in range(self.num_seeds): |
|
for name, dataset_args in self.datasets.items(): |
|
result = self.compute_fewshot_metrics(train_state, seed, *dataset_args) |
|
for shots, v in result.items(): |
|
prefix = "a/" if (name, shots) in self.display_first else "z/" |
|
suffix = f"-seed-{seed}" |
|
yield f"{prefix}{name}_{shots}shot{suffix}", v |
|
|