File size: 8,498 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
from typing import List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F

from models.config import AcousticModelConfigType

from .variance_predictor import VariancePredictor


class DurationAdaptor(Module):
    """DurationAdaptor is a module that adapts the duration of the input sequence.

    Args:
        model_config (AcousticModelConfigType): Configuration object containing model parameters.
    """

    def __init__(
        self,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()
        # Initialize the duration predictor
        self.duration_predictor = VariancePredictor(
            channels_in=model_config.encoder.n_hidden,
            channels=model_config.variance_adaptor.n_hidden,
            channels_out=1,
            kernel_size=model_config.variance_adaptor.kernel_size,
            p_dropout=model_config.variance_adaptor.p_dropout,
        )

    @staticmethod
    def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
        r"""Convert the padding shape from a list of lists to a flat list.

        Args:
            pad_shape (List[List[int]]): Padding shape as a list of lists.

        Returns:
            List[int]: Padding shape as a flat list.
        """
        pad_list = pad_shape[::-1]
        return [item for sublist in pad_list for item in sublist]

    @staticmethod
    def generate_path(duration: Tensor, mask: Tensor) -> Tensor:
        r"""Generate a path based on the duration and mask.

        Args:
            duration (Tensor): Duration tensor.
            mask (Tensor): Mask tensor.

        Returns:
            Tensor: Path tensor.

        Shapes:
        - duration: :math:`[B, T_en]`
        - mask: :math:'[B, T_en, T_de]`
        - path: :math:`[B, T_en, T_de]`
        """
        b, t_x, t_y = mask.shape
        cum_duration = torch.cumsum(duration, 1)

        cum_duration_flat = cum_duration.view(b * t_x)
        path = DurationAdaptor.sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
        path = path.view(b, t_x, t_y)
        pad_shape = DurationAdaptor.convert_pad_shape([[0, 0], [1, 0], [0, 0]])
        path = path - F.pad(path, pad_shape)[:, :-1]
        path = path * mask
        return path

    # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
    @staticmethod
    def sequence_mask(sequence_length: Tensor, max_len: Optional[int] = None) -> Tensor:
        """Create a sequence mask for filtering padding in a sequence tensor.

        Args:
            sequence_length (torch.Tensor): Sequence lengths.
            max_len (int, Optional): Maximum sequence length. Defaults to None.

        Returns:
            torch.Tensor: Sequence mask.

        Shapes:
            - mask: :math:`[B, T_max]`
        """
        if max_len is None:
            max_len = int(sequence_length.max())

        seq_range = torch.arange(
            max_len,
            dtype=sequence_length.dtype,
            device=sequence_length.device,
        )
        # B x T_max
        return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)

    @staticmethod
    def generate_attn(
        dr: Tensor,
        x_mask: Tensor,
        y_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """Generate an attention mask from the linear scale durations.

        Args:
            dr (Tensor): Linear scale durations.
            x_mask (Tensor): Mask for the input (character) sequence.
            y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations
                if None. Defaults to None.

        Shapes
           - dr: :math:`(B, T_{en})`
           - x_mask: :math:`(B, T_{en})`
           - y_mask: :math:`(B, T_{de})`
        """
        # compute decode mask from the durations
        if y_mask is None:
            y_lengths = dr.sum(1).long()
            y_lengths[y_lengths < 1] = 1
            sequence_mask = DurationAdaptor.sequence_mask(y_lengths, None)
            y_mask = torch.unsqueeze(sequence_mask, 1).to(dr.dtype)

        # compute the attention mask
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        attn = DurationAdaptor.generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
        return attn

    def _expand_encoder_with_durations(
        self,
        encoder_output: Tensor,
        duration_target: Tensor,
        x_mask: Tensor,
        mel_lens: Tensor,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        r"""Expand the encoder output with durations.

        Args:
            encoder_output (Tensor): Encoder output.
            duration_target (Tensor): Target durations.
            x_mask (Tensor): Mask for the input sequence.
            mel_lens (Tensor): Lengths of the mel spectrograms.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Tuple containing the mask for the output sequence, the attention mask, and the expanded encoder output.
        """
        y_mask = torch.unsqueeze(
            DurationAdaptor.sequence_mask(mel_lens, None),
            1,
        ).to(encoder_output.dtype)
        attn = self.generate_attn(duration_target, x_mask, y_mask)

        encoder_output_ex = torch.einsum(
            "kmn, kmj -> kjn",
            [attn.float(), encoder_output],
        )

        return y_mask, attn, encoder_output_ex

    def forward_train(
        self,
        encoder_output: Tensor,
        encoder_output_res: Tensor,
        duration_target: Tensor,
        src_mask: Tensor,
        mel_lens: Tensor,
    ):
        r"""Forward pass of the DurationAdaptor during training.

        Args:
            encoder_output (Tensor): Encoder output.
            encoder_output_res (Tensor): Encoder output.
            duration_target (Tensor): Target durations.
            src_mask (Tensor): Source mask.
            mel_lens (Tensor): Lengths of the mel spectrograms.

        Returns:
            Tuple: Tuple containing the predicted alignments, log durations, mask for the output sequence, expanded encoder output, and the transposed attention mask.
        """
        log_duration_pred = self.duration_predictor.forward(
            x=encoder_output_res.detach(),
            mask=src_mask,
        )  # [B, C_hidden, T_src] -> [B, T_src]

        y_mask, attn, encoder_output_dr = self._expand_encoder_with_durations(
            encoder_output,
            duration_target,
            x_mask=~src_mask[:, None],
            mel_lens=mel_lens,
        )

        duration_target = torch.log(duration_target + 1)
        duration_pred = torch.exp(log_duration_pred) - 1

        alignments_duration_pred = self.generate_attn(
            duration_pred,
            src_mask.unsqueeze(1),
            y_mask,
        )  # [B, T_max, T_max2']

        return (
            alignments_duration_pred,
            log_duration_pred,
            encoder_output_dr,
            attn.transpose(1, 2),
        )

    def forward(self, encoder_output: Tensor, src_mask: Tensor, d_control: float = 1.0):
        r"""Forward pass of the DurationAdaptor.

        Args:
            encoder_output (Tensor): Encoder output.
            src_mask (Tensor): Source mask.
            d_control (float): Duration control. Defaults to 1.0.

        Returns:
            Tuple: Tuple containing the expanded encoder output, log durations, predicted durations, mask for the output sequence, and the attention mask.
        """
        log_duration_pred = self.duration_predictor(
            x=encoder_output.detach(),
            mask=src_mask,
        )  # [B, C_hidden, T_src] -> [B, T_src]

        duration_pred = (
            (torch.exp(log_duration_pred) - 1) * (~src_mask) * d_control
        )  # -> [B, T_src]

        # duration_pred[duration_pred < 1] = 1.0  # -> [B, T_src]
        duration_pred = torch.where(
            duration_pred < 1,
            torch.ones_like(duration_pred),
            duration_pred,
        )  # -> [B, T_src]

        duration_pred = torch.round(duration_pred)  # -> [B, T_src]
        mel_lens = duration_pred.sum(1)  # -> [B,]

        _, attn, encoder_output_dr = self._expand_encoder_with_durations(
            encoder_output,
            duration_pred.squeeze(1),
            ~src_mask[:, None],
            mel_lens,
        )

        return (
            log_duration_pred,
            encoder_output_dr,
            duration_pred,
            attn.transpose(1, 2),
        )