Spaces:
Running
on
T4
Running
on
T4
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb. | |
# %% auto 0 | |
__all__ = [] | |
# %% ../nbs/B2. Training (Lightning).ipynb 2 | |
import io | |
import time | |
import random | |
from pathlib import Path | |
from fastprogress import progress_bar, master_bar | |
import fastprogress | |
import wandb | |
import numpy as np | |
import pylab as plt | |
import torch | |
import torch.nn as nn | |
from torch.utils.data.dataloader import DataLoader | |
from torch.profiler import record_function | |
# %% ../nbs/B2. Training (Lightning).ipynb 3 | |
import lightning.pytorch as pl | |
import math | |
class TrainingTask(pl.LightningModule): | |
def __init__(self, model, model_hparams=None): | |
super().__init__() | |
self.model = model | |
self.model_hparams = model_hparams | |
def on_fit_start(self): | |
if getattr(self.model, 'setup'): | |
self.model.setup(self.device) | |
def configure_optimizers(self): | |
""" Initialize AdamW optimizer""" | |
lr = self.model_hparams['lr0'] | |
weight_decay = self.model_hparams['weight_decay'] | |
all_params = set(model.parameters()) | |
customized_params = set() | |
groups = [] | |
group_map = {} | |
for name,m in model.named_modules(): | |
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): | |
customized_params |= set(m.parameters()) | |
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay | |
m_lr = lr * getattr(m, 'lr_scale', 1) | |
group = group_map.get((m_wd, m_lr), None) | |
if not group: | |
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} | |
groups.append(group) | |
group_map[(m_wd, m_lr)] = group | |
group['params'] += m.parameters() | |
group['names'].append(name) | |
other_params = all_params - customized_params | |
param_groups = groups + [ | |
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, | |
] | |
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups) | |
# modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319 | |
def num_steps_per_epoch() -> int: | |
"""Get number of steps""" | |
# Accessing _data_source is flaky and might break | |
dataset = self.trainer.fit_loop._data_source.dataloader() | |
dataset_size = len(dataset) | |
# math.ceil so always overestimate (underestimating throws exceptions) | |
num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches) | |
return num_steps | |
total_steps = self.model_hparams['epochs'] * num_steps_per_epoch() | |
self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps) | |
print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps") | |
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
pct_start=self.model_hparams['pct_start'], | |
max_lr=[pg.get('lr', lr) for pg in param_groups], | |
steps_per_epoch=num_steps_per_epoch(), | |
epochs=int(self.model_hparams['epochs']), | |
final_div_factor=25 | |
) | |
return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}] | |
def training_step(self, train_batch, batch_idx): | |
train_logits, train_loss = self.model.forward(*train_batch) | |
self.log("train_loss", train_loss, sync_dist=True) | |
return train_loss | |
def validation_step(self, val_batch, batch_idx): | |
val_logits, val_loss = self.model.forward(*val_batch) | |
self.log("val_loss", val_loss, sync_dist=True) | |
return val_loss | |
def on_validation_epoch_end(self): | |
if hasattr(self.model, 'get_metrics'): | |
self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True) | |
def test_step(self, val_batch, batch_idx): | |
test_logits, test_loss = self.model.forward(*val_batch) | |
self.log("test_loss", test_loss, sync_dist=True) | |
return test_loss | |
# %% ../nbs/B2. Training (Lightning).ipynb 4 | |
from fastcore.script import anno_parser | |
import shlex | |
# watch out: we can only pass Python values as keyword arguments (not positional) | |
# everything else has to be a string | |
def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True): | |
p = anno_parser(fun) | |
args = p.parse_args(args).__dict__ | |
args.pop('xtra'); args.pop('pdb') | |
args.update({k:v for k, v in kwargs.items()}) | |
if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: | |
wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']} | |
return fun(**args) | |
# %% ../nbs/B2. Training (Lightning).ipynb 8 | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--task', type=str, help='Task to train') | |
parser.add_argument('--seed', type=int, default=0, help='Global training seed') | |
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') | |
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') | |
parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now | |
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints") | |
parser.add_argument('--epochs', type=int, default=10, help='total training epochs') | |
parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations') | |
parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay') | |
parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate') | |
parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping') | |
parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples') | |
parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision") | |
parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)') | |
parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters') | |
parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint') | |
parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy') | |
parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix') | |
parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name') | |
args = parser.parse_args().__dict__ | |
task_args: list = shlex.split(args.pop("task")) | |
task_name, task_args = task_args[0], task_args[1:] | |
input_args: list = shlex.split(args.pop("input_dir")) | |
checkpoint_dir: str = args.pop("checkpoint_dir") | |
num_workers: int = args.pop("workers") | |
batch_size: int = args.pop("batch_size") | |
epochs: int = args.pop("epochs") | |
tunables_args: list = shlex.split(args.pop("tunables")) | |
hyp_params = {} | |
hyp_params['batch_size'] = batch_size | |
hyp_params['warmup_steps'] = args['warmup_steps'] | |
hyp_params['weight_decay'] = args['weight_decay'] | |
hyp_params['clip_gradient_norm'] = args['clip_gradient_norm'] | |
hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches'] | |
hyp_params['precision'] = args['precision'] | |
hyp_params['lr0'] = args['lr0'] | |
hyp_params['epochs'] = epochs | |
hyp_params['strategy'] = args['strategy'] | |
# %% ../nbs/B2. Training (Lightning).ipynb 9 | |
from lightning.pytorch.loggers import WandbLogger | |
from lightning.pytorch.callbacks import LearningRateMonitor | |
import datetime | |
import webdataset as wds | |
import importlib | |
torch.set_float32_matmul_precision('medium') | |
project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}" | |
if args['wandb_suffix']: | |
project += "-"+args['wandb_suffix'] | |
wandb_logger = WandbLogger(project=project) | |
ckpt_callback = pl.callbacks.ModelCheckpoint( | |
dirpath=f'{task_name}-{epochs}e', | |
filename=task_name+"-{epoch}-{step}-{val_loss:.2f}", | |
monitor="val_loss", | |
save_top_k=4, | |
train_time_interval=datetime.timedelta(minutes=5), | |
) | |
lr_monitor_callback = LearningRateMonitor(logging_interval='step') | |
from torch.utils.data import DataLoader | |
task = importlib.import_module("whisperspeech."+task_name) | |
train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args) | |
tunables = None | |
if hasattr(task, "Tunables"): | |
import dataclasses | |
tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False) | |
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: | |
wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables) | |
for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]: | |
val = getattr(tunables, name, None) | |
if val is not None: hyp_params[name] = val | |
if isinstance(train_ds, torch.utils.data.IterableDataset): | |
dl_batch_size, dl_shuffle = None, False | |
pin_memory = False | |
else: | |
dl_batch_size, dl_shuffle = batch_size, True | |
pin_memory = True | |
val_loader = wds.WebLoader(val_ds, | |
batch_size=dl_batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size) | |
train_loader = wds.WebLoader(train_ds, | |
batch_size=dl_batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
shuffle=dl_shuffle, | |
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size) | |
model_kwargs = dict(dataset=train_ds) | |
if tunables is not None: model_kwargs['tunables'] = tunables | |
model = parse_and_call('model', task.make_model, task_args, model_kwargs) | |
task = TrainingTask(model, model_hparams=hyp_params) | |
trainer = pl.Trainer(strategy=hyp_params['strategy'], | |
max_epochs=hyp_params['epochs'], | |
accelerator="gpu", | |
profiler="simple", | |
precision=hyp_params['precision'], | |
gradient_clip_val=hyp_params['clip_gradient_norm'], | |
accumulate_grad_batches=hyp_params['accumulate_grad_batches'], | |
val_check_interval=args.pop("validate_every_n_steps"), | |
enable_checkpointing=True, | |
logger=wandb_logger, | |
callbacks=[ckpt_callback, lr_monitor_callback]) | |
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config: | |
wandb_logger.experiment.config.update(hyp_params) | |
kwargs = {} | |
if 'resume_from' in args: | |
kwargs['ckpt_path'] = args['resume_from'] | |
trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs) | |