File size: 3,106 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from datetime import datetime
import logging
import os
import sys

from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import find_usable_cuda_devices  # type: ignore
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.tuner.tuning import Tuner
import torch

from models.tts.delightful_tts.delightful_tts import DelightfulTTS

# Node runk in the cluster
node_rank = 0
num_nodes = 2

# Setup of the training cluster
os.environ["MASTER_PORT"] = "12355"
# Change the IP address to the IP address of the master node
os.environ["MASTER_ADDR"] = "10.148.0.6"
os.environ["WORLD_SIZE"] = f"{num_nodes}"
# Change the IP address to the IP address of the master node
os.environ["NODE_RANK"] = f"{node_rank}"

# Get the current date and time
now = datetime.now()

# Format the current date and time as a string
timestamp = now.strftime("%Y%m%d_%H%M%S")

# Create a logger
logger = logging.getLogger("my_logger")

# Set the level of the logger to ERROR
logger.setLevel(logging.ERROR)

# Create a file handler that logs error messages to a file with the current timestamp in its name
handler = logging.FileHandler(f"logs/error_{timestamp}.log")

# Create a formatter and add it to the handler
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)


print("usable_cuda_devices: ", find_usable_cuda_devices())

# Set the precision of the matrix multiplication to float32 to improve the performance of the training
torch.set_float32_matmul_precision("high")

default_root_dir = "logs"

ckpt_acoustic = "./checkpoints/epoch=301-step=124630.ckpt"

ckpt_vocoder = "./checkpoints/vocoder.ckpt"

try:
    trainer = Trainer(
        accelerator="cuda",
        devices=-1,
        num_nodes=num_nodes,
        strategy=DDPStrategy(
            gradient_as_bucket_view=True,
            find_unused_parameters=True,
        ),
        # Save checkpoints to the `default_root_dir` directory
        default_root_dir=default_root_dir,
        enable_checkpointing=True,
        accumulate_grad_batches=5,
        max_epochs=-1,
        log_every_n_steps=10,
        gradient_clip_val=0.5,
    )

    # model = DelightfulTTS()
    model = DelightfulTTS.load_from_checkpoint(ckpt_acoustic, strict=False)

    tuner = Tuner(trainer)
    tuner.lr_find(model)
    # ValueError: Tuning the batch size is currently not supported with distributed strategies.
    # tuner.scale_batch_size(model, mode="binsearch")

    train_dataloader = model.train_dataloader(
        # NOTE: Preload the cached dataset into the RAM
        cache_dir="/dev/shm/",
        cache=True,
    )

    trainer.fit(
        model=model,
        train_dataloaders=train_dataloader,
        # Resume training states from the checkpoint file
        # ckpt_path=ckpt_acoustic,
    )

except Exception as e:
    # Log the error message
    logger.error(f"An error occurred: {e}")
    sys.exit(1)