tonic
Laion WhisperSpeech Demo
33d9042
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb.
# %% auto 0
__all__ = ['SimpleVisual', 'validate', 'train']
# %% ../nbs/B1. Training.ipynb 2
import io
import time
import random
from pathlib import Path
from fastprogress import progress_bar, master_bar
import fastprogress
import numpy as np
import pylab as plt
import math
import IPython
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torch.profiler import record_function
import webdataset as wds
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision('medium')
# %% ../nbs/B1. Training.ipynb 3
class SimpleVisual:
def __init__ (self, model, masterbar, total_steps):
self.model = model
self.masterbar = masterbar
self.total_steps = total_steps
self.epochs = total_steps // masterbar.main_bar.total
gs = plt.GridSpec(2, 1, height_ratios=[3,1])
graph_fig = plt.figure(figsize=(10,6))
self.graph_fig = graph_fig
self.loss_p = graph_fig.add_subplot(gs[0])
self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
self.lr_p.tick_params('x', labelbottom=False)
self.graph_out = None
self.its = []
self.train_losses = []
self.val_losses = []
self.lr_history = []
def show(self):
self.start_t = time.time()
self.masterbar.write(["samples", "train", "val", "time"], table=True)
self.graph_out = display(self.graph_fig, display_id=True, clear=True)
def hide(self):
if self.graph_out is not None:
self.graph_out.update(IPython.display.HTML(''))
def plot(self):
loss_p, lr_p = self.loss_p, self.lr_p
loss_p.clear()
loss_p.plot(self.its, self.train_losses)
loss_p.plot(self.its, self.val_losses)
loss_p.set_xlim(0, self.total_steps)
loss_p.set_yscale('log')
lr_p.clear()
lrs = np.array(self.lr_history)
lr_p.plot(self.its, lrs)
self.graph_out.update(self.graph_fig)
def add_data(self, it, lr, train_loss, val_los):
self.its.append(it)
self.train_losses.append(train_loss)
self.val_losses.append(val_los)
self.lr_history.append(lr)
self.plot()
def add_table_row(self, it, avg_train_loss, val_loss):
elapsed_t = time.time() - self.start_t
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
def on_iter(self, bar, it, avg_train_loss, val_loss):
epoch = math.ceil(it / self.total_steps * self.epochs)
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
# %% ../nbs/B1. Training.ipynb 4
# FIXME: we need to keep this synchronised with the validation code below...
def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"):
if isinstance(val, torch.utils.data.IterableDataset):
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
.unbatched().shuffle(1024).batched(bs)
else:
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
with torch.no_grad():
val_loss = 0
val_samples = 0
for args in val_loader:
args = [x.to(device, non_blocking=True) for x in args]
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
ps, loss = model(*args)
N = args[0].shape[0]
val_loss += loss.mean().item() * N
val_samples += N
val_loss = val_loss / val_samples
return val_loss
# %% ../nbs/B1. Training.ipynb 5
def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False,
weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None,
dl_workers=8, visual_class = SimpleVisual, profiler=None,
run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None,
device="cuda", trainable_params=None):
if chkpt_every_iters is None:
chkpt_every_iters = table_row_every_iters
mb = master_bar(range(epochs))
if isinstance(train, torch.utils.data.IterableDataset):
pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs)))
visual = visual_class(model, mb, epochs * train.total_samples)
# pct_start = min(0.3, warmup_steps / (epochs * len(train)))
# visual = visual_class(model, mb, epochs*len(train)*bs)
else:
pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs))
visual = visual_class(model, mb, epochs*len(train))
model.visual = visual
Path(checkpoint_path).mkdir(exist_ok=True)
if isinstance(train, torch.utils.data.IterableDataset):
# train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False)
# val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False)
train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
.unbatched().shuffle(1024).batched(bs, partial=False)
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
.unbatched().shuffle(1024).batched(bs)
else:
train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True)
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
val_loss = torch.nan
avg_train_loss = torch.nan
if hasattr(model, 'setup'):
model.setup(device)
try:
scheduler = None
if trainable_params is None: trainable_params = model.parameters()
all_params = set(trainable_params)
customized_params = set()
groups = []
group_map = {}
for name,m in model.named_modules():
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
m_trainable = [x for x in m.parameters() if x in all_params]
if not m_trainable: continue
customized_params |= set(m_trainable)
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
m_lr = lr * getattr(m, 'lr_scale', 1)
group = group_map.get((m_wd, m_lr), None)
if not group:
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
groups.append(group)
group_map[(m_wd, m_lr)] = group
group['params'] += m_trainable
group['names'].append(name)
other_params = all_params - customized_params
if other_params:
groups = groups + [
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
]
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups)
model._optimizer = optimizer
scaler = torch.cuda.amp.GradScaler(enabled=half)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs,
max_lr=[pg.get('lr', lr) for pg in groups],
final_div_factor=25)
it = 0
next_val_it = it + 50
next_chkpt_it = chkpt_every_iters
next_table_it = table_row_every_iters
visual.show()
running_loss = [0]
for epoch in mb:
bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb)
for args in bar:
with record_function("forward"):
args = [x.to(device, non_blocking=True) for x in args]
# zero the parameter gradients
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
ps, loss = model(*args)
loss = loss.mean()
with record_function("backward"):
scaler.scale(loss).backward()
if clip_gradient_norm:
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
scaler.step(optimizer)
scaler.update()
scheduler.step()
if profiler is not None: profiler.step()
with record_function("running_loss"):
running_loss.append(loss.item())
running_loss = running_loss[-5:]
avg_train_loss = sum(running_loss)/len(running_loss)
if it >= next_chkpt_it:
with record_function("checkpoint"):
next_chkpt_it += chkpt_every_iters
torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt')
if it >= next_val_it:
next_val_it += run_valid_every_iters
with record_function("validation"):
with record_function("model.eval"):
model.eval()
with torch.no_grad():
val_loss = 0
val_samples = 0
for args in val_loader:
args = [x.to(device, non_blocking=True) for x in args]
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
ps, loss = model(*args)
N = args[0].shape[0]
val_loss += loss.mean().item() * N
val_samples += N
val_loss = val_loss / val_samples
with record_function("model.train"):
model.train()
with record_function("plotting"):
visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss)
if it >= next_table_it:
visual.add_table_row(it, avg_train_loss, val_loss)
next_table_it += table_row_every_iters
it += bs
visual.on_iter(bar, it, avg_train_loss, val_loss)
except KeyboardInterrupt:
mb.write(f"interrupted")
mb.show()
pass
finally:
visual.add_table_row(it, avg_train_loss, val_loss)
mb.show()
visual.hide()