File size: 1,080 Bytes
6551065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import os
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks import ModelCheckpoint
def add_callbacks(args):
log_dir = args.savedmodel_path
os.makedirs(log_dir, exist_ok=True)
# --------- Add Callbacks
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(log_dir, "checkpoints"),
filename="{epoch}-{step}",
save_top_k=-1,
every_n_train_steps=args.every_n_train_steps,
save_last=False,
save_weights_only=False
)
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")
to_returns = {
"callbacks": [checkpoint_callback, lr_monitor_callback],
"loggers": [csv_logger, tb_logger]
}
return to_returns
|