File size: 3,018 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from datetime import datetime
import logging

from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import find_usable_cuda_devices  # type: ignore
from lightning.pytorch.strategies import DDPStrategy
import torch

from models.config import (
    HifiGanConfig,
    PreprocessingConfig,
)
from models.vocoder.hifigan import HifiGan
from models.vocoder.hifigan.generator import Generator

# Get the current date and time
now = datetime.now()

# Format the current date and time as a string
timestamp = now.strftime("%Y%m%d_%H%M%S")

# Create a logger
logger = logging.getLogger("my_logger")

# Set the level of the logger to ERROR
logger.setLevel(logging.ERROR)

# Create a file handler that logs error messages to a file with the current timestamp in its name
handler = logging.FileHandler(f"logs/error_{timestamp}.log")

# Create a formatter and add it to the handler
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)


print("usable_cuda_devices: ", find_usable_cuda_devices())

# Set the precision of the matrix multiplication to float32 to improve the performance of the training
torch.set_float32_matmul_precision("high")

default_root_dir = "logs_hifi_vocoder"

# ckpt_acoustic="./checkpoints/epoch=301-step=124630.ckpt"

# ckpt_vocoder="./checkpoints/vocoder.ckpt"

trainer = Trainer(
    accelerator="cuda",
    devices=-1,
    strategy=DDPStrategy(
        gradient_as_bucket_view=True,
        find_unused_parameters=True,
    ),
    # Save checkpoints to the `default_root_dir` directory
    default_root_dir=default_root_dir,
    enable_checkpointing=True,
    max_epochs=-1,
    log_every_n_steps=10,
)

model = HifiGan()

# Try to preload the model from the nvidia checkpoint
hifigan_state_dict = torch.load("./checkpoints/hifigan_spanish.pth")
model.load_state_dict(hifigan_state_dict)

# Desc checkpoints state load
# disc_checkpoint_path = "checkpoints/do_02500000"
# disc_checkpoint = torch.load(disc_checkpoint_path)
# model.discriminator.MPD.load_state_dict(disc_checkpoint["mpd"])
# model.discriminator.MSD.load_state_dict(disc_checkpoint["msd"])

# Gen checkpoints state load
# gen_checkpoint_path = "checkpoints/generator_v1"
# gen_checkpoint = torch.load(gen_checkpoint_path)
# model.generator.load_state_dict(gen_checkpoint["generator"])

# Reset the parameters of the generator
# preprocess_config = PreprocessingConfig(
#     "multilingual",
#     sampling_rate=44100,
# )

# model.generator = Generator(
#     h=HifiGanConfig(),
#     p=preprocess_config,
# )

# len(gen_checkpoint)

train_dataloader = model.train_dataloader(
    root="/dev/shm/",
    # NOTE: Preload the cached dataset into the RAM
    cache_dir="/dev/shm/",
    cache=True,
)

trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    # val_dataloaders=val_dataloader,
    # Resume training states from the checkpoint file
    # ckpt_path=ckpt_acoustic,
)