from typing import List from lightning.pytorch.core import LightningModule import torch from torch import Tensor from torch.optim import AdamW from torch.optim.lr_scheduler import ExponentialLR from torch.utils.data import DataLoader from models.config import ( AcousticFinetuningConfig, AcousticModelConfigType, AcousticMultilingualModelConfig, AcousticPretrainingConfig, AcousticTrainingConfig, PreprocessingConfig, get_lang_map, lang2id, ) from models.helpers.tools import get_mask_from_lengths from training.datasets.hifi_libri_dataset import ( speakers_hifi_ids, speakers_libri_ids, train_dataloader, ) from training.loss import FastSpeech2LossGen from training.preprocess.normalize_text import NormalizeText # Updated version of the tokenizer from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA from .acoustic_model import AcousticModel MEL_SPEC_EVERY_N_STEPS = 1000 AUDIO_EVERY_N_STEPS = 100 class DelightfulTTS(LightningModule): r"""Trainer for the acoustic model. Args: preprocess_config PreprocessingConfig: The preprocessing configuration. model_config AcousticModelConfigType: The model configuration. fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False. bin_warmup (bool, optional): Whether to use binarization warmup for the loss or not. Defaults to True. lang (str): Language of the dataset. n_speakers (int): Number of speakers in the dataset.generation during training. batch_size (int): The batch size. """ def __init__( self, preprocess_config: PreprocessingConfig, model_config: AcousticModelConfigType = AcousticMultilingualModelConfig(), fine_tuning: bool = False, bin_warmup: bool = True, lang: str = "en", n_speakers: int = 5392, batch_size: int = 19, ): super().__init__() self.lang = lang self.lang_id = lang2id[self.lang] self.fine_tuning = fine_tuning self.batch_size = batch_size lang_map = get_lang_map(lang) normilize_text_lang = lang_map.nemo self.tokenizer = TokenizerIPA(lang) self.normilize_text = NormalizeText(normilize_text_lang) self.train_config_acoustic: AcousticTrainingConfig if self.fine_tuning: self.train_config_acoustic = AcousticFinetuningConfig() else: self.train_config_acoustic = AcousticPretrainingConfig() self.preprocess_config = preprocess_config # TODO: fix the arguments! self.acoustic_model = AcousticModel( preprocess_config=self.preprocess_config, model_config=model_config, # NOTE: this parameter may be hyperparameter that you can define based on the demands n_speakers=n_speakers, ) # NOTE: in case of training from 0 bin_warmup should be True! self.loss_acoustic = FastSpeech2LossGen( bin_warmup=bin_warmup, ) def forward( self, text: str, speaker_idx: Tensor, ) -> Tensor: r"""Performs a forward pass through the AcousticModel. This code must be run only with the loaded weights from the checkpoint! Args: text (str): The input text. speaker_idx (Tensor): The index of the speaker Returns: Tensor: The generated waveform with hifi-gan. """ normalized_text = self.normilize_text(text) _, phones = self.tokenizer(normalized_text) # Convert to tensor x = torch.tensor( phones, dtype=torch.int, device=speaker_idx.device, ).unsqueeze(0) speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0) langs = ( torch.tensor( [self.lang_id], dtype=torch.int, device=speaker_idx.device, ) .repeat(x.shape[1]) .unsqueeze(0) ) mel_pred = self.acoustic_model.forward( x=x, speakers=speakers, langs=langs, ) return mel_pred def training_step(self, batch: List, _: int): r"""Performs a training step for the model. Args: batch (List): The batch of data for training. The batch should contain: - ids: List of indexes. - raw_texts: Raw text inputs. - speakers: Speaker identities. - texts: Text inputs. - src_lens: Lengths of the source sequences. - mels: Mel spectrogram targets. - pitches: Pitch targets. - pitches_stat: Statistics of the pitches. - mel_lens: Lengths of the mel spectrograms. - langs: Language identities. - attn_priors: Prior attention weights. - wavs: Waveform targets. - energies: Energy targets. batch_idx (int): Index of the batch. Returns: - 'loss': The total loss for the training step. """ ( _, _, speakers, texts, src_lens, mels, pitches, _, mel_lens, langs, attn_priors, _, energies, ) = batch outputs = self.acoustic_model.forward_train( x=texts, speakers=speakers, src_lens=src_lens, mels=mels, mel_lens=mel_lens, pitches=pitches, langs=langs, attn_priors=attn_priors, energies=energies, ) y_pred = outputs["y_pred"] log_duration_prediction = outputs["log_duration_prediction"] p_prosody_ref = outputs["p_prosody_ref"] p_prosody_pred = outputs["p_prosody_pred"] pitch_prediction = outputs["pitch_prediction"] energy_pred = outputs["energy_pred"] energy_target = outputs["energy_target"] src_mask = get_mask_from_lengths(src_lens) mel_mask = get_mask_from_lengths(mel_lens) ( total_loss, mel_loss, ssim_loss, duration_loss, u_prosody_loss, p_prosody_loss, pitch_loss, ctc_loss, bin_loss, energy_loss, ) = self.loss_acoustic.forward( src_masks=src_mask, mel_masks=mel_mask, mel_targets=mels, mel_predictions=y_pred, log_duration_predictions=log_duration_prediction, u_prosody_ref=outputs["u_prosody_ref"], u_prosody_pred=outputs["u_prosody_pred"], p_prosody_ref=p_prosody_ref, p_prosody_pred=p_prosody_pred, pitch_predictions=pitch_prediction, p_targets=outputs["pitch_target"], durations=outputs["attn_hard_dur"], attn_logprob=outputs["attn_logprob"], attn_soft=outputs["attn_soft"], attn_hard=outputs["attn_hard"], src_lens=src_lens, mel_lens=mel_lens, energy_pred=energy_pred, energy_target=energy_target, step=self.trainer.global_step, ) self.log( "train_total_loss", total_loss, sync_dist=True, batch_size=self.batch_size, ) self.log("train_mel_loss", mel_loss, sync_dist=True, batch_size=self.batch_size) self.log( "train_ssim_loss", ssim_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "train_duration_loss", duration_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "train_u_prosody_loss", u_prosody_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "train_p_prosody_loss", p_prosody_loss, sync_dist=True, batch_size=self.batch_size, ) self.log( "train_pitch_loss", pitch_loss, sync_dist=True, batch_size=self.batch_size, ) self.log("train_ctc_loss", ctc_loss, sync_dist=True, batch_size=self.batch_size) self.log("train_bin_loss", bin_loss, sync_dist=True, batch_size=self.batch_size) self.log( "train_energy_loss", energy_loss, sync_dist=True, batch_size=self.batch_size, ) return total_loss def configure_optimizers(self): r"""Configures the optimizer used for training. Returns tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models. """ lr_decay = self.train_config_acoustic.optimizer_config.lr_decay default_lr = self.train_config_acoustic.optimizer_config.learning_rate init_lr = ( default_lr if self.trainer.global_step == 0 else default_lr * (lr_decay**self.trainer.global_step) ) optimizer_acoustic = AdamW( self.acoustic_model.parameters(), lr=init_lr, betas=self.train_config_acoustic.optimizer_config.betas, eps=self.train_config_acoustic.optimizer_config.eps, weight_decay=self.train_config_acoustic.optimizer_config.weight_decay, ) scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay) return { "optimizer": optimizer_acoustic, "lr_scheduler": scheduler_acoustic, } def train_dataloader( self, root: str = "datasets_cache", cache: bool = True, cache_dir: str = "/dev/shm", include_libri: bool = False, libri_speakers: List[str] = speakers_libri_ids, hifi_speakers: List[str] = speakers_hifi_ids, ) -> DataLoader: r"""Returns the training dataloader, that is using the LibriTTS dataset. Args: root (str): The root directory of the dataset. cache (bool): Whether to cache the preprocessed data. cache_dir (str): The directory for the cache. Defaults to "/dev/shm". include_libri (bool): Whether to include the LibriTTS dataset or not. libri_speakers (List[str]): The list of LibriTTS speakers to include. hifi_speakers (List[str]): The list of HiFi-GAN speakers to include. Returns: Tupple[DataLoader, DataLoader]: The training and validation dataloaders. """ return train_dataloader( batch_size=self.batch_size, num_workers=self.preprocess_config.workers, sampling_rate=self.preprocess_config.sampling_rate, root=root, cache=cache, cache_dir=cache_dir, lang=self.lang, include_libri=include_libri, libri_speakers=libri_speakers, hifi_speakers=hifi_speakers, )