Spaces:
Running
Running
from typing import Tuple | |
import torch | |
from torch import nn | |
from torch.nn import Module | |
import torch.nn.functional as F | |
from models.config import ( | |
AcousticModelConfigType, | |
PreprocessingConfig, | |
) | |
from models.helpers import tools | |
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE | |
from models.tts.delightful_tts.conv_blocks import CoordConv1d | |
class ReferenceEncoder(Module): | |
r"""A class to define the reference encoder. | |
Similar to Tacotron model, the reference encoder is used to extract the high-level features from the reference | |
It consists of a number of convolutional blocks (`CoordConv1d` for the first one and `nn.Conv1d` for the rest), | |
then followed by instance normalization and GRU layers. | |
The `CoordConv1d` at the first layer to better preserve positional information, paper: | |
[Robust and fine-grained prosody control of end-to-end speech synthesis](https://arxiv.org/pdf/1811.02122.pdf) | |
Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. | |
Args: | |
preprocess_config (PreprocessingConfig): Configuration object with preprocessing parameters. | |
model_config (AcousticModelConfigType): Configuration object with acoustic model parameters. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor | |
produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor. | |
_Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x | |
has been masked. | |
""" | |
def __init__( | |
self, | |
preprocess_config: PreprocessingConfig, | |
model_config: AcousticModelConfigType, | |
): | |
super().__init__() | |
n_mel_channels = preprocess_config.stft.n_mel_channels | |
ref_enc_filters = model_config.reference_encoder.ref_enc_filters | |
ref_enc_size = model_config.reference_encoder.ref_enc_size | |
ref_enc_strides = model_config.reference_encoder.ref_enc_strides | |
ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size | |
self.n_mel_channels = n_mel_channels | |
K = len(ref_enc_filters) | |
filters = [self.n_mel_channels, *ref_enc_filters] | |
strides = [1, *ref_enc_strides] | |
# Use CoordConv1d at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf | |
convs = [ | |
CoordConv1d( | |
in_channels=filters[0], | |
out_channels=filters[0 + 1], | |
kernel_size=ref_enc_size, | |
stride=strides[0], | |
padding=ref_enc_size // 2, | |
with_r=True, | |
), | |
*[ | |
nn.Conv1d( | |
in_channels=filters[i], | |
out_channels=filters[i + 1], | |
kernel_size=ref_enc_size, | |
stride=strides[i], | |
padding=ref_enc_size // 2, | |
) | |
for i in range(1, K) | |
], | |
] | |
# Define convolution layers (ModuleList) | |
self.convs = nn.ModuleList(convs) | |
self.norms = nn.ModuleList( | |
[ | |
nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) | |
for i in range(K) | |
], | |
) | |
# Define GRU layer | |
self.gru = nn.GRU( | |
input_size=ref_enc_filters[-1], | |
hidden_size=ref_enc_gru_size, | |
batch_first=True, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mel_lens: torch.Tensor, | |
leaky_relu_slope: float = LEAKY_RELU_SLOPE, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
r"""Forward pass of the ReferenceEncoder. | |
Args: | |
x (torch.Tensor): A 3-dimensional tensor containing the input sequences, its size is [N, n_mels, timesteps]. | |
mel_lens (torch.Tensor): A 1-dimensional tensor containing the lengths of each sequence in x. Its length is N. | |
leaky_relu_slope (float): The slope of the leaky relu function. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor | |
produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor. | |
_Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x | |
has been masked. | |
""" | |
mel_masks = tools.get_mask_from_lengths(mel_lens).unsqueeze(1) | |
mel_masks = mel_masks.to(x.device) | |
x = x.masked_fill(mel_masks, 0) | |
for conv, norm in zip(self.convs, self.norms): | |
x = x.float() | |
x = conv(x) | |
x = F.leaky_relu(x, leaky_relu_slope) # [N, 128, Ty//2^K, n_mels//2^K] | |
x = norm(x) | |
for _ in range(2): | |
mel_lens = tools.stride_lens_downsampling(mel_lens) | |
mel_masks = tools.get_mask_from_lengths(mel_lens) | |
x = x.masked_fill(mel_masks.unsqueeze(1), 0) | |
x = x.permute((0, 2, 1)) | |
packed_sequence = torch.nn.utils.rnn.pack_padded_sequence( | |
x, | |
lengths=mel_lens.cpu().int(), | |
batch_first=True, | |
enforce_sorted=False, | |
) | |
self.gru.flatten_parameters() | |
# memory --- [N, Ty, E//2], out --- [1, N, E//2] | |
out, memory = self.gru(packed_sequence) | |
out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True) | |
return out, memory, mel_masks | |
def calculate_channels( | |
self, | |
L: int, | |
kernel_size: int, | |
stride: int, | |
pad: int, | |
n_convs: int, | |
) -> int: | |
r"""Calculate the number of channels after applying convolutions. | |
Args: | |
L (int): The original size. | |
kernel_size (int): The kernel size used in the convolutions. | |
stride (int): The stride used in the convolutions. | |
pad (int): The padding used in the convolutions. | |
n_convs (int): The number of convolutions. | |
Returns: | |
int: The size after the convolutions. | |
""" | |
# Loop through each convolution | |
for _ in range(n_convs): | |
# Calculate the size after each convolution | |
L = (L - kernel_size + 2 * pad) // stride + 1 | |
return L | |