import evaluate | |
from datasets import load_dataset | |
from transformers import AutoModelForCausalLM | |
perplexity = evaluate.load("d-matrix/dmx_perplexity", module_type="metric") | |
input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] | |
model = AutoModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path="d-matrix/gpt-j-6b", | |
# trust_remote_code=True, | |
device_map="auto", | |
use_auth_token=True, | |
) | |
results = perplexity.compute(model=model, references=input_texts) | |