Spaces:
Running
Running
from typing import Dict, Tuple, Union | |
import torch | |
from torch import nn | |
from torch.nn import Module | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
from models.config import ( | |
SUPPORTED_LANGUAGES, | |
AcousticModelConfigType, | |
PreprocessingConfig, | |
symbols, | |
) | |
from models.helpers import ( | |
positional_encoding, | |
tools, | |
) | |
from models.tts.delightful_tts.attention import Conformer | |
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE | |
from models.tts.delightful_tts.reference_encoder import ( | |
PhonemeLevelProsodyEncoder, | |
UtteranceLevelProsodyEncoder, | |
) | |
from .aligner import Aligner | |
from .energy_adaptor import EnergyAdaptor | |
from .length_adaptor import LengthAdaptor | |
from .phoneme_prosody_predictor import PhonemeProsodyPredictor | |
from .pitch_adaptor_conv import PitchAdaptorConv | |
class AcousticModel(Module): | |
r"""The DelightfulTTS AcousticModel class represents a PyTorch module for an acoustic model in text-to-speech (TTS). | |
The acoustic model is responsible for predicting speech signals from phoneme sequences. | |
The model comprises multiple sub-modules including encoder, decoder and various prosody encoders and predictors. | |
Additionally, a pitch and length adaptor are instantiated. | |
Args: | |
preprocess_config (PreprocessingConfig): Object containing the configuration used for preprocessing the data | |
model_config (AcousticModelConfigType): Configuration object containing various model parameters | |
n_speakers (int): Total number of speakers in the dataset | |
leaky_relu_slope (float, optional): Slope for the leaky relu. Defaults to LEAKY_RELU_SLOPE. | |
Note: | |
For more specific details on the implementation of sub-modules please refer to their individual respective modules. | |
""" | |
def __init__( | |
self, | |
preprocess_config: PreprocessingConfig, | |
model_config: AcousticModelConfigType, | |
n_speakers: int, | |
leaky_relu_slope: float = LEAKY_RELU_SLOPE, | |
): | |
super().__init__() | |
self.emb_dim = model_config.encoder.n_hidden | |
self.encoder = Conformer( | |
dim=model_config.encoder.n_hidden, | |
n_layers=model_config.encoder.n_layers, | |
n_heads=model_config.encoder.n_heads, | |
embedding_dim=model_config.speaker_embed_dim + model_config.lang_embed_dim, | |
p_dropout=model_config.encoder.p_dropout, | |
kernel_size_conv_mod=model_config.encoder.kernel_size_conv_mod, | |
with_ff=model_config.encoder.with_ff, | |
) | |
self.pitch_adaptor_conv = PitchAdaptorConv( | |
channels_in=model_config.encoder.n_hidden, | |
channels_hidden=model_config.variance_adaptor.n_hidden, | |
channels_out=1, | |
kernel_size=model_config.variance_adaptor.kernel_size, | |
emb_kernel_size=model_config.variance_adaptor.emb_kernel_size, | |
dropout=model_config.variance_adaptor.p_dropout, | |
leaky_relu_slope=leaky_relu_slope, | |
) | |
self.energy_adaptor = EnergyAdaptor( | |
channels_in=model_config.encoder.n_hidden, | |
channels_hidden=model_config.variance_adaptor.n_hidden, | |
channels_out=1, | |
kernel_size=model_config.variance_adaptor.kernel_size, | |
emb_kernel_size=model_config.variance_adaptor.emb_kernel_size, | |
dropout=model_config.variance_adaptor.p_dropout, | |
leaky_relu_slope=leaky_relu_slope, | |
) | |
self.length_regulator = LengthAdaptor(model_config) | |
self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder( | |
preprocess_config, | |
model_config, | |
) | |
self.utterance_prosody_predictor = PhonemeProsodyPredictor( | |
model_config=model_config, | |
phoneme_level=False, | |
) | |
self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder( | |
preprocess_config, | |
model_config, | |
) | |
self.phoneme_prosody_predictor = PhonemeProsodyPredictor( | |
model_config=model_config, | |
phoneme_level=True, | |
) | |
self.u_bottle_out = nn.Linear( | |
model_config.reference_encoder.bottleneck_size_u, | |
model_config.encoder.n_hidden, | |
) | |
self.u_norm = nn.LayerNorm( | |
model_config.reference_encoder.bottleneck_size_u, | |
elementwise_affine=False, | |
) | |
self.p_bottle_out = nn.Linear( | |
model_config.reference_encoder.bottleneck_size_p, | |
model_config.encoder.n_hidden, | |
) | |
self.p_norm = nn.LayerNorm( | |
model_config.reference_encoder.bottleneck_size_p, | |
elementwise_affine=False, | |
) | |
self.aligner = Aligner( | |
d_enc_in=model_config.encoder.n_hidden, | |
d_dec_in=preprocess_config.stft.n_mel_channels, | |
d_hidden=model_config.encoder.n_hidden, | |
) | |
self.decoder = Conformer( | |
dim=model_config.decoder.n_hidden, | |
n_layers=model_config.decoder.n_layers, | |
n_heads=model_config.decoder.n_heads, | |
embedding_dim=model_config.speaker_embed_dim + model_config.lang_embed_dim, | |
p_dropout=model_config.decoder.p_dropout, | |
kernel_size_conv_mod=model_config.decoder.kernel_size_conv_mod, | |
with_ff=model_config.decoder.with_ff, | |
) | |
self.src_word_emb = Parameter( | |
tools.initialize_embeddings( | |
(len(symbols), model_config.encoder.n_hidden), | |
), | |
) | |
self.to_mel = nn.Linear( | |
model_config.decoder.n_hidden, | |
preprocess_config.stft.n_mel_channels, | |
) | |
# NOTE: here you can manage the speaker embeddings, can be used for the voice export ? | |
# NOTE: flexibility of the model binded by the n_speaker parameter, maybe I can find another way? | |
# NOTE: in LIBRITTS there are 2477 speakers, we can add more, just extend the speaker_embed matrix | |
# Need to think about it more | |
self.speaker_embed = Parameter( | |
tools.initialize_embeddings( | |
(n_speakers, model_config.speaker_embed_dim), | |
), | |
) | |
self.lang_embed = Parameter( | |
tools.initialize_embeddings( | |
(len(SUPPORTED_LANGUAGES), model_config.lang_embed_dim), | |
), | |
) | |
def get_embeddings( | |
self, | |
token_idx: torch.Tensor, | |
speaker_idx: torch.Tensor, | |
src_mask: torch.Tensor, | |
lang_idx: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r"""Given the tokens, speakers, source mask, and language indices, compute | |
the embeddings for tokens, speakers and languages and return the | |
token_embeddings and combined speaker and language embeddings | |
Args: | |
token_idx (torch.Tensor): Tensor of token indices. | |
speaker_idx (torch.Tensor): Tensor of speaker identities. | |
src_mask (torch.Tensor): Mask tensor for source sequences. | |
lang_idx (torch.Tensor): Tensor of language indices. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Token embeddings tensor, | |
and combined speaker and language embeddings tensor. | |
""" | |
token_embeddings = F.embedding(token_idx, self.src_word_emb) | |
# NOTE: here you can manage the speaker embeddings, can be used for the voice export ? | |
speaker_embeds = F.embedding(speaker_idx, self.speaker_embed) | |
lang_embeds = F.embedding(lang_idx, self.lang_embed) | |
# Merge the speaker and language embeddings | |
embeddings = torch.cat([speaker_embeds, lang_embeds], dim=2) | |
# Apply the mask to the embeddings and token embeddings | |
embeddings = embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) | |
token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) | |
return token_embeddings, embeddings | |
def prepare_for_export(self) -> None: | |
r"""Prepare the model for export. | |
This method is called when the model is about to be exported, such as for deployment | |
or serializing for later use. The method removes unnecessary components that are | |
not needed during inference. Specifically, it removes the phoneme and utterance | |
prosody encoders for this acoustic model. These components are typically used during | |
training and are not needed when the model is used for making predictions. | |
Returns | |
None | |
""" | |
del self.phoneme_prosody_encoder | |
del self.utterance_prosody_encoder | |
# NOTE: freeze/unfreeze params changed, because of the conflict with the lightning module | |
def freeze_params(self) -> None: | |
r"""Freeze the trainable parameters in the model. | |
By freezing, the parameters are no longer updated by gradient descent. | |
This is typically done when you want to keep parts of your model fixed while training other parts. | |
For this model, it freezes all parameters and then selectively unfreezes the | |
speaker embeddings and the pitch adaptor's pitch embeddings to allow these components to update during training. | |
Returns | |
None | |
""" | |
for par in self.parameters(): | |
par.requires_grad = False | |
self.speaker_embed.requires_grad = True | |
# NOTE: freeze/unfreeze params changed, because of the conflict with the lightning module | |
def unfreeze_params(self, freeze_text_embed: bool, freeze_lang_embed: bool) -> None: | |
r"""Unfreeze the trainable parameters in the model, allowing them to be updated during training. | |
This method is typically used to 'unfreeze' previously 'frozen' parameters, making them trainable again. | |
For this model, it unfreezes all parameters and then selectively freezes the | |
text embeddings and language embeddings, if required. | |
Args: | |
freeze_text_embed (bool): Flag to indicate if text embeddings should remain frozen. | |
freeze_lang_embed (bool): Flag to indicate if language embeddings should remain frozen. | |
Returns: | |
None | |
""" | |
# Iterate through all model parameters and make them trainable | |
for par in self.parameters(): | |
par.requires_grad = True | |
# If freeze_text_embed flag is True, keep the source word embeddings frozen | |
if freeze_text_embed: | |
# @fixed self.src_word_emb.parameters has no parameters() method! | |
# for par in self.src_word_emb.parameters(): | |
self.src_word_emb.requires_grad = False | |
# If freeze_lang_embed flag is True, keep the language embeddings frozen | |
if freeze_lang_embed: | |
self.lang_embed.requires_grad = False | |
def average_utterance_prosody( | |
self, | |
u_prosody_pred: torch.Tensor, | |
src_mask: torch.Tensor, | |
) -> torch.Tensor: | |
r"""Compute the average utterance prosody over the length of non-masked elements. | |
This method averages the output of the utterance prosody predictor over | |
the sequence lengths (non-masked elements). This function will return | |
a tensor with the same first dimension but singleton trailing dimensions. | |
Args: | |
u_prosody_pred (torch.Tensor): Tensor containing the predicted utterance prosody of dimension (batch_size, T, n_features). | |
src_mask (torch.Tensor): Tensor of dimension (batch_size, T) acting as a mask where masked entries are set to False. | |
Returns: | |
torch.Tensor: Tensor of dimension (batch_size, 1, n_features) containing average utterance prosody over non-masked sequence length. | |
""" | |
# Compute the real sequence lengths by negating the mask and summing along the sequence dimension | |
lengths = ((~src_mask) * 1.0).sum(1) | |
# Compute the sum of u_prosody_pred across the sequence length dimension, | |
# then divide by the sequence lengths tensor to calculate the average. | |
# This performs a broadcasting operation to account for the third dimension (n_features). | |
# Return the averaged prosody prediction | |
return u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1) | |
def forward_train( | |
self, | |
x: torch.Tensor, | |
speakers: torch.Tensor, | |
src_lens: torch.Tensor, | |
mels: torch.Tensor, | |
mel_lens: torch.Tensor, | |
pitches: torch.Tensor, | |
langs: torch.Tensor, | |
attn_priors: Union[torch.Tensor, None], | |
energies: torch.Tensor, | |
) -> Dict[str, torch.Tensor]: | |
r"""Forward pass during training phase. | |
For a given phoneme sequence, speaker identities, sequence lengths, mels, | |
mel lengths, pitches, language, and attention priors, the forward pass | |
processes these inputs through the defined architecture. | |
Args: | |
x (torch.Tensor): Tensor of phoneme sequence. | |
speakers (torch.Tensor): Tensor of speaker identities. | |
src_lens (torch.Tensor): Long tensor representing the lengths of source sequences. | |
mels (torch.Tensor): Tensor of mel spectrograms. | |
mel_lens (torch.Tensor): Long tensor representing the lengths of mel sequences. | |
pitches (torch.Tensor): Tensor of pitch values. | |
langs (torch.Tensor): Tensor of language identities. | |
attn_priors (torch.Tensor): Prior attention values. | |
energies (torch.Tensor): Tensor of energy values. | |
Returns: | |
Dict[str, torch.Tensor]: Returns the prediction outputs as a dictionary. | |
""" | |
# Generate masks for padding positions in the source sequences and mel sequences | |
src_mask = tools.get_mask_from_lengths(src_lens) | |
mel_mask = tools.get_mask_from_lengths(mel_lens) | |
x, embeddings = self.get_embeddings( | |
token_idx=x, | |
speaker_idx=speakers, | |
src_mask=src_mask, | |
lang_idx=langs, | |
) | |
encoding = positional_encoding( | |
self.emb_dim, | |
max(x.shape[1], int(mel_lens.max().item())), | |
) | |
x = x.to(src_mask.device) | |
encoding = encoding.to(src_mask.device) | |
embeddings = embeddings.to(src_mask.device) | |
x = self.encoder(x, src_mask, embeddings=embeddings, encoding=encoding) | |
u_prosody_ref = self.u_norm( | |
self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens), | |
) | |
u_prosody_pred = self.u_norm( | |
self.average_utterance_prosody( | |
u_prosody_pred=self.utterance_prosody_predictor(x=x, mask=src_mask), | |
src_mask=src_mask, | |
), | |
) | |
p_prosody_ref = self.p_norm( | |
self.phoneme_prosody_encoder( | |
x=x, | |
src_mask=src_mask, | |
mels=mels, | |
mel_lens=mel_lens, | |
encoding=encoding, | |
), | |
) | |
p_prosody_pred = self.p_norm( | |
self.phoneme_prosody_predictor( | |
x=x, | |
mask=src_mask, | |
), | |
) | |
x = x + self.u_bottle_out(u_prosody_pred) | |
x = x + self.p_bottle_out(p_prosody_pred) | |
# Save the residual for later use | |
x_res = x | |
attn_logprob, attn_soft, attn_hard, attn_hard_dur = self.aligner( | |
enc_in=x_res.permute((0, 2, 1)), | |
dec_in=mels, | |
enc_len=src_lens, | |
dec_len=mel_lens, | |
enc_mask=src_mask, | |
attn_prior=attn_priors, | |
) | |
attn_hard_dur = attn_hard_dur.to(src_mask.device) | |
x, pitch_prediction, avg_pitch_target = ( | |
self.pitch_adaptor_conv.add_pitch_embedding_train( | |
x=x, | |
target=pitches, | |
dr=attn_hard_dur, | |
mask=src_mask, | |
) | |
) | |
energies = energies.to(src_mask.device) | |
x, energy_pred, avg_energy_target = ( | |
self.energy_adaptor.add_energy_embedding_train( | |
x=x, | |
target=energies, | |
dr=attn_hard_dur, | |
mask=src_mask, | |
) | |
) | |
x, log_duration_prediction, embeddings = self.length_regulator.upsample_train( | |
x=x, | |
x_res=x_res, | |
duration_target=attn_hard_dur, | |
src_mask=src_mask, | |
embeddings=embeddings, | |
) | |
# Decode the encoder output to pred mel spectrogram | |
decoder_output = self.decoder( | |
x, | |
mel_mask, | |
embeddings=embeddings, | |
encoding=encoding, | |
) | |
y_pred = self.to_mel(decoder_output) | |
y_pred = y_pred.permute((0, 2, 1)) | |
return { | |
"y_pred": y_pred, | |
"pitch_prediction": pitch_prediction, | |
"pitch_target": avg_pitch_target, | |
"energy_pred": energy_pred, | |
"energy_target": avg_energy_target, | |
"log_duration_prediction": log_duration_prediction, | |
"u_prosody_pred": u_prosody_pred, | |
"u_prosody_ref": u_prosody_ref, | |
"p_prosody_pred": p_prosody_pred, | |
"p_prosody_ref": p_prosody_ref, | |
"attn_logprob": attn_logprob, | |
"attn_soft": attn_soft, | |
"attn_hard": attn_hard, | |
"attn_hard_dur": attn_hard_dur, | |
} | |
def forward( | |
self, | |
x: torch.Tensor, | |
speakers: torch.Tensor, | |
langs: torch.Tensor, | |
d_control: float = 1.0, | |
) -> torch.Tensor: | |
r"""Forward pass during model inference. | |
The forward pass receives phoneme sequence, speaker identities, languages, pitch control and | |
duration control, conducts a series of operations on these inputs and returns the predicted mel | |
spectrogram. | |
Args: | |
x (torch.Tensor): Tensor of phoneme sequences. | |
speakers (torch.Tensor): Tensor of speaker identities. | |
langs (torch.Tensor): Tensor of language identities. | |
d_control (float): Duration control parameter. Defaults to 1.0. | |
Returns: | |
torch.Tensor: Predicted mel spectrogram. | |
""" | |
# Generate masks for padding positions in the source sequences | |
src_mask = tools.get_mask_from_lengths( | |
torch.tensor([x.shape[1]], dtype=torch.int64), | |
).to(x.device) | |
# Obtain the embeddings for the input | |
x, embeddings = self.get_embeddings( | |
token_idx=x, | |
speaker_idx=speakers, | |
src_mask=src_mask, | |
lang_idx=langs, | |
) | |
# Generate positional encodings | |
encoding = positional_encoding( | |
self.emb_dim, | |
x.shape[1], | |
).to(x.device) | |
# Process the embeddings through the encoder | |
x = self.encoder(x, src_mask, embeddings=embeddings, encoding=encoding) | |
# Predict prosody at utterance level and phoneme level | |
u_prosody_pred = self.u_norm( | |
self.average_utterance_prosody( | |
u_prosody_pred=self.utterance_prosody_predictor(x=x, mask=src_mask), | |
src_mask=src_mask, | |
), | |
) | |
p_prosody_pred = self.p_norm( | |
self.phoneme_prosody_predictor( | |
x=x, | |
mask=src_mask, | |
), | |
) | |
x = x + self.u_bottle_out(u_prosody_pred) | |
x = x + self.p_bottle_out(p_prosody_pred) | |
x_res = x | |
x, _ = self.pitch_adaptor_conv.add_pitch_embedding( | |
x=x, | |
mask=src_mask, | |
) | |
x, _ = self.energy_adaptor.add_energy_embedding( | |
x=x, | |
mask=src_mask, | |
) | |
x, _, embeddings = self.length_regulator.upsample( | |
x=x, | |
x_res=x_res, | |
src_mask=src_mask, | |
control=d_control, | |
embeddings=embeddings, | |
) | |
mel_mask = tools.get_mask_from_lengths( | |
torch.tensor([x.shape[1]], dtype=torch.int64), | |
).to(x.device) | |
if x.shape[1] > encoding.shape[1]: | |
encoding = positional_encoding(self.emb_dim, x.shape[1]).to(x.device) | |
decoder_output = self.decoder( | |
x, | |
mel_mask, | |
embeddings=embeddings, | |
encoding=encoding, | |
) | |
x = self.to_mel(decoder_output) | |
x = x.permute((0, 2, 1)) | |
return x | |