from typing import Tuple import torch from torch import nn from torch.nn import Module import torch.nn.functional as F from models.config import ( AcousticModelConfigType, PreprocessingConfig, ) from models.helpers import tools from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE from models.tts.delightful_tts.conv_blocks import CoordConv1d class ReferenceEncoder(Module): r"""A class to define the reference encoder. Similar to Tacotron model, the reference encoder is used to extract the high-level features from the reference It consists of a number of convolutional blocks (`CoordConv1d` for the first one and `nn.Conv1d` for the rest), then followed by instance normalization and GRU layers. The `CoordConv1d` at the first layer to better preserve positional information, paper: [Robust and fine-grained prosody control of end-to-end speech synthesis](https://arxiv.org/pdf/1811.02122.pdf) Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. Args: preprocess_config (PreprocessingConfig): Configuration object with preprocessing parameters. model_config (AcousticModelConfigType): Configuration object with acoustic model parameters. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor. _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x has been masked. """ def __init__( self, preprocess_config: PreprocessingConfig, model_config: AcousticModelConfigType, ): super().__init__() n_mel_channels = preprocess_config.stft.n_mel_channels ref_enc_filters = model_config.reference_encoder.ref_enc_filters ref_enc_size = model_config.reference_encoder.ref_enc_size ref_enc_strides = model_config.reference_encoder.ref_enc_strides ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size self.n_mel_channels = n_mel_channels K = len(ref_enc_filters) filters = [self.n_mel_channels, *ref_enc_filters] strides = [1, *ref_enc_strides] # Use CoordConv1d at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf convs = [ CoordConv1d( in_channels=filters[0], out_channels=filters[0 + 1], kernel_size=ref_enc_size, stride=strides[0], padding=ref_enc_size // 2, with_r=True, ), *[ nn.Conv1d( in_channels=filters[i], out_channels=filters[i + 1], kernel_size=ref_enc_size, stride=strides[i], padding=ref_enc_size // 2, ) for i in range(1, K) ], ] # Define convolution layers (ModuleList) self.convs = nn.ModuleList(convs) self.norms = nn.ModuleList( [ nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K) ], ) # Define GRU layer self.gru = nn.GRU( input_size=ref_enc_filters[-1], hidden_size=ref_enc_gru_size, batch_first=True, ) def forward( self, x: torch.Tensor, mel_lens: torch.Tensor, leaky_relu_slope: float = LEAKY_RELU_SLOPE, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass of the ReferenceEncoder. Args: x (torch.Tensor): A 3-dimensional tensor containing the input sequences, its size is [N, n_mels, timesteps]. mel_lens (torch.Tensor): A 1-dimensional tensor containing the lengths of each sequence in x. Its length is N. leaky_relu_slope (float): The slope of the leaky relu function. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor. _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x has been masked. """ mel_masks = tools.get_mask_from_lengths(mel_lens).unsqueeze(1) mel_masks = mel_masks.to(x.device) x = x.masked_fill(mel_masks, 0) for conv, norm in zip(self.convs, self.norms): x = x.float() x = conv(x) x = F.leaky_relu(x, leaky_relu_slope) # [N, 128, Ty//2^K, n_mels//2^K] x = norm(x) for _ in range(2): mel_lens = tools.stride_lens_downsampling(mel_lens) mel_masks = tools.get_mask_from_lengths(mel_lens) x = x.masked_fill(mel_masks.unsqueeze(1), 0) x = x.permute((0, 2, 1)) packed_sequence = torch.nn.utils.rnn.pack_padded_sequence( x, lengths=mel_lens.cpu().int(), batch_first=True, enforce_sorted=False, ) self.gru.flatten_parameters() # memory --- [N, Ty, E//2], out --- [1, N, E//2] out, memory = self.gru(packed_sequence) out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True) return out, memory, mel_masks def calculate_channels( self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int, ) -> int: r"""Calculate the number of channels after applying convolutions. Args: L (int): The original size. kernel_size (int): The kernel size used in the convolutions. stride (int): The stride used in the convolutions. pad (int): The padding used in the convolutions. n_convs (int): The number of convolutions. Returns: int: The size after the convolutions. """ # Loop through each convolution for _ in range(n_convs): # Calculate the size after each convolution L = (L - kernel_size + 2 * pad) // stride + 1 return L