Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from accelerate import Accelerator | |
from arguments import EvaluationArguments | |
from datasets import load_dataset | |
from torch.utils.data import IterableDataset | |
from torch.utils.data.dataloader import DataLoader | |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed | |
class ConstantLengthDataset(IterableDataset): | |
def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6): | |
self.tokenizer = tokenizer | |
self.concat_token_id = tokenizer.bos_token_id | |
self.dataset = dataset | |
self.seq_length = seq_length | |
self.input_characters = seq_length * chars_per_token * num_of_sequences | |
def __iter__(self): | |
iterator = iter(self.dataset) | |
more_examples = True | |
while more_examples: | |
buffer, buffer_len = [], 0 | |
while True: | |
if buffer_len >= self.input_characters: | |
break | |
try: | |
buffer.append(next(iterator)["content"]) | |
buffer_len += len(buffer[-1]) | |
except StopIteration: | |
more_examples = False | |
break | |
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"] | |
all_token_ids = [] | |
for tokenized_input in tokenized_inputs: | |
all_token_ids.extend(tokenized_input + [self.concat_token_id]) | |
for i in range(0, len(all_token_ids), self.seq_length): | |
input_ids = all_token_ids[i : i + self.seq_length] | |
if len(input_ids) == self.seq_length: | |
yield torch.tensor(input_ids) | |
def create_dataloader(args): | |
ds_kwargs = {"streaming": True} | |
valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs) | |
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length) | |
eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size) | |
return eval_dataloader | |
def evaluate(args): | |
model.eval() | |
losses = [] | |
for step, batch in enumerate(eval_dataloader): | |
with torch.no_grad(): | |
outputs = model(batch, labels=batch) | |
loss = outputs.loss.repeat(args.batch_size) | |
losses.append(accelerator.gather(loss)) | |
if args.max_eval_steps > 0 and step >= args.max_eval_steps: | |
break | |
loss = torch.mean(torch.cat(losses)) | |
try: | |
perplexity = torch.exp(loss) | |
except OverflowError: | |
perplexity = float("inf") | |
return loss.item(), perplexity.item() | |
# Setup Accelerator | |
accelerator = Accelerator() | |
# Parse configuration | |
parser = HfArgumentParser(EvaluationArguments) | |
args = parser.parse_args() | |
set_seed(args.seed) | |
# Logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO | |
) | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt) | |
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt) | |
# Load dataset and dataloader | |
eval_dataloader = create_dataloader(args) | |
# Prepare everything with our `accelerator`. | |
model, eval_dataloader = accelerator.prepare(model, eval_dataloader) | |
# Evaluate and save the last checkpoint | |
logger.info("Evaluating and saving model after training") | |
eval_loss, perplexity = evaluate(args) | |
logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}") | |