File size: 6,483 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from typing import Tuple

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from models.config import (
    AcousticModelConfigType,
    PreprocessingConfig,
)
from models.helpers import tools
from models.tts.delightful_tts.constants import LEAKY_RELU_SLOPE
from models.tts.delightful_tts.conv_blocks import CoordConv1d


class ReferenceEncoder(Module):
    r"""A class to define the reference encoder.
    Similar to Tacotron model, the reference encoder is used to extract the high-level features from the reference

    It consists of a number of convolutional blocks (`CoordConv1d` for the first one and `nn.Conv1d` for the rest),
    then followed by instance normalization and GRU layers.
    The `CoordConv1d` at the first layer to better preserve positional information, paper:
    [Robust and fine-grained prosody control of end-to-end speech synthesis](https://arxiv.org/pdf/1811.02122.pdf)

    Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.

    Args:
        preprocess_config (PreprocessingConfig): Configuration object with preprocessing parameters.
        model_config (AcousticModelConfigType): Configuration object with acoustic model parameters.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor
            produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor.
            _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x
            has been masked.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()

        n_mel_channels = preprocess_config.stft.n_mel_channels
        ref_enc_filters = model_config.reference_encoder.ref_enc_filters
        ref_enc_size = model_config.reference_encoder.ref_enc_size
        ref_enc_strides = model_config.reference_encoder.ref_enc_strides
        ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size

        self.n_mel_channels = n_mel_channels
        K = len(ref_enc_filters)
        filters = [self.n_mel_channels, *ref_enc_filters]
        strides = [1, *ref_enc_strides]

        # Use CoordConv1d at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
        convs = [
            CoordConv1d(
                in_channels=filters[0],
                out_channels=filters[0 + 1],
                kernel_size=ref_enc_size,
                stride=strides[0],
                padding=ref_enc_size // 2,
                with_r=True,
            ),
            *[
                nn.Conv1d(
                    in_channels=filters[i],
                    out_channels=filters[i + 1],
                    kernel_size=ref_enc_size,
                    stride=strides[i],
                    padding=ref_enc_size // 2,
                )
                for i in range(1, K)
            ],
        ]
        # Define convolution layers (ModuleList)
        self.convs = nn.ModuleList(convs)

        self.norms = nn.ModuleList(
            [
                nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True)
                for i in range(K)
            ],
        )

        # Define GRU layer
        self.gru = nn.GRU(
            input_size=ref_enc_filters[-1],
            hidden_size=ref_enc_gru_size,
            batch_first=True,
        )

    def forward(
        self,
        x: torch.Tensor,
        mel_lens: torch.Tensor,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Forward pass of the ReferenceEncoder.

        Args:
            x (torch.Tensor): A 3-dimensional tensor containing the input sequences, its size is [N, n_mels, timesteps].
            mel_lens (torch.Tensor): A 1-dimensional tensor containing the lengths of each sequence in x. Its length is N.
            leaky_relu_slope (float): The slope of the leaky relu function.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor
                produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor.
                _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x
                has been masked.
        """
        mel_masks = tools.get_mask_from_lengths(mel_lens).unsqueeze(1)
        mel_masks = mel_masks.to(x.device)

        x = x.masked_fill(mel_masks, 0)
        for conv, norm in zip(self.convs, self.norms):
            x = x.float()
            x = conv(x)
            x = F.leaky_relu(x, leaky_relu_slope)  # [N, 128, Ty//2^K, n_mels//2^K]
            x = norm(x)

        for _ in range(2):
            mel_lens = tools.stride_lens_downsampling(mel_lens)

        mel_masks = tools.get_mask_from_lengths(mel_lens)

        x = x.masked_fill(mel_masks.unsqueeze(1), 0)
        x = x.permute((0, 2, 1))

        packed_sequence = torch.nn.utils.rnn.pack_padded_sequence(
            x,
            lengths=mel_lens.cpu().int(),
            batch_first=True,
            enforce_sorted=False,
        )

        self.gru.flatten_parameters()
        # memory --- [N, Ty, E//2], out --- [1, N, E//2]
        out, memory = self.gru(packed_sequence)
        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)

        return out, memory, mel_masks

    def calculate_channels(
        self,
        L: int,
        kernel_size: int,
        stride: int,
        pad: int,
        n_convs: int,
    ) -> int:
        r"""Calculate the number of channels after applying convolutions.

        Args:
            L (int): The original size.
            kernel_size (int): The kernel size used in the convolutions.
            stride (int): The stride used in the convolutions.
            pad (int): The padding used in the convolutions.
            n_convs (int): The number of convolutions.

        Returns:
            int: The size after the convolutions.
        """
        # Loop through each convolution
        for _ in range(n_convs):
            # Calculate the size after each convolution
            L = (L - kernel_size + 2 * pad) // stride + 1
        return L