StoryLlama / metric.py
YuvrajSingh9886's picture
Upload 12 files
5bb6ad4 verified
raw
history blame contribute delete
740 Bytes
import evaluate
from config import ModelArgs
from model import Llama
import evaluate
# Load the perplexity metric
perplexity = evaluate.load("perplexity")
def compute_perplexity(model_name, text):
results = perplexity.compute(predictions=[text], model_id=model_name)
return results["perplexities"][0]
# Example Usage
llama = Llama(device=ModelArgs.device, embeddings_dims=ModelArgs.embeddings_dims, no_of_decoder_layers=ModelArgs.no_of_decoder_layers, block_size=ModelArgs.block_size, vocab_size=ModelArgs.vocab_size, dropout=ModelArgs.dropout)
llama = llama.to(ModelArgs.device)
text = "This is an example sentence for perplexity calculation."
ppl = compute_perplexity(llama, text)
print(f"Perplexity: {ppl}")