|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for perplexity of a model.""" |
|
from big_vision.evaluators import mean |
|
import big_vision.utils as u |
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
|
API = 'jit' |
|
|
|
|
|
def perplexity(predict_fn, normalize_by_seqlen): |
|
"""Returns a function that computes perplexity.""" |
|
|
|
def _perplexity_fn(train_state, batch, pad_token=0, **kw): |
|
logits, _ = predict_fn(train_state, batch, **kw) |
|
|
|
|
|
weights = jnp.where(batch['labels'] != pad_token, 1, 0).astype(jnp.float32) |
|
if batch.get('label_masks') is not None: |
|
weights = weights * batch['label_masks'] |
|
|
|
losses = u.weighted_softmax_xent( |
|
logits=logits, labels=batch['labels'], |
|
weights=weights, label_smoothing=0.0, |
|
reduction=False, normalize=normalize_by_seqlen) |
|
|
|
return {'perplexity': losses} |
|
return _perplexity_fn |
|
|
|
|
|
class Evaluator(mean.Evaluator): |
|
"""Perplexity evaluator.""" |
|
|
|
def __init__(self, predict_fn, *a, normalize_by_seqlen=False, **kw): |
|
super().__init__(perplexity(predict_fn, normalize_by_seqlen), *a, **kw) |
|
|