File size: 5,465 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
General seq2seq / input-output trainer
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from .default_lm import OurTrainer as DefaultTrainer
from .utils import replace_padding_tokens
def compute_scrolls_metrics(eval_preds, scrolls_metric, tokenizer):
"""
Function to compute metrics that are also in SCROLLS (ROUGE, F1, etc.)
"""
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = replace_padding_tokens(preds, tokenizer.pad_token_id)
labels = replace_padding_tokens(labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Scrolls metric expects predictions to be [pred_1, pred_2, ...]
# and references to be [[ref_1], [ref_2], ... ]
decoded_labels = [[s] for s in decoded_labels]
result = scrolls_metric.compute(predictions=decoded_preds,
references=decoded_labels)
print('----------------')
print('Model generation')
print(decoded_preds[:10])
print('----------------')
print('True answer')
print(decoded_labels[:10])
return result
class OurTrainer(DefaultTrainer):
"""
Evaluator for seq-to-seq / generation benchmarks
"""
def __init__(self, model, args, # max_eval_batches: Optional[int] = 100,
**kwargs: any):
super().__init__(model=model, args=args, **kwargs)
# Reset + determine metric for best automatically based on the dataset
self.metric_for_best = None
self.is_better = lambda x, y: x > y # Hardcode greater is better for now
self.print_steps = getattr(args, 'print_steps', 100)
print(f'self.print_steps:', self.print_steps)
# ablation sweep
self.max_eval_batches = 10
def init_criterion_(self):
pass
def compute_loss(self):
pass
def evaluate(self, *args: any, **kwargs: any):
return self.eval_step(*args, **kwargs)
def eval_step(self, model: nn.Module, step: int,
dataloader: DataLoader = None,
max_batches: int = None,
prefix: str = None,
**kwargs: any): # -1):
"""
One evaluation step
"""
total = 0
total_loss = 0
metrics = {}
max_batches = self.max_eval_batches if max_batches is None else max_batches
max_batches = 10 # ablation sweep
dataloader = (dataloader if dataloader is not None else self.eval_loader)
scrolls_metric = dataloader.dataset.metric # Should be assigned in dataset
tokenizer = dataloader.dataset.tokenizer
# Save decoded predictions and references here to compute average metrics
predictions, references = [], []
model.eval()
pbar = tqdm(dataloader, leave=False, colour='green',
desc=f'Evaluating at step {step}')
with torch.no_grad():
for ix, data in enumerate(pbar):
inputs = {k: v.to(self.device) for k, v in data.items()
if k in ['input_ids', 'attention_mask']}
labels = data['labels']
outputs = model.generate(**inputs,
max_new_tokens=1024, # hardcoded for now
pad_token_id=tokenizer.pad_token_id,
use_cache=True,).cpu()
# Only save newly generated tokens
pred_ids = outputs[:, data['input_ids'].shape[1]:]
predictions.append(pred_ids)
references.append(labels)
pbar.set_description(f"Evaluating at step {step} | input_len: {data['input_ids'].shape[1]} | output_len: {labels.shape[1]}")
if ix == max_batches:
break
if (ix + 1) % self.print_steps == 0: # 100 == 0:
print(f'Model input: \n', tokenizer.batch_decode(inputs['input_ids'].detach().cpu())[0])
print(f'Model output:\n', tokenizer.batch_decode(pred_ids)[0])
print(f'True output:\n', tokenizer.batch_decode(labels)[0])
# Compute and save metrics
try:
predictions = torch.cat(predictions, dim=0)
references = torch.cat(references, dim=0)
except:
pass
_metric = compute_scrolls_metrics((predictions, references),
scrolls_metric, tokenizer)
if self.metric_for_best is None: # Hard-coded for now
if 'f1' in _metric:
self.metric_for_best = f'eval/f1'
elif 'exact_match' in _metric:
self.metric_for_best = f'eval/exact_match'
elif 'rouge/geometric_mean' in _metric:
self.metric_for_best = f'eval/rouge/geometric_mean'
for k, v in _metric.items():
if 'display' not in k:
_k = f'{prefix}/eval/{k}' if prefix is not None else f'eval/{k}'
metrics[_k] = v
return metrics
|