WhisperSpeech / whisperspeech /train_multi.py
tonic
Laion WhisperSpeech Demo
33d9042
raw
history blame
11.1 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
# %% auto 0
__all__ = []
# %% ../nbs/B2. Training (Lightning).ipynb 2
import io
import time
import random
from pathlib import Path
from fastprogress import progress_bar, master_bar
import fastprogress
import wandb
import numpy as np
import pylab as plt
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torch.profiler import record_function
# %% ../nbs/B2. Training (Lightning).ipynb 3
import lightning.pytorch as pl
import math
class TrainingTask(pl.LightningModule):
def __init__(self, model, model_hparams=None):
super().__init__()
self.model = model
self.model_hparams = model_hparams
def on_fit_start(self):
if getattr(self.model, 'setup'):
self.model.setup(self.device)
def configure_optimizers(self):
""" Initialize AdamW optimizer"""
lr = self.model_hparams['lr0']
weight_decay = self.model_hparams['weight_decay']
all_params = set(model.parameters())
customized_params = set()
groups = []
group_map = {}
for name,m in model.named_modules():
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
customized_params |= set(m.parameters())
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
m_lr = lr * getattr(m, 'lr_scale', 1)
group = group_map.get((m_wd, m_lr), None)
if not group:
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
groups.append(group)
group_map[(m_wd, m_lr)] = group
group['params'] += m.parameters()
group['names'].append(name)
other_params = all_params - customized_params
param_groups = groups + [
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
]
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)
# modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
def num_steps_per_epoch() -> int:
"""Get number of steps"""
# Accessing _data_source is flaky and might break
dataset = self.trainer.fit_loop._data_source.dataloader()
dataset_size = len(dataset)
# math.ceil so always overestimate (underestimating throws exceptions)
num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches)
return num_steps
total_steps = self.model_hparams['epochs'] * num_steps_per_epoch()
self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps)
print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps")
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
pct_start=self.model_hparams['pct_start'],
max_lr=[pg.get('lr', lr) for pg in param_groups],
steps_per_epoch=num_steps_per_epoch(),
epochs=int(self.model_hparams['epochs']),
final_div_factor=25
)
return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
def training_step(self, train_batch, batch_idx):
train_logits, train_loss = self.model.forward(*train_batch)
self.log("train_loss", train_loss, sync_dist=True)
return train_loss
def validation_step(self, val_batch, batch_idx):
val_logits, val_loss = self.model.forward(*val_batch)
self.log("val_loss", val_loss, sync_dist=True)
return val_loss
def on_validation_epoch_end(self):
if hasattr(self.model, 'get_metrics'):
self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True)
def test_step(self, val_batch, batch_idx):
test_logits, test_loss = self.model.forward(*val_batch)
self.log("test_loss", test_loss, sync_dist=True)
return test_loss
# %% ../nbs/B2. Training (Lightning).ipynb 4
from fastcore.script import anno_parser
import shlex
# watch out: we can only pass Python values as keyword arguments (not positional)
# everything else has to be a string
def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
p = anno_parser(fun)
args = p.parse_args(args).__dict__
args.pop('xtra'); args.pop('pdb')
args.update({k:v for k, v in kwargs.items()})
if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']}
return fun(**args)
# %% ../nbs/B2. Training (Lightning).ipynb 8
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, help='Task to train')
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now
parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations')
parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples')
parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision")
parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters')
parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint')
parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy')
parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix')
parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name')
args = parser.parse_args().__dict__
task_args: list = shlex.split(args.pop("task"))
task_name, task_args = task_args[0], task_args[1:]
input_args: list = shlex.split(args.pop("input_dir"))
checkpoint_dir: str = args.pop("checkpoint_dir")
num_workers: int = args.pop("workers")
batch_size: int = args.pop("batch_size")
epochs: int = args.pop("epochs")
tunables_args: list = shlex.split(args.pop("tunables"))
hyp_params = {}
hyp_params['batch_size'] = batch_size
hyp_params['warmup_steps'] = args['warmup_steps']
hyp_params['weight_decay'] = args['weight_decay']
hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']
hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches']
hyp_params['precision'] = args['precision']
hyp_params['lr0'] = args['lr0']
hyp_params['epochs'] = epochs
hyp_params['strategy'] = args['strategy']
# %% ../nbs/B2. Training (Lightning).ipynb 9
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
import datetime
import webdataset as wds
import importlib
torch.set_float32_matmul_precision('medium')
project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}"
if args['wandb_suffix']:
project += "-"+args['wandb_suffix']
wandb_logger = WandbLogger(project=project)
ckpt_callback = pl.callbacks.ModelCheckpoint(
dirpath=f'{task_name}-{epochs}e',
filename=task_name+"-{epoch}-{step}-{val_loss:.2f}",
monitor="val_loss",
save_top_k=4,
train_time_interval=datetime.timedelta(minutes=5),
)
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
from torch.utils.data import DataLoader
task = importlib.import_module("whisperspeech."+task_name)
train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args)
tunables = None
if hasattr(task, "Tunables"):
import dataclasses
tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False)
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables)
for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]:
val = getattr(tunables, name, None)
if val is not None: hyp_params[name] = val
if isinstance(train_ds, torch.utils.data.IterableDataset):
dl_batch_size, dl_shuffle = None, False
pin_memory = False
else:
dl_batch_size, dl_shuffle = batch_size, True
pin_memory = True
val_loader = wds.WebLoader(val_ds,
batch_size=dl_batch_size,
num_workers=num_workers,
drop_last=False,
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size)
train_loader = wds.WebLoader(train_ds,
batch_size=dl_batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=dl_shuffle,
pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size)
model_kwargs = dict(dataset=train_ds)
if tunables is not None: model_kwargs['tunables'] = tunables
model = parse_and_call('model', task.make_model, task_args, model_kwargs)
task = TrainingTask(model, model_hparams=hyp_params)
trainer = pl.Trainer(strategy=hyp_params['strategy'],
max_epochs=hyp_params['epochs'],
accelerator="gpu",
profiler="simple",
precision=hyp_params['precision'],
gradient_clip_val=hyp_params['clip_gradient_norm'],
accumulate_grad_batches=hyp_params['accumulate_grad_batches'],
val_check_interval=args.pop("validate_every_n_steps"),
enable_checkpointing=True,
logger=wandb_logger,
callbacks=[ckpt_callback, lr_monitor_callback])
if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
wandb_logger.experiment.config.update(hyp_params)
kwargs = {}
if 'resume_from' in args:
kwargs['ckpt_path'] = args['resume_from']
trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)