File size: 1,547 Bytes
1cc747d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
class IntervalModelCheckpoint(Callback):
"""
Save a checkpoint every N steps, instead of Lightning's default that checkpoints
based on validation loss.
"""
def __init__(
self,
dirpath,
save_intervals,
):
"""
Args:
save_step_frequency: how often to save in steps
prefix: add a prefix to the name, only used if
use_modelcheckpoint_filename=False
use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
default filename, don't use ours.
"""
self.dirpath = dirpath
self.save_intervals = save_intervals
self.best_val_loss = 1e10
def on_batch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train batch """
global_step = trainer.global_step
if (global_step + 1) in self.save_intervals:
trainer.run_evaluation()
val_loss = trainer.callback_metrics['val_loss']
filename = f"steps={global_step+1:05d}-val_loss={val_loss:0.8f}.ckpt"
ckpt_path = os.path.join(self.dirpath, filename)
trainer.save_checkpoint(ckpt_path)
if val_loss < self.best_val_loss:
best_ckpt_path = os.path.join(self.dirpath, 'best.ckpt')
trainer.save_checkpoint(best_ckpt_path)
self.best_val_loss = val_loss
|