nickovchinnikov's picture
Init
9d61c9b
from typing import List
from lightning.pytorch.core import LightningModule
import torch
from torch import Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from models.config import (
AcousticFinetuningConfig,
AcousticModelConfigType,
AcousticMultilingualModelConfig,
AcousticPretrainingConfig,
AcousticTrainingConfig,
PreprocessingConfig,
get_lang_map,
lang2id,
)
from models.helpers.tools import get_mask_from_lengths
from training.datasets.hifi_libri_dataset import (
speakers_hifi_ids,
speakers_libri_ids,
train_dataloader,
)
from training.loss import FastSpeech2LossGen
from training.preprocess.normalize_text import NormalizeText
# Updated version of the tokenizer
from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA
from .acoustic_model import AcousticModel
MEL_SPEC_EVERY_N_STEPS = 1000
AUDIO_EVERY_N_STEPS = 100
class DelightfulTTS(LightningModule):
r"""Trainer for the acoustic model.
Args:
preprocess_config PreprocessingConfig: The preprocessing configuration.
model_config AcousticModelConfigType: The model configuration.
fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False.
bin_warmup (bool, optional): Whether to use binarization warmup for the loss or not. Defaults to True.
lang (str): Language of the dataset.
n_speakers (int): Number of speakers in the dataset.generation during training.
batch_size (int): The batch size.
"""
def __init__(
self,
preprocess_config: PreprocessingConfig,
model_config: AcousticModelConfigType = AcousticMultilingualModelConfig(),
fine_tuning: bool = False,
bin_warmup: bool = True,
lang: str = "en",
n_speakers: int = 5392,
batch_size: int = 19,
):
super().__init__()
self.lang = lang
self.lang_id = lang2id[self.lang]
self.fine_tuning = fine_tuning
self.batch_size = batch_size
lang_map = get_lang_map(lang)
normilize_text_lang = lang_map.nemo
self.tokenizer = TokenizerIPA(lang)
self.normilize_text = NormalizeText(normilize_text_lang)
self.train_config_acoustic: AcousticTrainingConfig
if self.fine_tuning:
self.train_config_acoustic = AcousticFinetuningConfig()
else:
self.train_config_acoustic = AcousticPretrainingConfig()
self.preprocess_config = preprocess_config
# TODO: fix the arguments!
self.acoustic_model = AcousticModel(
preprocess_config=self.preprocess_config,
model_config=model_config,
# NOTE: this parameter may be hyperparameter that you can define based on the demands
n_speakers=n_speakers,
)
# NOTE: in case of training from 0 bin_warmup should be True!
self.loss_acoustic = FastSpeech2LossGen(
bin_warmup=bin_warmup,
)
def forward(
self,
text: str,
speaker_idx: Tensor,
) -> Tensor:
r"""Performs a forward pass through the AcousticModel.
This code must be run only with the loaded weights from the checkpoint!
Args:
text (str): The input text.
speaker_idx (Tensor): The index of the speaker
Returns:
Tensor: The generated waveform with hifi-gan.
"""
normalized_text = self.normilize_text(text)
_, phones = self.tokenizer(normalized_text)
# Convert to tensor
x = torch.tensor(
phones,
dtype=torch.int,
device=speaker_idx.device,
).unsqueeze(0)
speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0)
langs = (
torch.tensor(
[self.lang_id],
dtype=torch.int,
device=speaker_idx.device,
)
.repeat(x.shape[1])
.unsqueeze(0)
)
mel_pred = self.acoustic_model.forward(
x=x,
speakers=speakers,
langs=langs,
)
return mel_pred
def training_step(self, batch: List, _: int):
r"""Performs a training step for the model.
Args:
batch (List): The batch of data for training. The batch should contain:
- ids: List of indexes.
- raw_texts: Raw text inputs.
- speakers: Speaker identities.
- texts: Text inputs.
- src_lens: Lengths of the source sequences.
- mels: Mel spectrogram targets.
- pitches: Pitch targets.
- pitches_stat: Statistics of the pitches.
- mel_lens: Lengths of the mel spectrograms.
- langs: Language identities.
- attn_priors: Prior attention weights.
- wavs: Waveform targets.
- energies: Energy targets.
batch_idx (int): Index of the batch.
Returns:
- 'loss': The total loss for the training step.
"""
(
_,
_,
speakers,
texts,
src_lens,
mels,
pitches,
_,
mel_lens,
langs,
attn_priors,
_,
energies,
) = batch
outputs = self.acoustic_model.forward_train(
x=texts,
speakers=speakers,
src_lens=src_lens,
mels=mels,
mel_lens=mel_lens,
pitches=pitches,
langs=langs,
attn_priors=attn_priors,
energies=energies,
)
y_pred = outputs["y_pred"]
log_duration_prediction = outputs["log_duration_prediction"]
p_prosody_ref = outputs["p_prosody_ref"]
p_prosody_pred = outputs["p_prosody_pred"]
pitch_prediction = outputs["pitch_prediction"]
energy_pred = outputs["energy_pred"]
energy_target = outputs["energy_target"]
src_mask = get_mask_from_lengths(src_lens)
mel_mask = get_mask_from_lengths(mel_lens)
(
total_loss,
mel_loss,
ssim_loss,
duration_loss,
u_prosody_loss,
p_prosody_loss,
pitch_loss,
ctc_loss,
bin_loss,
energy_loss,
) = self.loss_acoustic.forward(
src_masks=src_mask,
mel_masks=mel_mask,
mel_targets=mels,
mel_predictions=y_pred,
log_duration_predictions=log_duration_prediction,
u_prosody_ref=outputs["u_prosody_ref"],
u_prosody_pred=outputs["u_prosody_pred"],
p_prosody_ref=p_prosody_ref,
p_prosody_pred=p_prosody_pred,
pitch_predictions=pitch_prediction,
p_targets=outputs["pitch_target"],
durations=outputs["attn_hard_dur"],
attn_logprob=outputs["attn_logprob"],
attn_soft=outputs["attn_soft"],
attn_hard=outputs["attn_hard"],
src_lens=src_lens,
mel_lens=mel_lens,
energy_pred=energy_pred,
energy_target=energy_target,
step=self.trainer.global_step,
)
self.log(
"train_total_loss",
total_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log("train_mel_loss", mel_loss, sync_dist=True, batch_size=self.batch_size)
self.log(
"train_ssim_loss",
ssim_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"train_duration_loss",
duration_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"train_u_prosody_loss",
u_prosody_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"train_p_prosody_loss",
p_prosody_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"train_pitch_loss",
pitch_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log("train_ctc_loss", ctc_loss, sync_dist=True, batch_size=self.batch_size)
self.log("train_bin_loss", bin_loss, sync_dist=True, batch_size=self.batch_size)
self.log(
"train_energy_loss",
energy_loss,
sync_dist=True,
batch_size=self.batch_size,
)
return total_loss
def configure_optimizers(self):
r"""Configures the optimizer used for training.
Returns
tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models.
"""
lr_decay = self.train_config_acoustic.optimizer_config.lr_decay
default_lr = self.train_config_acoustic.optimizer_config.learning_rate
init_lr = (
default_lr
if self.trainer.global_step == 0
else default_lr * (lr_decay**self.trainer.global_step)
)
optimizer_acoustic = AdamW(
self.acoustic_model.parameters(),
lr=init_lr,
betas=self.train_config_acoustic.optimizer_config.betas,
eps=self.train_config_acoustic.optimizer_config.eps,
weight_decay=self.train_config_acoustic.optimizer_config.weight_decay,
)
scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay)
return {
"optimizer": optimizer_acoustic,
"lr_scheduler": scheduler_acoustic,
}
def train_dataloader(
self,
root: str = "datasets_cache",
cache: bool = True,
cache_dir: str = "/dev/shm",
include_libri: bool = False,
libri_speakers: List[str] = speakers_libri_ids,
hifi_speakers: List[str] = speakers_hifi_ids,
) -> DataLoader:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache. Defaults to "/dev/shm".
include_libri (bool): Whether to include the LibriTTS dataset or not.
libri_speakers (List[str]): The list of LibriTTS speakers to include.
hifi_speakers (List[str]): The list of HiFi-GAN speakers to include.
Returns:
Tupple[DataLoader, DataLoader]: The training and validation dataloaders.
"""
return train_dataloader(
batch_size=self.batch_size,
num_workers=self.preprocess_config.workers,
sampling_rate=self.preprocess_config.sampling_rate,
root=root,
cache=cache,
cache_dir=cache_dir,
lang=self.lang,
include_libri=include_libri,
libri_speakers=libri_speakers,
hifi_speakers=hifi_speakers,
)