Spaces:
Running
on
T4
Running
on
T4
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb. | |
# %% auto 0 | |
__all__ = ['SimpleVisual', 'validate', 'train'] | |
# %% ../nbs/B1. Training.ipynb 2 | |
import io | |
import time | |
import random | |
from pathlib import Path | |
from fastprogress import progress_bar, master_bar | |
import fastprogress | |
import numpy as np | |
import pylab as plt | |
import math | |
import IPython | |
import torch | |
import torch.nn as nn | |
from torch.utils.data.dataloader import DataLoader | |
from torch.profiler import record_function | |
import webdataset as wds | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.enabled = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.set_float32_matmul_precision('medium') | |
# %% ../nbs/B1. Training.ipynb 3 | |
class SimpleVisual: | |
def __init__ (self, model, masterbar, total_steps): | |
self.model = model | |
self.masterbar = masterbar | |
self.total_steps = total_steps | |
self.epochs = total_steps // masterbar.main_bar.total | |
gs = plt.GridSpec(2, 1, height_ratios=[3,1]) | |
graph_fig = plt.figure(figsize=(10,6)) | |
self.graph_fig = graph_fig | |
self.loss_p = graph_fig.add_subplot(gs[0]) | |
self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) | |
self.lr_p.tick_params('x', labelbottom=False) | |
self.graph_out = None | |
self.its = [] | |
self.train_losses = [] | |
self.val_losses = [] | |
self.lr_history = [] | |
def show(self): | |
self.start_t = time.time() | |
self.masterbar.write(["samples", "train", "val", "time"], table=True) | |
self.graph_out = display(self.graph_fig, display_id=True, clear=True) | |
def hide(self): | |
if self.graph_out is not None: | |
self.graph_out.update(IPython.display.HTML('')) | |
def plot(self): | |
loss_p, lr_p = self.loss_p, self.lr_p | |
loss_p.clear() | |
loss_p.plot(self.its, self.train_losses) | |
loss_p.plot(self.its, self.val_losses) | |
loss_p.set_xlim(0, self.total_steps) | |
loss_p.set_yscale('log') | |
lr_p.clear() | |
lrs = np.array(self.lr_history) | |
lr_p.plot(self.its, lrs) | |
self.graph_out.update(self.graph_fig) | |
def add_data(self, it, lr, train_loss, val_los): | |
self.its.append(it) | |
self.train_losses.append(train_loss) | |
self.val_losses.append(val_los) | |
self.lr_history.append(lr) | |
self.plot() | |
def add_table_row(self, it, avg_train_loss, val_loss): | |
elapsed_t = time.time() - self.start_t | |
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) | |
def on_iter(self, bar, it, avg_train_loss, val_loss): | |
epoch = math.ceil(it / self.total_steps * self.epochs) | |
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" | |
# %% ../nbs/B1. Training.ipynb 4 | |
# FIXME: we need to keep this synchronised with the validation code below... | |
def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"): | |
if isinstance(val, torch.utils.data.IterableDataset): | |
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
.unbatched().shuffle(1024).batched(bs) | |
else: | |
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) | |
with torch.no_grad(): | |
val_loss = 0 | |
val_samples = 0 | |
for args in val_loader: | |
args = [x.to(device, non_blocking=True) for x in args] | |
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
ps, loss = model(*args) | |
N = args[0].shape[0] | |
val_loss += loss.mean().item() * N | |
val_samples += N | |
val_loss = val_loss / val_samples | |
return val_loss | |
# %% ../nbs/B1. Training.ipynb 5 | |
def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False, | |
weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None, | |
dl_workers=8, visual_class = SimpleVisual, profiler=None, | |
run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None, | |
device="cuda", trainable_params=None): | |
if chkpt_every_iters is None: | |
chkpt_every_iters = table_row_every_iters | |
mb = master_bar(range(epochs)) | |
if isinstance(train, torch.utils.data.IterableDataset): | |
pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs))) | |
visual = visual_class(model, mb, epochs * train.total_samples) | |
# pct_start = min(0.3, warmup_steps / (epochs * len(train))) | |
# visual = visual_class(model, mb, epochs*len(train)*bs) | |
else: | |
pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs)) | |
visual = visual_class(model, mb, epochs*len(train)) | |
model.visual = visual | |
Path(checkpoint_path).mkdir(exist_ok=True) | |
if isinstance(train, torch.utils.data.IterableDataset): | |
# train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False) | |
# val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False) | |
train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
.unbatched().shuffle(1024).batched(bs, partial=False) | |
val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \ | |
.unbatched().shuffle(1024).batched(bs) | |
else: | |
train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True) | |
val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last) | |
val_loss = torch.nan | |
avg_train_loss = torch.nan | |
if hasattr(model, 'setup'): | |
model.setup(device) | |
try: | |
scheduler = None | |
if trainable_params is None: trainable_params = model.parameters() | |
all_params = set(trainable_params) | |
customized_params = set() | |
groups = [] | |
group_map = {} | |
for name,m in model.named_modules(): | |
if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'): | |
m_trainable = [x for x in m.parameters() if x in all_params] | |
if not m_trainable: continue | |
customized_params |= set(m_trainable) | |
m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay | |
m_lr = lr * getattr(m, 'lr_scale', 1) | |
group = group_map.get((m_wd, m_lr), None) | |
if not group: | |
group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr} | |
groups.append(group) | |
group_map[(m_wd, m_lr)] = group | |
group['params'] += m_trainable | |
group['names'].append(name) | |
other_params = all_params - customized_params | |
if other_params: | |
groups = groups + [ | |
{"names": ["other"], "params": list(other_params), "weight_decay": weight_decay }, | |
] | |
optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups) | |
model._optimizer = optimizer | |
scaler = torch.cuda.amp.GradScaler(enabled=half) | |
scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs, | |
max_lr=[pg.get('lr', lr) for pg in groups], | |
final_div_factor=25) | |
it = 0 | |
next_val_it = it + 50 | |
next_chkpt_it = chkpt_every_iters | |
next_table_it = table_row_every_iters | |
visual.show() | |
running_loss = [0] | |
for epoch in mb: | |
bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb) | |
for args in bar: | |
with record_function("forward"): | |
args = [x.to(device, non_blocking=True) for x in args] | |
# zero the parameter gradients | |
optimizer.zero_grad(set_to_none=True) | |
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
ps, loss = model(*args) | |
loss = loss.mean() | |
with record_function("backward"): | |
scaler.scale(loss).backward() | |
if clip_gradient_norm: | |
scaler.unscale_(optimizer) | |
# Since the gradients of optimizer's assigned params are unscaled, clips as usual: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm) | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
if profiler is not None: profiler.step() | |
with record_function("running_loss"): | |
running_loss.append(loss.item()) | |
running_loss = running_loss[-5:] | |
avg_train_loss = sum(running_loss)/len(running_loss) | |
if it >= next_chkpt_it: | |
with record_function("checkpoint"): | |
next_chkpt_it += chkpt_every_iters | |
torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt') | |
if it >= next_val_it: | |
next_val_it += run_valid_every_iters | |
with record_function("validation"): | |
with record_function("model.eval"): | |
model.eval() | |
with torch.no_grad(): | |
val_loss = 0 | |
val_samples = 0 | |
for args in val_loader: | |
args = [x.to(device, non_blocking=True) for x in args] | |
with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'): | |
ps, loss = model(*args) | |
N = args[0].shape[0] | |
val_loss += loss.mean().item() * N | |
val_samples += N | |
val_loss = val_loss / val_samples | |
with record_function("model.train"): | |
model.train() | |
with record_function("plotting"): | |
visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss) | |
if it >= next_table_it: | |
visual.add_table_row(it, avg_train_loss, val_loss) | |
next_table_it += table_row_every_iters | |
it += bs | |
visual.on_iter(bar, it, avg_train_loss, val_loss) | |
except KeyboardInterrupt: | |
mb.write(f"interrupted") | |
mb.show() | |
pass | |
finally: | |
visual.add_table_row(it, avg_train_loss, val_loss) | |
mb.show() | |
visual.hide() | |