Spaces:
Running
Running
File size: 3,018 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from datetime import datetime
import logging
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import find_usable_cuda_devices # type: ignore
from lightning.pytorch.strategies import DDPStrategy
import torch
from models.config import (
HifiGanConfig,
PreprocessingConfig,
)
from models.vocoder.hifigan import HifiGan
from models.vocoder.hifigan.generator import Generator
# Get the current date and time
now = datetime.now()
# Format the current date and time as a string
timestamp = now.strftime("%Y%m%d_%H%M%S")
# Create a logger
logger = logging.getLogger("my_logger")
# Set the level of the logger to ERROR
logger.setLevel(logging.ERROR)
# Create a file handler that logs error messages to a file with the current timestamp in its name
handler = logging.FileHandler(f"logs/error_{timestamp}.log")
# Create a formatter and add it to the handler
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
# Add the handler to the logger
logger.addHandler(handler)
print("usable_cuda_devices: ", find_usable_cuda_devices())
# Set the precision of the matrix multiplication to float32 to improve the performance of the training
torch.set_float32_matmul_precision("high")
default_root_dir = "logs_hifi_vocoder"
# ckpt_acoustic="./checkpoints/epoch=301-step=124630.ckpt"
# ckpt_vocoder="./checkpoints/vocoder.ckpt"
trainer = Trainer(
accelerator="cuda",
devices=-1,
strategy=DDPStrategy(
gradient_as_bucket_view=True,
find_unused_parameters=True,
),
# Save checkpoints to the `default_root_dir` directory
default_root_dir=default_root_dir,
enable_checkpointing=True,
max_epochs=-1,
log_every_n_steps=10,
)
model = HifiGan()
# Try to preload the model from the nvidia checkpoint
hifigan_state_dict = torch.load("./checkpoints/hifigan_spanish.pth")
model.load_state_dict(hifigan_state_dict)
# Desc checkpoints state load
# disc_checkpoint_path = "checkpoints/do_02500000"
# disc_checkpoint = torch.load(disc_checkpoint_path)
# model.discriminator.MPD.load_state_dict(disc_checkpoint["mpd"])
# model.discriminator.MSD.load_state_dict(disc_checkpoint["msd"])
# Gen checkpoints state load
# gen_checkpoint_path = "checkpoints/generator_v1"
# gen_checkpoint = torch.load(gen_checkpoint_path)
# model.generator.load_state_dict(gen_checkpoint["generator"])
# Reset the parameters of the generator
# preprocess_config = PreprocessingConfig(
# "multilingual",
# sampling_rate=44100,
# )
# model.generator = Generator(
# h=HifiGanConfig(),
# p=preprocess_config,
# )
# len(gen_checkpoint)
train_dataloader = model.train_dataloader(
root="/dev/shm/",
# NOTE: Preload the cached dataset into the RAM
cache_dir="/dev/shm/",
cache=True,
)
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
# val_dataloaders=val_dataloader,
# Resume training states from the checkpoint file
# ckpt_path=ckpt_acoustic,
)
|