File size: 3,594 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from argparse import ArgumentParser

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner.tuning import Tuner

from models.tts.delightful_tts import DelightfulTTS
from models.vocoder.univnet import UnivNet


def train():
    parser = ArgumentParser()

    # Trainer arguments
    parser.add_argument("--devices", type=int, default=None)
    parser.add_argument("--default_root_dir", type=str, default="logs/acoustic")

    parser.add_argument("--limit_train_batches", type=int, default=None)
    parser.add_argument("--max_epochs", type=int, default=None)

    # Default we use 3 batches to accumulate gradients
    parser.add_argument("--accumulate_grad_batches", type=int, default=3)
    parser.add_argument("--accelerator", type=str, default="cuda")
    parser.add_argument("--ckpt_acoustic", type=str, default="./checkpoints/am_pitche_stats_with_vocoder.ckpt")
    parser.add_argument("--ckpt_vocoder", type=str, default="./checkpoints/vocoder.ckpt")

    # Optimizers
    # FIXME: this is not working, found an errors...
    # Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
    # Stochastic Weight Averaging (SWA)
    parser.add_argument("--swa", type=bool, default=False)
    # Learning rate finder
    parser.add_argument("--lr_find", type=bool, default=False)
    # Batch size scaling
    parser.add_argument("--batch_size_scaling", type=bool, default=False)

    # Parse the user inputs and defaults (returns a argparse.Namespace)
    args = parser.parse_args()

    tensorboard = TensorBoardLogger(save_dir=args.default_root_dir)

    callbacks = []

    if args.swa:
        callbacks.append(
            # Stochastic Weight Averaging (SWA) can make your models generalize
            # better at virtually no additional cost.
            # This can be used with both non-trained and trained models.
            # The SWA procedure smooths the loss landscape thus making it
            # harder to end up in a local minimum during optimization.
            StochasticWeightAveraging(swa_lrs=1e-2),
        )

    trainer = Trainer(
        logger=tensorboard,
        # Save checkpoints to the `default_root_dir` directory
        default_root_dir=args.default_root_dir,
        limit_train_batches=args.limit_train_batches,
        max_epochs=args.max_epochs,
        accelerator=args.accelerator,
        accumulate_grad_batches=args.accumulate_grad_batches,
        callbacks=callbacks,
    )

    # Create a Tuner
    tuner = Tuner(trainer)

    # Load the pretrained weights for the vocoder
    vocoder_module = UnivNet.load_from_checkpoint(
        args.ckpt_vocoder,
    )
    module = DelightfulTTS.load_from_checkpoint(
        args.ckpt_acoustic,
        vocoder_module=vocoder_module,
    )

    train_dataloader = module.train_dataloader()

    if args.lr_find:
        # finds learning rate automatically
        # sets hparams.lr or hparams.learning_rate to that learning rate
        tuner.lr_find(module)

    if args.batch_size_scaling:
        # Auto-scale batch size by growing it exponentially (default)
        tuner.scale_batch_size(module, init_val=2, mode="power")

    trainer.fit(model=module, train_dataloaders=train_dataloader)


if __name__ == "__main__":
    train()
    # note: it is good practice to implement the CLI in a function and call it in the main if block