|
import os
|
|
from lightning.pytorch.loggers import CSVLogger
|
|
from lightning.pytorch import loggers as pl_loggers
|
|
from lightning.pytorch.callbacks import LearningRateMonitor
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
|
|
|
|
def add_callbacks(args):
|
|
log_dir = args.savedmodel_path
|
|
os.makedirs(log_dir, exist_ok=True)
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
dirpath=os.path.join(log_dir, "checkpoints"),
|
|
filename="{epoch}-{step}",
|
|
save_top_k=-1,
|
|
every_n_train_steps=args.every_n_train_steps,
|
|
save_last=False,
|
|
save_weights_only=False
|
|
)
|
|
|
|
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
|
|
tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
|
|
csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")
|
|
|
|
to_returns = {
|
|
"callbacks": [checkpoint_callback, lr_monitor_callback],
|
|
"loggers": [csv_logger, tb_logger]
|
|
}
|
|
return to_returns
|
|
|