nickovchinnikov's picture
Init
9d61c9b
from argparse import ArgumentParser
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner.tuning import Tuner
from models.tts.delightful_tts import DelightfulTTS
from models.vocoder.univnet import UnivNet
def train():
parser = ArgumentParser()
# Trainer arguments
parser.add_argument("--devices", type=int, default=None)
parser.add_argument("--default_root_dir", type=str, default="logs/acoustic")
parser.add_argument("--limit_train_batches", type=int, default=None)
parser.add_argument("--max_epochs", type=int, default=None)
# Default we use 3 batches to accumulate gradients
parser.add_argument("--accumulate_grad_batches", type=int, default=3)
parser.add_argument("--accelerator", type=str, default="cuda")
parser.add_argument("--ckpt_acoustic", type=str, default="./checkpoints/am_pitche_stats_with_vocoder.ckpt")
parser.add_argument("--ckpt_vocoder", type=str, default="./checkpoints/vocoder.ckpt")
# Optimizers
# FIXME: this is not working, found an errors...
# Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
# Stochastic Weight Averaging (SWA)
parser.add_argument("--swa", type=bool, default=False)
# Learning rate finder
parser.add_argument("--lr_find", type=bool, default=False)
# Batch size scaling
parser.add_argument("--batch_size_scaling", type=bool, default=False)
# Parse the user inputs and defaults (returns a argparse.Namespace)
args = parser.parse_args()
tensorboard = TensorBoardLogger(save_dir=args.default_root_dir)
callbacks = []
if args.swa:
callbacks.append(
# Stochastic Weight Averaging (SWA) can make your models generalize
# better at virtually no additional cost.
# This can be used with both non-trained and trained models.
# The SWA procedure smooths the loss landscape thus making it
# harder to end up in a local minimum during optimization.
StochasticWeightAveraging(swa_lrs=1e-2),
)
trainer = Trainer(
logger=tensorboard,
# Save checkpoints to the `default_root_dir` directory
default_root_dir=args.default_root_dir,
limit_train_batches=args.limit_train_batches,
max_epochs=args.max_epochs,
accelerator=args.accelerator,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=callbacks,
)
# Create a Tuner
tuner = Tuner(trainer)
# Load the pretrained weights for the vocoder
vocoder_module = UnivNet.load_from_checkpoint(
args.ckpt_vocoder,
)
module = DelightfulTTS.load_from_checkpoint(
args.ckpt_acoustic,
vocoder_module=vocoder_module,
)
train_dataloader = module.train_dataloader()
if args.lr_find:
# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
tuner.lr_find(module)
if args.batch_size_scaling:
# Auto-scale batch size by growing it exponentially (default)
tuner.scale_batch_size(module, init_val=2, mode="power")
trainer.fit(model=module, train_dataloaders=train_dataloader)
if __name__ == "__main__":
train()
# note: it is good practice to implement the CLI in a function and call it in the main if block