PeechTTSv22050 / models /generators /delightful_univnet.py
nickovchinnikov's picture
Init
9d61c9b
from typing import List
from lightning.pytorch.core import LightningModule
import torch
from torch.optim import AdamW, Optimizer, swa_utils
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from models.config import (
AcousticENModelConfig,
AcousticFinetuningConfig,
AcousticPretrainingConfig,
AcousticTrainingConfig,
VocoderFinetuningConfig,
VocoderModelConfig,
VocoderPretrainingConfig,
VoicoderTrainingConfig,
get_lang_map,
lang2id,
)
from models.config import (
PreprocessingConfigUnivNet as PreprocessingConfig,
)
from models.helpers.dataloaders import train_dataloader
from models.helpers.tools import get_mask_from_lengths
# Models
from models.tts.delightful_tts.acoustic_model import AcousticModel
from models.vocoder.univnet.discriminator import Discriminator
from models.vocoder.univnet.generator import Generator
from training.loss import FastSpeech2LossGen, UnivnetLoss
from training.preprocess.normalize_text import NormalizeText
# Updated version of the tokenizer
from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA
class DelightfulUnivnet(LightningModule):
r"""DEPRECATED: This idea is basically wrong. The model should synthesis pretty well mel spectrograms and then use them to generate the waveform based on the good quality mel-spec.
Trainer for the acoustic model.
Args:
fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False.
lang (str): Language of the dataset.
n_speakers (int): Number of speakers in the dataset.generation during training.
batch_size (int): The batch size.
acc_grad_steps (int): The number of gradient accumulation steps.
swa_steps (int): The number of steps for the SWA update.
"""
def __init__(
self,
fine_tuning: bool = True,
lang: str = "en",
n_speakers: int = 5392,
batch_size: int = 12,
acc_grad_steps: int = 5,
swa_steps: int = 1000,
):
super().__init__()
# Switch to manual optimization
self.automatic_optimization = False
self.acc_grad_steps = acc_grad_steps
self.swa_steps = swa_steps
self.lang = lang
self.fine_tuning = fine_tuning
self.batch_size = batch_size
lang_map = get_lang_map(lang)
normilize_text_lang = lang_map.nemo
self.tokenizer = TokenizerIPA(lang)
self.normilize_text = NormalizeText(normilize_text_lang)
# Acoustic model
self.train_config_acoustic: AcousticTrainingConfig
if self.fine_tuning:
self.train_config_acoustic = AcousticFinetuningConfig()
else:
self.train_config_acoustic = AcousticPretrainingConfig()
self.preprocess_config = PreprocessingConfig("english_only")
self.model_config_acoustic = AcousticENModelConfig()
# TODO: fix the arguments!
self.acoustic_model = AcousticModel(
preprocess_config=self.preprocess_config,
model_config=self.model_config_acoustic,
# NOTE: this parameter may be hyperparameter that you can define based on the demands
n_speakers=n_speakers,
)
# Initialize SWA
self.swa_averaged_acoustic = swa_utils.AveragedModel(self.acoustic_model)
# NOTE: in case of training from 0 bin_warmup should be True!
self.loss_acoustic = FastSpeech2LossGen(bin_warmup=False)
# Vocoder models
self.model_config_vocoder = VocoderModelConfig()
self.train_config: VoicoderTrainingConfig = (
VocoderFinetuningConfig() if fine_tuning else VocoderPretrainingConfig()
)
self.univnet = Generator(
model_config=self.model_config_vocoder,
preprocess_config=self.preprocess_config,
)
self.swa_averaged_univnet = swa_utils.AveragedModel(self.univnet)
self.discriminator = Discriminator(model_config=self.model_config_vocoder)
self.swa_averaged_discriminator = swa_utils.AveragedModel(self.discriminator)
self.loss_univnet = UnivnetLoss()
def forward(
self, text: str, speaker_idx: torch.Tensor, lang: str = "en"
) -> torch.Tensor:
r"""Performs a forward pass through the AcousticModel.
This code must be run only with the loaded weights from the checkpoint!
Args:
text (str): The input text.
speaker_idx (torch.Tensor): The index of the speaker.
lang (str): The language.
Returns:
torch.Tensor: The output of the AcousticModel.
"""
normalized_text = self.normilize_text(text)
_, phones = self.tokenizer(normalized_text)
# Convert to tensor
x = torch.tensor(
phones,
dtype=torch.int,
device=speaker_idx.device,
).unsqueeze(0)
speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0)
langs = (
torch.tensor(
[lang2id[lang]],
dtype=torch.int,
device=speaker_idx.device,
)
.repeat(x.shape[1])
.unsqueeze(0)
)
y_pred = self.acoustic_model.forward(
x=x,
speakers=speakers,
langs=langs,
)
mel_lens = torch.tensor(
[y_pred.shape[2]],
dtype=torch.int32,
device=y_pred.device,
)
wav = self.univnet.infer(y_pred, mel_lens)
return wav
# TODO: don't forget about torch.no_grad() !
# default used by the Trainer
# trainer = Trainer(inference_mode=True)
# Use `torch.no_grad` instead
# trainer = Trainer(inference_mode=False)
def training_step(self, batch: List, batch_idx: int):
r"""Performs a training step for the model.
Args:
batch (List): The batch of data for training. The batch should contain:
- ids: List of indexes.
- raw_texts: Raw text inputs.
- speakers: Speaker identities.
- texts: Text inputs.
- src_lens: Lengths of the source sequences.
- mels: Mel spectrogram targets.
- pitches: Pitch targets.
- pitches_stat: Statistics of the pitches.
- mel_lens: Lengths of the mel spectrograms.
- langs: Language identities.
- attn_priors: Prior attention weights.
- wavs: Waveform targets.
- energies: Energy targets.
batch_idx (int): Index of the batch.
Returns:
- 'loss': The total loss for the training step.
"""
(
_,
_,
speakers,
texts,
src_lens,
mels,
pitches,
_,
mel_lens,
langs,
attn_priors,
audio,
energies,
) = batch
#####################################
## Acoustic model train step ##
#####################################
outputs = self.acoustic_model.forward_train(
x=texts,
speakers=speakers,
src_lens=src_lens,
mels=mels,
mel_lens=mel_lens,
pitches=pitches,
langs=langs,
attn_priors=attn_priors,
energies=energies,
)
y_pred = outputs["y_pred"]
log_duration_prediction = outputs["log_duration_prediction"]
p_prosody_ref = outputs["p_prosody_ref"]
p_prosody_pred = outputs["p_prosody_pred"]
pitch_prediction = outputs["pitch_prediction"]
energy_pred = outputs["energy_pred"]
energy_target = outputs["energy_target"]
src_mask = get_mask_from_lengths(src_lens)
mel_mask = get_mask_from_lengths(mel_lens)
(
acc_total_loss,
acc_mel_loss,
acc_ssim_loss,
acc_duration_loss,
acc_u_prosody_loss,
acc_p_prosody_loss,
acc_pitch_loss,
acc_ctc_loss,
acc_bin_loss,
acc_energy_loss,
) = self.loss_acoustic.forward(
src_masks=src_mask,
mel_masks=mel_mask,
mel_targets=mels,
mel_predictions=y_pred,
log_duration_predictions=log_duration_prediction,
u_prosody_ref=outputs["u_prosody_ref"],
u_prosody_pred=outputs["u_prosody_pred"],
p_prosody_ref=p_prosody_ref,
p_prosody_pred=p_prosody_pred,
pitch_predictions=pitch_prediction,
p_targets=outputs["pitch_target"],
durations=outputs["attn_hard_dur"],
attn_logprob=outputs["attn_logprob"],
attn_soft=outputs["attn_soft"],
attn_hard=outputs["attn_hard"],
src_lens=src_lens,
mel_lens=mel_lens,
energy_pred=energy_pred,
energy_target=energy_target,
step=self.trainer.global_step,
)
self.log(
"acc_total_loss", acc_total_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_mel_loss", acc_mel_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_ssim_loss", acc_ssim_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_duration_loss",
acc_duration_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"acc_u_prosody_loss",
acc_u_prosody_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"acc_p_prosody_loss",
acc_p_prosody_loss,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"acc_pitch_loss", acc_pitch_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_ctc_loss", acc_ctc_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_bin_loss", acc_bin_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"acc_energy_loss",
acc_energy_loss,
sync_dist=True,
batch_size=self.batch_size,
)
#####################################
## Univnet model train step ##
#####################################
fake_audio = self.univnet.forward(y_pred)
res_fake, period_fake = self.discriminator(fake_audio.detach())
res_real, period_real = self.discriminator(audio)
(
voc_total_loss_gen,
voc_total_loss_disc,
voc_stft_loss,
voc_score_loss,
voc_esr_loss,
voc_snr_loss,
) = self.loss_univnet.forward(
audio,
fake_audio,
res_fake,
period_fake,
res_real,
period_real,
)
self.log(
"voc_total_loss_gen",
voc_total_loss_gen,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"voc_total_loss_disc",
voc_total_loss_disc,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"voc_stft_loss", voc_stft_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"voc_score_loss", voc_score_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"voc_esr_loss", voc_esr_loss, sync_dist=True, batch_size=self.batch_size
)
self.log(
"voc_snr_loss", voc_snr_loss, sync_dist=True, batch_size=self.batch_size
)
# Manual optimizer
# Access your optimizers
optimizers = self.optimizers()
schedulers = self.lr_schedulers()
####################################
# Acoustic model manual optimizer ##
####################################
opt_acoustic: Optimizer = optimizers[0] # type: ignore
sch_acoustic: ExponentialLR = schedulers[0] # type: ignore
opt_univnet: Optimizer = optimizers[0] # type: ignore
sch_univnet: ExponentialLR = schedulers[0] # type: ignore
opt_discriminator: Optimizer = optimizers[1] # type: ignore
sch_discriminator: ExponentialLR = schedulers[1] # type: ignore
# Backward pass for the acoustic model
# NOTE: the loss is divided by the accumulated gradient steps
self.manual_backward(acc_total_loss / self.acc_grad_steps, retain_graph=True)
# Perform manual optimization univnet
self.manual_backward(
voc_total_loss_gen / self.acc_grad_steps, retain_graph=True
)
self.manual_backward(
voc_total_loss_disc / self.acc_grad_steps, retain_graph=True
)
# accumulate gradients of N batches
if (batch_idx + 1) % self.acc_grad_steps == 0:
# Acoustic model optimizer step
# clip gradients
self.clip_gradients(
opt_acoustic, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
# optimizer step
opt_acoustic.step()
# Scheduler step
sch_acoustic.step()
# zero the gradients
opt_acoustic.zero_grad()
# Univnet model optimizer step
# clip gradients
self.clip_gradients(
opt_univnet, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
self.clip_gradients(
opt_discriminator, gradient_clip_val=0.5, gradient_clip_algorithm="norm"
)
# optimizer step
opt_univnet.step()
opt_discriminator.step()
# Scheduler step
sch_univnet.step()
sch_discriminator.step()
# zero the gradients
opt_univnet.zero_grad()
opt_discriminator.zero_grad()
# Update SWA model every swa_steps
if self.trainer.global_step % self.swa_steps == 0:
self.swa_averaged_acoustic.update_parameters(self.acoustic_model)
self.swa_averaged_univnet.update_parameters(self.univnet)
self.swa_averaged_discriminator.update_parameters(self.discriminator)
def on_train_epoch_end(self):
r"""Updates the averaged model after each optimizer step with SWA."""
self.swa_averaged_acoustic.update_parameters(self.acoustic_model)
self.swa_averaged_univnet.update_parameters(self.univnet)
self.swa_averaged_discriminator.update_parameters(self.discriminator)
def configure_optimizers(self):
r"""Configures the optimizer used for training.
Returns
tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models.
"""
####################################
# Acoustic model optimizer config ##
####################################
# Compute the gamma and initial learning rate based on the current step
lr_decay = self.train_config_acoustic.optimizer_config.lr_decay
default_lr = self.train_config_acoustic.optimizer_config.learning_rate
init_lr = (
default_lr
if self.trainer.global_step == 0
else default_lr * (lr_decay**self.trainer.global_step)
)
optimizer_acoustic = AdamW(
self.acoustic_model.parameters(),
lr=init_lr,
betas=self.train_config_acoustic.optimizer_config.betas,
eps=self.train_config_acoustic.optimizer_config.eps,
weight_decay=self.train_config_acoustic.optimizer_config.weight_decay,
)
scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay)
####################################
# Univnet model optimizer config ##
####################################
optim_univnet = AdamW(
self.univnet.parameters(),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_univnet = ExponentialLR(
optim_univnet,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
####################################
# Discriminator optimizer config ##
####################################
optim_discriminator = AdamW(
self.discriminator.parameters(),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_discriminator = ExponentialLR(
optim_discriminator,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
return (
{"optimizer": optimizer_acoustic, "lr_scheduler": scheduler_acoustic},
{"optimizer": optim_univnet, "lr_scheduler": scheduler_univnet},
{"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator},
)
def on_train_end(self):
# Update SWA models after training
swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_acoustic)
swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_univnet)
swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_discriminator)
def train_dataloader(
self,
num_workers: int = 5,
root: str = "datasets_cache/LIBRITTS",
cache: bool = True,
cache_dir: str = "datasets_cache",
mem_cache: bool = False,
url: str = "train-960",
) -> DataLoader:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
num_workers (int): The number of workers.
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache.
mem_cache (bool): Whether to use memory cache.
url (str): The URL of the dataset.
Returns:
Tupple[DataLoader, DataLoader]: The training and validation dataloaders.
"""
return train_dataloader(
batch_size=self.batch_size,
num_workers=num_workers,
root=root,
cache=cache,
cache_dir=cache_dir,
mem_cache=mem_cache,
url=url,
lang=self.lang,
)