""" |
Default trainer class for training models |
""" |
from collections import OrderedDict |
from os.path import join |
from argparse import ArgumentParser |
from tqdm import tqdm |
import pandas as pd |
import torch |
import torch.nn as nn |
from torch.utils.data import DataLoader |
from torch.optim import Optimizer |
from torch.optim.lr_scheduler import LRScheduler |
from .optim import get_optimizer, get_scheduler |
from .utils import decode_samples |
class OurTrainer(): |
""" |
Basic parent trainer class. Defaults to language modeling. |
-> Replacement for Hugging Face Trainer |
""" |
def __init__(self, |
model: nn.Module, |
args: ArgumentParser, |
train_loader: DataLoader, |
eval_loader: DataLoader, |
optimizer_and_scheduler: tuple[Optimizer, LRScheduler], |
device: torch.device, |
wandb, |
checkpoint_suffix: str = None, |
save_checkpoints: bool = True, |
save_results: bool = True, |
optimizer_args: dict = None, |
lr_scheduler_args: dict = None, |
greater_is_better: bool = False, |
metric_for_best_model: str = 'eval/loss', |
num_train_epochs: int = 2, |
gradient_accumulation_steps: int = 1, |
evaluation_strategy: str = 'steps', |
load_best_model_at_end: bool = True, |
logging_steps: int = 100, |
max_steps: int = -1, |
eval_steps: int = 100, |
max_eval_batches: int = -1, |
print_samples: bool = False, |
initial_eval: bool = True, |
num_save_ckpt_steps: int = 1000, |
**kwargs: any): |
super().__init__() |
self.model = model |
self.step = 0 |
self.grad_step = 0 |
self.compute_loss_backprop = False |
if optimizer_and_scheduler is None: |
assert optimizer_args is not None and lr_scheduler_args is not None |
self.optimizer = get_optimizer(model=self.model, **optimizer_args) |
self.scheduler = get_scheduler(optimizer=self.optimizer, **lr_scheduler_args) |
else: |
self.optimizer, self.scheduler = optimizer_and_scheduler |
try: |
self.scheduler_step_after_epoch = 'plateau' in args.lr_scheduler['lr_scheduler_type'] |
except KeyError: |
self.scheduler_step_after_epoch = False |
self.train_loader = train_loader |
self.eval_loader = eval_loader |
self.device = device |
self.wandb = wandb |
self.metric_for_best_model = metric_for_best_model |
self.num_train_epochs = num_train_epochs |
self.gradient_accumulation_steps = gradient_accumulation_steps |
self.evaluation_strategy = evaluation_strategy |
self.greater_is_better = greater_is_better |
self.is_better = (lambda x, y: x > y if greater_is_better else x < y) |
self.load_best_model_at_end = load_best_model_at_end |
self.logging_steps = logging_steps |
self.max_steps = max_steps |
self.eval_steps = eval_steps |
self.max_eval_batches = max_eval_batches |
self.print_samples = print_samples |
self.initial_eval = initial_eval |
self.num_save_ckpt_steps = num_save_ckpt_steps |
self.train_metrics = {'train/loss': None, |
'train/epoch': None, |
'train/step': None} |
self.eval_metrics = {metric_for_best_model: None} |
self.eval_metrics_by_step = {'eval_step': []} |
self.criterion = nn.CrossEntropyLoss(reduction='mean') |
try: |
self.tokenizer = self.train_loader.dataset.tokenizer |
except AttributeError: |
self.tokenizer = None |
self.save_results = save_results |
self.results_path = None |
self.best_val_metric = 0 if greater_is_better else 1e10 |
self.best_val_metric_epoch = 0 |
self.best_val_metric_step = 0 |
if save_checkpoints: |
self.init_checkpointing(args=args, checkpoint_suffix=checkpoint_suffix) |
def train(self) -> nn.Module: |
""" |
Entire training run |
""" |
model = self.model |
pbar = tqdm(range(self.num_train_epochs), leave=False, colour='white', |
desc='Training') |
for ix, epoch in enumerate(pbar): |
model, early_stopping = self.train_step(model, epoch) |
if self.evaluation_strategy == 'epoch': |
_eval_metrics = self.eval_step(model, step=self.grad_step) |
print(f'Epoch {ix} metrics:', _eval_metrics) |
if early_stopping: |
break |
if self.load_best_model_at_end: |
try: |
state_dict = torch.load(self.best_val_checkpoint_path)['model_state_dict'] |
model.load_state_dict(state_dict, strict=False) |
print(f'-> Loading best checkpoint from {self.best_val_checkpoint_path}') |
except FileNotFoundError as e: |
print(e) |
print('-> Returning most recent model instead') |
return model |
def train_step(self, model: nn.Module, epoch: int) -> nn.Module: |
""" |
Training loop over one epoch |
""" |
if self.gradient_accumulation_steps is None: |
accum_iter = 1 |
else: |
accum_iter = self.gradient_accumulation_steps |
model.train() |
model.zero_grad() |
pbar = tqdm(self.train_loader, leave=False, colour='blue', |
desc=f'-> Training (epoch {epoch} / {self.num_train_epochs})') |
total_loss = 0 |
eval_for_step = False |
if self.initial_eval: |
print('') |
print('-> Initial eval') |
self.compute_eval_metrics(model, step=self.grad_step) |
for ix, data in enumerate(pbar): |
loss, train_metrics = self.compute_loss(model, data, |
sample_idx=ix) |
loss /= accum_iter |
if not self.compute_loss_backprop: |
try: |
with torch.autograd.set_detect_anomaly(True): |
loss.backward() |
except Exception as e: |
breakpoint() |
if (self.step + 1) % accum_iter == 0: |
self.optimizer.step() |
if not self.scheduler_step_after_epoch and self.scheduler is not None: |
self.scheduler.step() |
self.optimizer.zero_grad() |
self.grad_step += 1 |
if not self.compute_loss_backprop: |
loss = loss.detach().cpu().item() |
self.step += 1 |
if not isinstance(loss, float): |
total_loss += loss.item() |
else: |
total_loss += loss |
desc = f"Training epoch {epoch} | loss: {total_loss / (ix + 1):.3f} | lr: {self.optimizer.param_groups[0]['lr']:.5f}" |
desc += f' | gradient step: {self.grad_step}' |
for k, v in train_metrics.items(): |
desc += f' | {k}: {v:.3f}' |
pbar.set_description(desc) |
if (self.grad_step) % (self.logging_steps): |
self.train_metrics['train/loss'] = loss.item() if not isinstance(loss, float) else loss |
self.train_metrics['train/epoch'] = epoch |
self.train_metrics['train/step'] = self.grad_step |
self.train_metrics['train/lr'] = self.optimizer.param_groups[0]['lr'] |
for k, v in train_metrics.items(): |
self.train_metrics[f'train/{k}'] = v |
if self.wandb is not None: |
self.wandb.log(self.train_metrics, step=self.grad_step) |
if self.evaluation_strategy == 'steps': |
if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step): |
_eval_metrics = self.eval_step(model, step=self.grad_step) |
print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) |
eval_for_step = True |
model.train() |
elif self.grad_step == 0 and self.num_save_ckpt_steps < 1000 and not eval_for_step: |
_eval_metrics = self.eval_step(model, step=self.grad_step) |
print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) |
eval_for_step = True |
model.train() |
elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step: |
pass |
else: |
if self.grad_step > 0: |
eval_for_step = False |
if self.grad_step == self.max_steps: |
early_stopping = True |
return model, early_stopping |
early_stopping = False |
return model, early_stopping |
def eval_step(self, model: nn.Module, step: int = None, |
**kwargs: any) -> dict[any]: |
""" |
Evaluation loop over one epoch |
""" |
with torch.no_grad(): |
self.eval_metrics = self.compute_eval_metrics(model, step=step, **kwargs) |
val_metric = self.eval_metrics[self.metric_for_best_model] |
if self.wandb is not None: |
self.wandb.log(self.eval_metrics, step=self.grad_step) |
if self.results_path is not None: |
self.eval_metrics_by_step['eval_step'].append(step) |
for k, v in self.eval_metrics.items(): |
if k not in self.eval_metrics_by_step: |
self.eval_metrics_by_step[k] = [v] |
else: |
self.eval_metrics_by_step[k].append(v) |
pd.DataFrame(self.eval_metrics_by_step).to_csv(self.results_path) |
if self.grad_step % self.eval_steps == 0: |
if self.is_better(val_metric, self.best_val_metric): |
self.best_val_metric = val_metric |
self.best_val_metric_step = self.grad_step |
torch.save({ |
'model_state_dict': self.save_trainable_weights(model), |
'step': self.grad_step, |
self.metric_for_best_model: val_metric |
}, self.best_val_checkpoint_path) |
print(f'\n-> Saved best model checkpoint to: {self.best_val_checkpoint_path}!') |
if self.grad_step % self.num_save_ckpt_steps == 0: |
save_path = self.best_val_checkpoint_path.replace('.pt', f'_{self.grad_step}.pt') |
torch.save({ |
'model_state_dict': self.save_trainable_weights(model), |
'step': self.grad_step, |
self.metric_for_best_model: val_metric |
}, save_path) |
print(f'\n-> Saved best model checkpoint to: {save_path}!') |
if self.scheduler_step_after_epoch and self.scheduler is not None: |
self.scheduler.step(val_metric) |
return self.eval_metrics |
def compute_eval_metrics(self, |
model: nn.Module, step: int, |
max_batches: int = None, |
dataloader: DataLoader = None, |
**kwargs: any) -> dict[any]: |
""" |
One evaluation loop over a validation dataset |
""" |
max_batches = (self.max_eval_batches if max_batches is None else max_batches) |
dataloader = self.eval_loader if dataloader is None else dataloader |
pbar = tqdm(dataloader, leave=False, colour='green', |
desc=f'Evaluating at step {step}') |
model.eval() |
step_loss = 0 |
step_eval_metrics = {} |
with torch.no_grad(): |
for ix, data in enumerate(pbar): |
loss, eval_metrics = self.compute_loss(model, data) |
if not self.compute_loss_backprop: |
loss = loss.item() |
if ix == 0: |
step_eval_metrics[self.metric_for_best_model] = [loss] |
for k, v in eval_metrics.items(): |
step_eval_metrics[f'eval/{k}'] = [v] |
else: |
step_eval_metrics[self.metric_for_best_model].append(loss) |
for k, v in eval_metrics.items(): |
step_eval_metrics[f'eval/{k}'].append(v) |
step_loss += loss |
desc = f"Evaluating at step {step} | loss: {step_loss / (ix + 1):.3f}" |
if self.optimizer is not None: |
desc += f" | lr: {self.optimizer.param_groups[0]['lr']:.5f}" |
pbar.set_description(desc) |
if ix == max_batches: |
break |
for k, v in step_eval_metrics.items(): |
step_eval_metrics[k] = sum(v) / len(v) |
print(f'Eval step {step}:', step_eval_metrics) |
del loss |
torch.cuda.empty_cache() |
return step_eval_metrics |
def compute_loss(self, model: nn.Module, data: torch.Tensor, |
sample_idx: int = None, **kwargs: any, |
) -> tuple[torch.Tensor, dict[any]]: |
""" |
Main method to determine how models are trained. |
-> Defaults to next-token prediction / classification, |
but override in child classes |
Args: |
- model: nn.Module, HF model to train |
- data: dict[torch.Tensor], HF datasets batch of data |
- sample_idx: int, index of batch in dataset |
""" |
input_keys = {'input_ids', 'attention_mask'} |
inputs = {k: v.to(model.device) |
for k, v in data.items() if k in input_keys} |
outputs = model(**inputs, output_attentions=False, use_cache=False) |
outputs = outputs.get('logits')[..., :-1, :].contiguous() |
targets = data.get('labels')[..., 1:].contiguous() |
if self.print_samples and sample_idx is not None and (sample_idx + 1) % 100 == 0: |
decode_samples(outputs, targets, self.tokenizer, sample_idx) |
outputs = outputs.view(-1, outputs.shape[-1]) |
targets = targets.view(-1).to(outputs.device) |
try: |
loss = self.criterion(outputs, targets) |
except Exception as e: |
print('outputs.shape', outputs.shape) |
print('targets.shape', targets.shape) |
raise e |
targets = targets.cpu() |
outputs = outputs.cpu() |
return loss, {'ppl': torch.exp(loss).item(), 'seq_len': targets.shape[-1] + 1} |
def save_trainable_weights(self, model: nn.Module): |
""" |
Save checkpoint with only weights actively being trained (e.g., for adapters). |
Make sure to later load with model.load_state_dict(state_dict, strict=False) |
""" |
with torch.no_grad(): |
state_dict = OrderedDict() |
for n, p in model.named_parameters(): |
if p.requires_grad: |
state_dict[n] = p.cpu() |
return state_dict |
def init_checkpointing(self, |
args: ArgumentParser, |
checkpoint_suffix: str) -> None: |
""" |
Initialize checkpointing attributes |
Inputs: |
- args: Argparse or HuggingFace TrainingArguments object |
- checkpoint_suffix: str to append to checkpoint name |
""" |
self.best_val_checkpoint_path = f'{join(args.checkpoint_dir, args.run_name)}.pt' |
if checkpoint_suffix is not None: |
self.best_val_checkpoint_path = self.best_val_checkpoint_path.replace( |
'.pt', f'{checkpoint_suffix}.pt') |
print(f'-> Saving best model checkpoint to {self.best_val_checkpoint_path}') |
if self.save_results: |
self.results_path = self.best_val_checkpoint_path.replace( |
'.pt', '.csv').replace(args.checkpoint_dir, args.results_dir) |
print(f'-> Saving results to {self.results_path}') |
self.best_val_metric = 0 if self.greater_is_better else 1e10 |
self.best_val_metric_epoch = 0 |
self.best_val_metric_step = 0 |
self.best_train_metric = 0 if self.greater_is_better else 1e10 |
self.best_train_metric_epoch = 0 |
self.best_train_metric_step = 0 |
self.metric_for_best_model = self.metric_for_best_model |
if self.metric_for_best_model is not None: |
if 'eval' not in self.metric_for_best_model: |
self.metric_for_best_model = f'eval/{self.metric_for_best_model}' |