from datetime import datetime import logging from lightning.pytorch import Trainer from lightning.pytorch.accelerators import find_usable_cuda_devices # type: ignore from lightning.pytorch.strategies import DDPStrategy import torch from models.config import ( HifiGanConfig, PreprocessingConfig, ) from models.vocoder.hifigan import HifiGan from models.vocoder.hifigan.generator import Generator # Get the current date and time now = datetime.now() # Format the current date and time as a string timestamp = now.strftime("%Y%m%d_%H%M%S") # Create a logger logger = logging.getLogger("my_logger") # Set the level of the logger to ERROR logger.setLevel(logging.ERROR) # Create a file handler that logs error messages to a file with the current timestamp in its name handler = logging.FileHandler(f"logs/error_{timestamp}.log") # Create a formatter and add it to the handler formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) # Add the handler to the logger logger.addHandler(handler) print("usable_cuda_devices: ", find_usable_cuda_devices()) # Set the precision of the matrix multiplication to float32 to improve the performance of the training torch.set_float32_matmul_precision("high") default_root_dir = "logs_hifi_vocoder" # ckpt_acoustic="./checkpoints/epoch=301-step=124630.ckpt" # ckpt_vocoder="./checkpoints/vocoder.ckpt" trainer = Trainer( accelerator="cuda", devices=-1, strategy=DDPStrategy( gradient_as_bucket_view=True, find_unused_parameters=True, ), # Save checkpoints to the `default_root_dir` directory default_root_dir=default_root_dir, enable_checkpointing=True, max_epochs=-1, log_every_n_steps=10, ) model = HifiGan() # Try to preload the model from the nvidia checkpoint hifigan_state_dict = torch.load("./checkpoints/hifigan_spanish.pth") model.load_state_dict(hifigan_state_dict) # Desc checkpoints state load # disc_checkpoint_path = "checkpoints/do_02500000" # disc_checkpoint = torch.load(disc_checkpoint_path) # model.discriminator.MPD.load_state_dict(disc_checkpoint["mpd"]) # model.discriminator.MSD.load_state_dict(disc_checkpoint["msd"]) # Gen checkpoints state load # gen_checkpoint_path = "checkpoints/generator_v1" # gen_checkpoint = torch.load(gen_checkpoint_path) # model.generator.load_state_dict(gen_checkpoint["generator"]) # Reset the parameters of the generator # preprocess_config = PreprocessingConfig( # "multilingual", # sampling_rate=44100, # ) # model.generator = Generator( # h=HifiGanConfig(), # p=preprocess_config, # ) # len(gen_checkpoint) train_dataloader = model.train_dataloader( root="/dev/shm/", # NOTE: Preload the cached dataset into the RAM cache_dir="/dev/shm/", cache=True, ) trainer.fit( model=model, train_dataloaders=train_dataloader, # val_dataloaders=val_dataloader, # Resume training states from the checkpoint file # ckpt_path=ckpt_acoustic, )