lolcats / src /trainer /finetune_seq2seq.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
General seq2seq / input-output trainer
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from .default_lm import OurTrainer as DefaultTrainer
from .utils import replace_padding_tokens
def compute_scrolls_metrics(eval_preds, scrolls_metric, tokenizer):
"""
Function to compute metrics that are also in SCROLLS (ROUGE, F1, etc.)
"""
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = replace_padding_tokens(preds, tokenizer.pad_token_id)
labels = replace_padding_tokens(labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Scrolls metric expects predictions to be [pred_1, pred_2, ...]
# and references to be [[ref_1], [ref_2], ... ]
decoded_labels = [[s] for s in decoded_labels]
result = scrolls_metric.compute(predictions=decoded_preds,
references=decoded_labels)
print('----------------')
print('Model generation')
print(decoded_preds[:10])
print('----------------')
print('True answer')
print(decoded_labels[:10])
return result
class OurTrainer(DefaultTrainer):
"""
Evaluator for seq-to-seq / generation benchmarks
"""
def __init__(self, model, args, # max_eval_batches: Optional[int] = 100,
**kwargs: any):
super().__init__(model=model, args=args, **kwargs)
# Reset + determine metric for best automatically based on the dataset
self.metric_for_best = None
self.is_better = lambda x, y: x > y # Hardcode greater is better for now
self.print_steps = getattr(args, 'print_steps', 100)
print(f'self.print_steps:', self.print_steps)
# ablation sweep
self.max_eval_batches = 10
def init_criterion_(self):
pass
def compute_loss(self):
pass
def evaluate(self, *args: any, **kwargs: any):
return self.eval_step(*args, **kwargs)
def eval_step(self, model: nn.Module, step: int,
dataloader: DataLoader = None,
max_batches: int = None,
prefix: str = None,
**kwargs: any): # -1):
"""
One evaluation step
"""
total = 0
total_loss = 0
metrics = {}
max_batches = self.max_eval_batches if max_batches is None else max_batches
max_batches = 10 # ablation sweep
dataloader = (dataloader if dataloader is not None else self.eval_loader)
scrolls_metric = dataloader.dataset.metric # Should be assigned in dataset
tokenizer = dataloader.dataset.tokenizer
# Save decoded predictions and references here to compute average metrics
predictions, references = [], []
model.eval()
pbar = tqdm(dataloader, leave=False, colour='green',
desc=f'Evaluating at step {step}')
with torch.no_grad():
for ix, data in enumerate(pbar):
inputs = {k: v.to(self.device) for k, v in data.items()
if k in ['input_ids', 'attention_mask']}
labels = data['labels']
outputs = model.generate(**inputs,
max_new_tokens=1024, # hardcoded for now
pad_token_id=tokenizer.pad_token_id,
use_cache=True,).cpu()
# Only save newly generated tokens
pred_ids = outputs[:, data['input_ids'].shape[1]:]
predictions.append(pred_ids)
references.append(labels)
pbar.set_description(f"Evaluating at step {step} | input_len: {data['input_ids'].shape[1]} | output_len: {labels.shape[1]}")
if ix == max_batches:
break
if (ix + 1) % self.print_steps == 0: # 100 == 0:
print(f'Model input: \n', tokenizer.batch_decode(inputs['input_ids'].detach().cpu())[0])
print(f'Model output:\n', tokenizer.batch_decode(pred_ids)[0])
print(f'True output:\n', tokenizer.batch_decode(labels)[0])
# Compute and save metrics
try:
predictions = torch.cat(predictions, dim=0)
references = torch.cat(references, dim=0)
except:
pass
_metric = compute_scrolls_metrics((predictions, references),
scrolls_metric, tokenizer)
if self.metric_for_best is None: # Hard-coded for now
if 'f1' in _metric:
self.metric_for_best = f'eval/f1'
elif 'exact_match' in _metric:
self.metric_for_best = f'eval/exact_match'
elif 'rouge/geometric_mean' in _metric:
self.metric_for_best = f'eval/rouge/geometric_mean'
for k, v in _metric.items():
if 'display' not in k:
_k = f'{prefix}/eval/{k}' if prefix is not None else f'eval/{k}'
metrics[_k] = v
return metrics