from typing import List from lightning.pytorch.core import LightningModule import torch from torch.optim import AdamW, Optimizer, swa_utils from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader from models.config import ( AcousticENModelConfig, AcousticFinetuningConfig, AcousticPretrainingConfig, AcousticTrainingConfig, VocoderFinetuningConfig, VocoderModelConfig, VocoderPretrainingConfig, VoicoderTrainingConfig, get_lang_map, lang2id, ) from models.config import ( PreprocessingConfigUnivNet as PreprocessingConfig, ) from models.helpers.dataloaders import train_dataloader from models.helpers.tools import get_mask_from_lengths # Models from models.tts.delightful_tts.acoustic_model import AcousticModel from models.vocoder.univnet.discriminator import Discriminator from models.vocoder.univnet.generator import Generator from training.loss import FastSpeech2LossGen, UnivnetLoss from training.preprocess.normalize_text import NormalizeText # Updated version of the tokenizer from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA class DelightfulUnivnet(LightningModule): r"""DEPRECATED: This idea is basically wrong. The model should synthesis pretty well mel spectrograms and then use them to generate the waveform based on the good quality mel-spec. Trainer for the acoustic model. Args: fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False. lang (str): Language of the dataset. n_speakers (int): Number of speakers in the dataset.generation during training. batch_size (int): The batch size. acc_grad_steps (int): The number of gradient accumulation steps. swa_steps (int): The number of steps for the SWA update. """ def __init__( self, fine_tuning: bool = True, lang: str = "en", n_speakers: int = 5392, batch_size: int = 12, acc_grad_steps: int = 5, swa_steps: int = 1000, ): super().__init__() # Switch to manual optimization self.automatic_optimization = False self.acc_grad_steps = acc_grad_steps self.swa_steps = swa_steps self.lang = lang self.fine_tuning = fine_tuning self.batch_size = batch_size lang_map = get_lang_map(lang) normilize_text_lang = lang_map.nemo self.tokenizer = TokenizerIPA(lang) self.normilize_text = NormalizeText(normilize_text_lang) # Acoustic model self.train_config_acoustic: AcousticTrainingConfig if self.fine_tuning: self.train_config_acoustic = AcousticFinetuningConfig() else: self.train_config_acoustic = AcousticPretrainingConfig() self.preprocess_config = PreprocessingConfig("english_only") self.model_config_acoustic = AcousticENModelConfig() # TODO: fix the arguments! self.acoustic_model = AcousticModel( preprocess_config=self.preprocess_config, model_config=self.model_config_acoustic, # NOTE: this parameter may be hyperparameter that you can define based on the demands n_speakers=n_speakers, ) # Initialize SWA self.swa_averaged_acoustic = swa_utils.AveragedModel(self.acoustic_model) # NOTE: in case of training from 0 bin_warmup should be True! self.loss_acoustic = FastSpeech2LossGen(bin_warmup=False) # Vocoder models self.model_config_vocoder = VocoderModelConfig() self.train_config: VoicoderTrainingConfig = ( VocoderFinetuningConfig() if fine_tuning else VocoderPretrainingConfig() ) self.univnet = Generator( model_config=self.model_config_vocoder, preprocess_config=self.preprocess_config, ) self.swa_averaged_univnet = swa_utils.AveragedModel(self.univnet) self.discriminator = Discriminator(model_config=self.model_config_vocoder) self.swa_averaged_discriminator = swa_utils.AveragedModel(self.discriminator) self.loss_univnet = UnivnetLoss() def forward( self, text: str, speaker_idx: torch.Tensor, lang: str = "en" ) -> torch.Tensor: r"""Performs a forward pass through the AcousticModel. This code must be run only with the loaded weights from the checkpoint! Args: text (str): The input text. speaker_idx (torch.Tensor): The index of the speaker. lang (str): The language. Returns: torch.Tensor: The output of the AcousticModel. """ normalized_text = self.normilize_text(text) _, phones = self.tokenizer(normalized_text) # Convert to tensor x = torch.tensor( phones, dtype=torch.int, device=speaker_idx.device, ).unsqueeze(0) speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0) langs = ( torch.tensor( [lang2id[lang]], dtype=torch.int, device=speaker_idx.device, ) .repeat(x.shape[1]) .unsqueeze(0) ) y_pred = self.acoustic_model.forward( x=x, speakers=speakers, langs=langs, ) mel_lens = torch.tensor( [y_pred.shape[2]], dtype=torch.int32, device=y_pred.device, ) wav = self.univnet.infer(y_pred, mel_lens) return wav # TODO: don't forget about torch.no_grad() ! # default used by the Trainer # trainer = Trainer(inference_mode=True) # Use `torch.no_grad` instead # trainer = Trainer(inference_mode=False) def training_step(self, batch: List, batch_idx: int): r"""Performs a training step for the model. Args: batch (List): The batch of data for training. The batch should contain: - ids: List of indexes. - raw_texts: Raw text inputs. - speakers: Speaker identities. - texts: Text inputs. - src_lens: Lengths of the source sequences. - mels: Mel spectrogram targets. - pitches: Pitch targets. - pitches_stat: Statistics of the pitches. - mel_lens: Lengths of the mel spectrograms. - langs: Language identities. - attn_priors: Prior attention weights. - wavs: Waveform targets. - energies: Energy targets. batch_idx (int): Index of the batch. Returns: - 'loss': The total loss for the training step. """ ( _, _, speakers, texts, src_lens, mels, pitches, _, mel_lens, langs, attn_priors, audio, energies, ) = batch ##################################### ## Acoustic model train step ## ##################################### outputs = self.acoustic_model.forward_train( x=texts, speakers=speakers, src_lens=src_lens, mels=mels, mel_lens=mel_lens, pitches=pitches, langs=langs, attn_priors=attn_priors, energies=energies, ) y_pred = outputs["y_pred"] log_duration_prediction = outputs["log_duration_prediction"] p_prosody_ref = outputs["p_prosody_ref"] p_prosody_pred = outputs["p_prosody_pred"] pitch_prediction = outputs["pitch_prediction"] energy_pred = outputs["energy_pred"] energy_target = outputs["energy_target"] src_mask = get_mask_from_lengths(src_lens) mel_mask = get_mask_from_lengths(mel_lens) ( acc_total_loss, acc_mel_loss, acc_ssim_loss, acc_duration_loss, acc_u_prosody_loss, acc_p_prosody_loss, acc_pitch_loss, acc_ctc_loss, acc_bin_loss, acc_energy_loss, ) = self.loss_acoustic.forward( src_masks=src_mask, mel_masks=mel_mask, mel_targets=mels, mel_predictions=y_pred, log_duration_predictions=log_duration_prediction, u_prosody_ref=outputs["u_prosody_ref"], u_prosody_pred=outputs["u_prosody_pred"], p_prosody_ref=p_prosody_ref, p_prosody_pred=p_prosody_pred, pitch_predictions=pitch_prediction, p_targets=outputs["pitch_target"], durations=outputs["attn_hard_dur"], attn_logprob=outputs["attn_logprob"], attn_soft=outputs["attn_soft"], attn_hard=outputs["attn_hard"], src_lens=src_lens, mel_lens=mel_lens, energy_pred=energy_pred, energy_target=energy_target, step=self.trainer.global_step, ) self.log( "acc_total_loss", acc_total_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_mel_loss", acc_mel_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_ssim_loss", acc_ssim_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_duration_loss", acc_duration_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "acc_u_prosody_loss", acc_u_prosody_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "acc_p_prosody_loss", acc_p_prosody_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "acc_pitch_loss", acc_pitch_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_ctc_loss", acc_ctc_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_bin_loss", acc_bin_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "acc_energy_loss", acc_energy_loss, sync_dist=True, batch_size=self.batch_size, ) ##################################### ## Univnet model train step ## ##################################### fake_audio = self.univnet.forward(y_pred) res_fake, period_fake = self.discriminator(fake_audio.detach()) res_real, period_real = self.discriminator(audio) ( voc_total_loss_gen, voc_total_loss_disc, voc_stft_loss, voc_score_loss, voc_esr_loss, voc_snr_loss, ) = self.loss_univnet.forward( audio, fake_audio, res_fake, period_fake, res_real, period_real, ) self.log( "voc_total_loss_gen", voc_total_loss_gen, sync_dist=True, batch_size=self.batch_size, ) self.log( "voc_total_loss_disc", voc_total_loss_disc, sync_dist=True, batch_size=self.batch_size, ) self.log( "voc_stft_loss", voc_stft_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "voc_score_loss", voc_score_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "voc_esr_loss", voc_esr_loss, sync_dist=True, batch_size=self.batch_size ) self.log( "voc_snr_loss", voc_snr_loss, sync_dist=True, batch_size=self.batch_size ) # Manual optimizer # Access your optimizers optimizers = self.optimizers() schedulers = self.lr_schedulers() #################################### # Acoustic model manual optimizer ## #################################### opt_acoustic: Optimizer = optimizers[0] # type: ignore sch_acoustic: ExponentialLR = schedulers[0] # type: ignore opt_univnet: Optimizer = optimizers[0] # type: ignore sch_univnet: ExponentialLR = schedulers[0] # type: ignore opt_discriminator: Optimizer = optimizers[1] # type: ignore sch_discriminator: ExponentialLR = schedulers[1] # type: ignore # Backward pass for the acoustic model # NOTE: the loss is divided by the accumulated gradient steps self.manual_backward(acc_total_loss / self.acc_grad_steps, retain_graph=True) # Perform manual optimization univnet self.manual_backward( voc_total_loss_gen / self.acc_grad_steps, retain_graph=True ) self.manual_backward( voc_total_loss_disc / self.acc_grad_steps, retain_graph=True ) # accumulate gradients of N batches if (batch_idx + 1) % self.acc_grad_steps == 0: # Acoustic model optimizer step # clip gradients self.clip_gradients( opt_acoustic, gradient_clip_val=0.5, gradient_clip_algorithm="norm" ) # optimizer step opt_acoustic.step() # Scheduler step sch_acoustic.step() # zero the gradients opt_acoustic.zero_grad() # Univnet model optimizer step # clip gradients self.clip_gradients( opt_univnet, gradient_clip_val=0.5, gradient_clip_algorithm="norm" ) self.clip_gradients( opt_discriminator, gradient_clip_val=0.5, gradient_clip_algorithm="norm" ) # optimizer step opt_univnet.step() opt_discriminator.step() # Scheduler step sch_univnet.step() sch_discriminator.step() # zero the gradients opt_univnet.zero_grad() opt_discriminator.zero_grad() # Update SWA model every swa_steps if self.trainer.global_step % self.swa_steps == 0: self.swa_averaged_acoustic.update_parameters(self.acoustic_model) self.swa_averaged_univnet.update_parameters(self.univnet) self.swa_averaged_discriminator.update_parameters(self.discriminator) def on_train_epoch_end(self): r"""Updates the averaged model after each optimizer step with SWA.""" self.swa_averaged_acoustic.update_parameters(self.acoustic_model) self.swa_averaged_univnet.update_parameters(self.univnet) self.swa_averaged_discriminator.update_parameters(self.discriminator) def configure_optimizers(self): r"""Configures the optimizer used for training. Returns tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models. """ #################################### # Acoustic model optimizer config ## #################################### # Compute the gamma and initial learning rate based on the current step lr_decay = self.train_config_acoustic.optimizer_config.lr_decay default_lr = self.train_config_acoustic.optimizer_config.learning_rate init_lr = ( default_lr if self.trainer.global_step == 0 else default_lr * (lr_decay**self.trainer.global_step) ) optimizer_acoustic = AdamW( self.acoustic_model.parameters(), lr=init_lr, betas=self.train_config_acoustic.optimizer_config.betas, eps=self.train_config_acoustic.optimizer_config.eps, weight_decay=self.train_config_acoustic.optimizer_config.weight_decay, ) scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay) #################################### # Univnet model optimizer config ## #################################### optim_univnet = AdamW( self.univnet.parameters(), self.train_config.learning_rate, betas=(self.train_config.adam_b1, self.train_config.adam_b2), ) scheduler_univnet = ExponentialLR( optim_univnet, gamma=self.train_config.lr_decay, last_epoch=-1, ) #################################### # Discriminator optimizer config ## #################################### optim_discriminator = AdamW( self.discriminator.parameters(), self.train_config.learning_rate, betas=(self.train_config.adam_b1, self.train_config.adam_b2), ) scheduler_discriminator = ExponentialLR( optim_discriminator, gamma=self.train_config.lr_decay, last_epoch=-1, ) return ( {"optimizer": optimizer_acoustic, "lr_scheduler": scheduler_acoustic}, {"optimizer": optim_univnet, "lr_scheduler": scheduler_univnet}, {"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator}, ) def on_train_end(self): # Update SWA models after training swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_acoustic) swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_univnet) swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_discriminator) def train_dataloader( self, num_workers: int = 5, root: str = "datasets_cache/LIBRITTS", cache: bool = True, cache_dir: str = "datasets_cache", mem_cache: bool = False, url: str = "train-960", ) -> DataLoader: r"""Returns the training dataloader, that is using the LibriTTS dataset. Args: num_workers (int): The number of workers. root (str): The root directory of the dataset. cache (bool): Whether to cache the preprocessed data. cache_dir (str): The directory for the cache. mem_cache (bool): Whether to use memory cache. url (str): The URL of the dataset. Returns: Tupple[DataLoader, DataLoader]: The training and validation dataloaders. """ return train_dataloader( batch_size=self.batch_size, num_workers=num_workers, root=root, cache=cache, cache_dir=cache_dir, mem_cache=mem_cache, url=url, lang=self.lang, )