Spaces:
Runtime error
Runtime error
File size: 5,062 Bytes
b97e015 1a74fec b97e015 1a74fec b97e015 b0c3beb 76e1a38 1a74fec b97e015 1a74fec b97e015 1a74fec b97e015 b09fb1d b97e015 b09fb1d b97e015 b0c3beb 6d20fa3 b0c3beb 1a74fec b97e015 0b572cf b97e015 504216e b97e015 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import evaluate
import datasets
from typing import Union, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tqdm import tqdm
_DESCRIPTION = """
Perplexity metric implemented by d-Matrix.
Perplexity (PPL) is one of the most common metrics for evaluating language models.
It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
For more information, see https://huggingface.co/docs/transformers/perplexity
"""
_KWARGS_DESCRIPTION = """
Args:
model (Union[str,AutoModelForCausalLM]): model used for calculating Perplexity
NOTE: Perplexity can only be calculated for causal language models.
This includes models such as gpt2, causal variations of bert,
causal versions of t5, and more (the full list can be found
in the AutoModelForCausalLM documentation here:
https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
references (list of str): input text, each separate text snippet is one list entry.
device (str): device to run on, defaults to 'cuda' when available.
max_length (int): maximum sequence length, defaults to 2048.
Returns:
perplexity: dictionary containing the perplexity score and loss.
Examples:
Example:
>>> from datasets import load_dataset
>>> perplexity = evaluate.load("dmx_perplexity", module_type="metric")
>>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
>>> results = perplexity.compute(model='distilgpt2',
... references=input_texts)
>>> print(list(results.keys()))
['loss', 'perplexity']
>>> print(results['loss']) # doctest: +SKIP
3.8299286365509033
>>> print(results['perplexity']) # doctest: +SKIP
46.05925369262695
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class DmxPerplexity(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
module_type="metric",
description=_DESCRIPTION,
citation="",
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"references": datasets.Value("string"),
}
),
reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
)
def _compute(
self,
references,
model: Union[str, AutoModelForCausalLM],
device=None,
max_length=None,
**kwargs,
):
if device is not None:
assert device in [
"gpu",
"cpu",
"cuda",
], "device should be either gpu or cpu."
if device == "gpu":
device = "cuda"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(model, str):
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model)
else:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path,**kwargs)
if max_length:
max_seq_len = max_length
elif hasattr(model.config, "max_position_embeddings"):
max_seq_len = model.config.max_position_embeddings
elif hasattr(model.config, "n_positions"):
max_seq_len = model.config.n_positions
else:
max_seq_len = 2048
if not hasattr(model, "hf_device_map") and (
not hasattr(model, "model_parallel") or not model.model_parallel
):
model = model.to(device)
model.eval()
encodings = tokenizer("\n\n".join(references), return_tensors="pt")
stride = max_seq_len
seq_len = encodings.input_ids.size(1)
seq_len = (seq_len // stride) * stride
nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
end_loc = min(begin_loc + max_seq_len, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
if isinstance(outputs, Dict):
neg_log_likelihood = outputs["loss"] * trg_len
else:
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood.to(device))
prev_end_loc = end_loc
if end_loc == seq_len:
break
loss = torch.stack(nlls).float().sum() / end_loc
ppl = torch.exp(loss)
return dict(
loss=loss.item(),
perplexity=ppl.item(),
)
|