|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for perplexity of a model.""" |
|
import functools |
|
|
|
from big_vision.evaluators import mean |
|
import big_vision.utils as u |
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
|
API = 'jit' |
|
|
|
|
|
|
|
@functools.cache |
|
def perplexity( |
|
predict_fn, key='labels', shift_labels=True): |
|
"""Returns a function that computes perplexity.""" |
|
|
|
def _perplexity_fn(train_state, batch, **kw): |
|
logits, _ = predict_fn(train_state, batch, **kw) |
|
|
|
labels = batch[key] |
|
weights = batch.get('mask_loss', jnp.ones_like(labels)) |
|
|
|
if shift_labels: |
|
labels = labels[:, 1:] |
|
weights = weights[:, 1:] |
|
|
|
losses = u.weighted_softmax_xent( |
|
logits=logits, labels=labels, weights=weights, |
|
reduction=False, normalize=False) |
|
normalizer = jnp.clip(weights.sum(axis=1), 2e-38) |
|
|
|
return {'sum': losses, 'avg': losses / normalizer} |
|
return _perplexity_fn |
|
|
|
|
|
class Evaluator(mean.Evaluator): |
|
"""Perplexity evaluator.""" |
|
|
|
def __init__(self, predict_fn, *a, key='labels', shift_labels=False, **kw): |
|
kw.setdefault('prefetch', 0) |
|
super().__init__(perplexity(predict_fn, key, shift_labels), *a, **kw) |
|
|