from typing import List, Optional, Tuple from lightning.pytorch.core import LightningModule import torch from torch.optim import AdamW, Optimizer, swa_utils from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader from models.config import ( PreprocessingConfigUnivNet as PreprocessingConfig, ) from models.config import ( VocoderFinetuningConfig, VocoderModelConfig, VocoderPretrainingConfig, VoicoderTrainingConfig, ) from models.helpers.dataloaders import train_dataloader from training.loss import UnivnetLoss from .discriminator import Discriminator from .generator import Generator class UnivNet(LightningModule): r"""Univnet module. This module contains the `Generator` and `Discriminator` models, and handles training and optimization. """ def __init__( self, fine_tuning: bool = False, lang: str = "en", acc_grad_steps: int = 10, batch_size: int = 6, root: str = "datasets_cache/LIBRITTS", checkpoint_path_v1: Optional[str] = "vocoder_pretrained.pt", ): r"""Initializes the `VocoderModule`. Args: fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False. lang (str): Language of the dataset. acc_grad_steps (int): Accumulated gradient steps. batch_size (int): The batch size. root (str, optional): The root directory for the dataset. Defaults to "datasets_cache/LIBRITTS". checkpoint_path_v1 (str, optional): The path to the checkpoint for the model. If provided, the model weights will be loaded from this checkpoint. Defaults to None. """ super().__init__() # Switch to manual optimization self.automatic_optimization = False self.acc_grad_steps = acc_grad_steps self.batch_size = batch_size self.lang = lang self.root = root model_config = VocoderModelConfig() preprocess_config = PreprocessingConfig("english_only") self.univnet = Generator( model_config=model_config, preprocess_config=preprocess_config, ) self.discriminator = Discriminator(model_config=model_config) # Initialize SWA self.swa_averaged_univnet = swa_utils.AveragedModel(self.univnet) self.swa_averaged_discriminator = swa_utils.AveragedModel(self.discriminator) self.loss = UnivnetLoss() self.train_config: VoicoderTrainingConfig = ( VocoderFinetuningConfig() if fine_tuning else VocoderPretrainingConfig() ) # NOTE: this code is used only for the v0.1.0 checkpoint. # In the future, this code will be removed! self.checkpoint_path_v1 = checkpoint_path_v1 if checkpoint_path_v1 is not None: generator, discriminator, _, _ = self.get_weights_v1(checkpoint_path_v1) self.univnet.load_state_dict(generator, strict=False) self.discriminator.load_state_dict(discriminator, strict=False) def get_weights_v1(self, checkpoint_path: str) -> Tuple[dict, dict, dict, dict]: r"""NOTE: this method is used only for the v0.1.0 checkpoint. Prepares the weights for the model. This is required for the model to be loaded from the checkpoint. Args: checkpoint_path (str): The path to the checkpoint. Returns: Tuple[dict, dict, dict, dict]: The weights for the generator and discriminator. """ ckpt_acoustic = torch.load(checkpoint_path, map_location=torch.device("cpu")) return ( ckpt_acoustic["generator"], ckpt_acoustic["discriminator"], ckpt_acoustic["optim_g"], ckpt_acoustic["optim_d"], ) def forward(self, y_pred: torch.Tensor) -> torch.Tensor: r"""Performs a forward pass through the UnivNet model. Args: y_pred (torch.Tensor): The predicted mel spectrogram. Returns: torch.Tensor: The output of the UnivNet model. """ mel_lens = torch.tensor( [y_pred.shape[2]], dtype=torch.int32, device=y_pred.device, ) wav_prediction = self.univnet.infer(y_pred, mel_lens) return wav_prediction[0, 0] def training_step(self, batch: List, batch_idx: int): r"""Performs a training step for the model. Args: batch (List): The batch of data for training. The batch should contain the mel spectrogram, its length, the audio, and the speaker ID. batch_idx (int): Index of the batch. Returns: dict: A dictionary containing the total loss for the generator and logs for tensorboard. """ ( _, _, _, _, _, mels, _, _, _, _, _, wavs, _, ) = batch # Access your optimizers optimizers = self.optimizers() schedulers = self.lr_schedulers() opt_univnet: Optimizer = optimizers[0] # type: ignore sch_univnet: ExponentialLR = schedulers[0] # type: ignore opt_discriminator: Optimizer = optimizers[1] # type: ignore sch_discriminator: ExponentialLR = schedulers[1] # type: ignore audio = wavs fake_audio = self.univnet(mels) res_fake, period_fake = self.discriminator(fake_audio.detach()) res_real, period_real = self.discriminator(audio) ( total_loss_gen, total_loss_disc, stft_loss, score_loss, esr_loss, snr_loss, ) = self.loss.forward( audio, fake_audio, res_fake, period_fake, res_real, period_real, ) self.log( "total_loss_gen", total_loss_gen, sync_dist=True, batch_size=self.batch_size, ) self.log( "total_loss_disc", total_loss_disc, sync_dist=True, batch_size=self.batch_size, ) self.log("stft_loss", stft_loss, sync_dist=True, batch_size=self.batch_size) self.log("esr_loss", esr_loss, sync_dist=True, batch_size=self.batch_size) self.log("snr_loss", snr_loss, sync_dist=True, batch_size=self.batch_size) self.log("score_loss", score_loss, sync_dist=True, batch_size=self.batch_size) # Perform manual optimization self.manual_backward(total_loss_gen / self.acc_grad_steps, retain_graph=True) self.manual_backward(total_loss_disc / self.acc_grad_steps, retain_graph=True) # accumulate gradients of N batches if (batch_idx + 1) % self.acc_grad_steps == 0: # clip gradients self.clip_gradients( opt_univnet, gradient_clip_val=0.5, gradient_clip_algorithm="norm", ) self.clip_gradients( opt_discriminator, gradient_clip_val=0.5, gradient_clip_algorithm="norm", ) # optimizer step opt_univnet.step() opt_discriminator.step() # Scheduler step sch_univnet.step() sch_discriminator.step() # zero the gradients opt_univnet.zero_grad() opt_discriminator.zero_grad() def configure_optimizers(self): r"""Configures the optimizers and learning rate schedulers for the `UnivNet` and `Discriminator` models. This method creates an `AdamW` optimizer and an `ExponentialLR` scheduler for each model. The learning rate, betas, and decay rate for the optimizers and schedulers are taken from the training configuration. Returns tuple: A tuple containing two dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models. Examples ```python vocoder_module = VocoderModule() optimizers = vocoder_module.configure_optimizers() print(optimizers) ( {"optimizer": , "lr_scheduler": }, {"optimizer": , "lr_scheduler": } ) ``` """ optim_univnet = AdamW( self.univnet.parameters(), self.train_config.learning_rate, betas=(self.train_config.adam_b1, self.train_config.adam_b2), ) scheduler_univnet = ExponentialLR( optim_univnet, gamma=self.train_config.lr_decay, last_epoch=-1, ) optim_discriminator = AdamW( self.discriminator.parameters(), self.train_config.learning_rate, betas=(self.train_config.adam_b1, self.train_config.adam_b2), ) scheduler_discriminator = ExponentialLR( optim_discriminator, gamma=self.train_config.lr_decay, last_epoch=-1, ) # NOTE: this code is used only for the v0.1.0 checkpoint. # In the future, this code will be removed! if self.checkpoint_path_v1 is not None: _, _, optim_g, optim_d = self.get_weights_v1(self.checkpoint_path_v1) optim_univnet.load_state_dict(optim_g) optim_discriminator.load_state_dict(optim_d) return ( {"optimizer": optim_univnet, "lr_scheduler": scheduler_univnet}, {"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator}, ) def on_train_epoch_end(self): r"""Updates the averaged model after each optimizer step with SWA.""" self.swa_averaged_univnet.update_parameters(self.univnet) self.swa_averaged_discriminator.update_parameters(self.discriminator) def on_train_end(self): # Update SWA model after training swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_univnet) swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_discriminator) def train_dataloader( self, num_workers: int = 5, root: str = "datasets_cache/LIBRITTS", cache: bool = True, cache_dir: str = "datasets_cache", mem_cache: bool = False, url: str = "train-clean-360", ) -> DataLoader: r"""Returns the training dataloader, that is using the LibriTTS dataset. Args: num_workers (int): The number of workers. root (str): The root directory of the dataset. cache (bool): Whether to cache the preprocessed data. cache_dir (str): The directory for the cache. mem_cache (bool): Whether to use memory cache. url (str): The URL of the dataset. Returns: DataLoader: The training and validation dataloaders. """ return train_dataloader( batch_size=self.batch_size, num_workers=num_workers, root=root, cache=cache, cache_dir=cache_dir, mem_cache=mem_cache, url=url, lang=self.lang, )