|
""" |
|
Default trainer class for training models |
|
""" |
|
from collections import OrderedDict |
|
from os.path import join |
|
from argparse import ArgumentParser |
|
from tqdm import tqdm |
|
|
|
import pandas as pd |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import LRScheduler |
|
|
|
from .optim import get_optimizer, get_scheduler |
|
from .utils import decode_samples |
|
|
|
|
|
class OurTrainer(): |
|
""" |
|
Basic parent trainer class. Defaults to language modeling. |
|
-> Replacement for Hugging Face Trainer |
|
""" |
|
def __init__(self, |
|
model: nn.Module, |
|
args: ArgumentParser, |
|
train_loader: DataLoader, |
|
eval_loader: DataLoader, |
|
optimizer_and_scheduler: tuple[Optimizer, LRScheduler], |
|
device: torch.device, |
|
wandb, |
|
checkpoint_suffix: str = None, |
|
save_checkpoints: bool = True, |
|
save_results: bool = True, |
|
|
|
optimizer_args: dict = None, |
|
lr_scheduler_args: dict = None, |
|
greater_is_better: bool = False, |
|
metric_for_best_model: str = 'eval/loss', |
|
num_train_epochs: int = 2, |
|
gradient_accumulation_steps: int = 1, |
|
evaluation_strategy: str = 'steps', |
|
load_best_model_at_end: bool = True, |
|
logging_steps: int = 100, |
|
max_steps: int = -1, |
|
eval_steps: int = 100, |
|
max_eval_batches: int = -1, |
|
print_samples: bool = False, |
|
initial_eval: bool = True, |
|
num_save_ckpt_steps: int = 1000, |
|
**kwargs: any): |
|
super().__init__() |
|
self.model = model |
|
self.step = 0 |
|
self.grad_step = 0 |
|
self.compute_loss_backprop = False |
|
|
|
if optimizer_and_scheduler is None: |
|
assert optimizer_args is not None and lr_scheduler_args is not None |
|
self.optimizer = get_optimizer(model=self.model, **optimizer_args) |
|
self.scheduler = get_scheduler(optimizer=self.optimizer, **lr_scheduler_args) |
|
else: |
|
self.optimizer, self.scheduler = optimizer_and_scheduler |
|
try: |
|
self.scheduler_step_after_epoch = 'plateau' in args.lr_scheduler['lr_scheduler_type'] |
|
except KeyError: |
|
self.scheduler_step_after_epoch = False |
|
|
|
|
|
self.train_loader = train_loader |
|
self.eval_loader = eval_loader |
|
|
|
self.device = device |
|
self.wandb = wandb |
|
|
|
|
|
self.metric_for_best_model = metric_for_best_model |
|
self.num_train_epochs = num_train_epochs |
|
self.gradient_accumulation_steps = gradient_accumulation_steps |
|
self.evaluation_strategy = evaluation_strategy |
|
self.greater_is_better = greater_is_better |
|
self.is_better = (lambda x, y: x > y if greater_is_better else x < y) |
|
self.load_best_model_at_end = load_best_model_at_end |
|
self.logging_steps = logging_steps |
|
self.max_steps = max_steps |
|
self.eval_steps = eval_steps |
|
self.max_eval_batches = max_eval_batches |
|
self.print_samples = print_samples |
|
self.initial_eval = initial_eval |
|
self.num_save_ckpt_steps = num_save_ckpt_steps |
|
|
|
|
|
self.train_metrics = {'train/loss': None, |
|
'train/epoch': None, |
|
'train/step': None} |
|
self.eval_metrics = {metric_for_best_model: None} |
|
self.eval_metrics_by_step = {'eval_step': []} |
|
self.criterion = nn.CrossEntropyLoss(reduction='mean') |
|
try: |
|
self.tokenizer = self.train_loader.dataset.tokenizer |
|
except AttributeError: |
|
self.tokenizer = None |
|
|
|
self.save_results = save_results |
|
self.results_path = None |
|
self.best_val_metric = 0 if greater_is_better else 1e10 |
|
self.best_val_metric_epoch = 0 |
|
self.best_val_metric_step = 0 |
|
if save_checkpoints: |
|
self.init_checkpointing(args=args, checkpoint_suffix=checkpoint_suffix) |
|
|
|
def train(self) -> nn.Module: |
|
""" |
|
Entire training run |
|
""" |
|
model = self.model |
|
pbar = tqdm(range(self.num_train_epochs), leave=False, colour='white', |
|
desc='Training') |
|
for ix, epoch in enumerate(pbar): |
|
model, early_stopping = self.train_step(model, epoch) |
|
if self.evaluation_strategy == 'epoch': |
|
_eval_metrics = self.eval_step(model, step=self.grad_step) |
|
print(f'Epoch {ix} metrics:', _eval_metrics) |
|
if early_stopping: |
|
break |
|
|
|
if self.load_best_model_at_end: |
|
try: |
|
state_dict = torch.load(self.best_val_checkpoint_path)['model_state_dict'] |
|
model.load_state_dict(state_dict, strict=False) |
|
print(f'-> Loading best checkpoint from {self.best_val_checkpoint_path}') |
|
except FileNotFoundError as e: |
|
print(e) |
|
print('-> Returning most recent model instead') |
|
return model |
|
|
|
def train_step(self, model: nn.Module, epoch: int) -> nn.Module: |
|
""" |
|
Training loop over one epoch |
|
""" |
|
if self.gradient_accumulation_steps is None: |
|
accum_iter = 1 |
|
else: |
|
accum_iter = self.gradient_accumulation_steps |
|
|
|
model.train() |
|
model.zero_grad() |
|
pbar = tqdm(self.train_loader, leave=False, colour='blue', |
|
desc=f'-> Training (epoch {epoch} / {self.num_train_epochs})') |
|
total_loss = 0 |
|
eval_for_step = False |
|
|
|
|
|
if self.initial_eval: |
|
print('') |
|
print('-> Initial eval') |
|
self.compute_eval_metrics(model, step=self.grad_step) |
|
|
|
|
|
for ix, data in enumerate(pbar): |
|
loss, train_metrics = self.compute_loss(model, data, |
|
sample_idx=ix) |
|
loss /= accum_iter |
|
if not self.compute_loss_backprop: |
|
|
|
try: |
|
with torch.autograd.set_detect_anomaly(True): |
|
loss.backward() |
|
except Exception as e: |
|
breakpoint() |
|
if (self.step + 1) % accum_iter == 0: |
|
self.optimizer.step() |
|
if not self.scheduler_step_after_epoch and self.scheduler is not None: |
|
self.scheduler.step() |
|
self.optimizer.zero_grad() |
|
self.grad_step += 1 |
|
if not self.compute_loss_backprop: |
|
loss = loss.detach().cpu().item() |
|
|
|
self.step += 1 |
|
if not isinstance(loss, float): |
|
total_loss += loss.item() |
|
else: |
|
total_loss += loss |
|
desc = f"Training epoch {epoch} | loss: {total_loss / (ix + 1):.3f} | lr: {self.optimizer.param_groups[0]['lr']:.5f}" |
|
desc += f' | gradient step: {self.grad_step}' |
|
for k, v in train_metrics.items(): |
|
desc += f' | {k}: {v:.3f}' |
|
pbar.set_description(desc) |
|
|
|
|
|
if (self.grad_step) % (self.logging_steps): |
|
self.train_metrics['train/loss'] = loss.item() if not isinstance(loss, float) else loss |
|
self.train_metrics['train/epoch'] = epoch |
|
self.train_metrics['train/step'] = self.grad_step |
|
self.train_metrics['train/lr'] = self.optimizer.param_groups[0]['lr'] |
|
for k, v in train_metrics.items(): |
|
self.train_metrics[f'train/{k}'] = v |
|
|
|
if self.wandb is not None: |
|
self.wandb.log(self.train_metrics, step=self.grad_step) |
|
|
|
if self.evaluation_strategy == 'steps': |
|
if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step): |
|
_eval_metrics = self.eval_step(model, step=self.grad_step) |
|
print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) |
|
eval_for_step = True |
|
model.train() |
|
elif self.grad_step == 0 and self.num_save_ckpt_steps < 1000 and not eval_for_step: |
|
_eval_metrics = self.eval_step(model, step=self.grad_step) |
|
print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) |
|
eval_for_step = True |
|
model.train() |
|
|
|
elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step: |
|
pass |
|
else: |
|
if self.grad_step > 0: |
|
eval_for_step = False |
|
if self.grad_step == self.max_steps: |
|
early_stopping = True |
|
return model, early_stopping |
|
|
|
early_stopping = False |
|
return model, early_stopping |
|
|
|
def eval_step(self, model: nn.Module, step: int = None, |
|
**kwargs: any) -> dict[any]: |
|
""" |
|
Evaluation loop over one epoch |
|
""" |
|
with torch.no_grad(): |
|
self.eval_metrics = self.compute_eval_metrics(model, step=step, **kwargs) |
|
val_metric = self.eval_metrics[self.metric_for_best_model] |
|
|
|
|
|
if self.wandb is not None: |
|
self.wandb.log(self.eval_metrics, step=self.grad_step) |
|
|
|
if self.results_path is not None: |
|
self.eval_metrics_by_step['eval_step'].append(step) |
|
for k, v in self.eval_metrics.items(): |
|
if k not in self.eval_metrics_by_step: |
|
self.eval_metrics_by_step[k] = [v] |
|
else: |
|
self.eval_metrics_by_step[k].append(v) |
|
|
|
pd.DataFrame(self.eval_metrics_by_step).to_csv(self.results_path) |
|
|
|
|
|
if self.grad_step % self.eval_steps == 0: |
|
if self.is_better(val_metric, self.best_val_metric): |
|
self.best_val_metric = val_metric |
|
self.best_val_metric_step = self.grad_step |
|
|
|
torch.save({ |
|
'model_state_dict': self.save_trainable_weights(model), |
|
'step': self.grad_step, |
|
self.metric_for_best_model: val_metric |
|
}, self.best_val_checkpoint_path) |
|
print(f'\n-> Saved best model checkpoint to: {self.best_val_checkpoint_path}!') |
|
|
|
if self.grad_step % self.num_save_ckpt_steps == 0: |
|
save_path = self.best_val_checkpoint_path.replace('.pt', f'_{self.grad_step}.pt') |
|
torch.save({ |
|
'model_state_dict': self.save_trainable_weights(model), |
|
'step': self.grad_step, |
|
self.metric_for_best_model: val_metric |
|
}, save_path) |
|
print(f'\n-> Saved best model checkpoint to: {save_path}!') |
|
|
|
if self.scheduler_step_after_epoch and self.scheduler is not None: |
|
self.scheduler.step(val_metric) |
|
return self.eval_metrics |
|
|
|
def compute_eval_metrics(self, |
|
model: nn.Module, step: int, |
|
max_batches: int = None, |
|
dataloader: DataLoader = None, |
|
**kwargs: any) -> dict[any]: |
|
""" |
|
One evaluation loop over a validation dataset |
|
""" |
|
max_batches = (self.max_eval_batches if max_batches is None else max_batches) |
|
dataloader = self.eval_loader if dataloader is None else dataloader |
|
pbar = tqdm(dataloader, leave=False, colour='green', |
|
desc=f'Evaluating at step {step}') |
|
|
|
model.eval() |
|
step_loss = 0 |
|
step_eval_metrics = {} |
|
with torch.no_grad(): |
|
for ix, data in enumerate(pbar): |
|
loss, eval_metrics = self.compute_loss(model, data) |
|
if not self.compute_loss_backprop: |
|
loss = loss.item() |
|
if ix == 0: |
|
step_eval_metrics[self.metric_for_best_model] = [loss] |
|
for k, v in eval_metrics.items(): |
|
step_eval_metrics[f'eval/{k}'] = [v] |
|
else: |
|
step_eval_metrics[self.metric_for_best_model].append(loss) |
|
for k, v in eval_metrics.items(): |
|
step_eval_metrics[f'eval/{k}'].append(v) |
|
|
|
step_loss += loss |
|
desc = f"Evaluating at step {step} | loss: {step_loss / (ix + 1):.3f}" |
|
if self.optimizer is not None: |
|
desc += f" | lr: {self.optimizer.param_groups[0]['lr']:.5f}" |
|
pbar.set_description(desc) |
|
if ix == max_batches: |
|
break |
|
|
|
|
|
for k, v in step_eval_metrics.items(): |
|
step_eval_metrics[k] = sum(v) / len(v) |
|
print(f'Eval step {step}:', step_eval_metrics) |
|
del loss |
|
torch.cuda.empty_cache() |
|
return step_eval_metrics |
|
|
|
def compute_loss(self, model: nn.Module, data: torch.Tensor, |
|
sample_idx: int = None, **kwargs: any, |
|
) -> tuple[torch.Tensor, dict[any]]: |
|
""" |
|
Main method to determine how models are trained. |
|
-> Defaults to next-token prediction / classification, |
|
but override in child classes |
|
|
|
Args: |
|
- model: nn.Module, HF model to train |
|
- data: dict[torch.Tensor], HF datasets batch of data |
|
- sample_idx: int, index of batch in dataset |
|
""" |
|
input_keys = {'input_ids', 'attention_mask'} |
|
inputs = {k: v.to(model.device) |
|
for k, v in data.items() if k in input_keys} |
|
|
|
outputs = model(**inputs, output_attentions=False, use_cache=False) |
|
|
|
outputs = outputs.get('logits')[..., :-1, :].contiguous() |
|
targets = data.get('labels')[..., 1:].contiguous() |
|
|
|
|
|
if self.print_samples and sample_idx is not None and (sample_idx + 1) % 100 == 0: |
|
decode_samples(outputs, targets, self.tokenizer, sample_idx) |
|
|
|
|
|
outputs = outputs.view(-1, outputs.shape[-1]) |
|
targets = targets.view(-1).to(outputs.device) |
|
try: |
|
loss = self.criterion(outputs, targets) |
|
except Exception as e: |
|
print('outputs.shape', outputs.shape) |
|
print('targets.shape', targets.shape) |
|
raise e |
|
|
|
targets = targets.cpu() |
|
outputs = outputs.cpu() |
|
return loss, {'ppl': torch.exp(loss).item(), 'seq_len': targets.shape[-1] + 1} |
|
|
|
def save_trainable_weights(self, model: nn.Module): |
|
""" |
|
Save checkpoint with only weights actively being trained (e.g., for adapters). |
|
Make sure to later load with model.load_state_dict(state_dict, strict=False) |
|
""" |
|
with torch.no_grad(): |
|
state_dict = OrderedDict() |
|
for n, p in model.named_parameters(): |
|
if p.requires_grad: |
|
state_dict[n] = p.cpu() |
|
return state_dict |
|
|
|
def init_checkpointing(self, |
|
args: ArgumentParser, |
|
checkpoint_suffix: str) -> None: |
|
""" |
|
Initialize checkpointing attributes |
|
|
|
Inputs: |
|
- args: Argparse or HuggingFace TrainingArguments object |
|
- checkpoint_suffix: str to append to checkpoint name |
|
""" |
|
self.best_val_checkpoint_path = f'{join(args.checkpoint_dir, args.run_name)}.pt' |
|
if checkpoint_suffix is not None: |
|
self.best_val_checkpoint_path = self.best_val_checkpoint_path.replace( |
|
'.pt', f'{checkpoint_suffix}.pt') |
|
print(f'-> Saving best model checkpoint to {self.best_val_checkpoint_path}') |
|
if self.save_results: |
|
self.results_path = self.best_val_checkpoint_path.replace( |
|
'.pt', '.csv').replace(args.checkpoint_dir, args.results_dir) |
|
print(f'-> Saving results to {self.results_path}') |
|
|
|
|
|
self.best_val_metric = 0 if self.greater_is_better else 1e10 |
|
self.best_val_metric_epoch = 0 |
|
self.best_val_metric_step = 0 |
|
self.best_train_metric = 0 if self.greater_is_better else 1e10 |
|
self.best_train_metric_epoch = 0 |
|
self.best_train_metric_step = 0 |
|
self.metric_for_best_model = self.metric_for_best_model |
|
if self.metric_for_best_model is not None: |
|
if 'eval' not in self.metric_for_best_model: |
|
self.metric_for_best_model = f'eval/{self.metric_for_best_model}' |
|
|