File size: 9,099 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""CBHG related modules."""

import torch
import torch.nn.functional as F

from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask


class CBHGLoss(torch.nn.Module):
    """Loss function module for CBHG."""

    def __init__(self, use_masking=True):
        """Initialize CBHG loss module.

        Args:
            use_masking (bool): Whether to mask padded part in loss calculation.

        """
        super(CBHGLoss, self).__init__()
        self.use_masking = use_masking

    def forward(self, cbhg_outs, spcs, olens):
        """Calculate forward propagation.

        Args:
            cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
            spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
            olens (LongTensor): Batch of the lengths of each sequence (B,).

        Returns:
            Tensor: L1 loss value
            Tensor: Mean square error loss value.

        """
        # perform masking for padded values
        if self.use_masking:
            mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
            spcs = spcs.masked_select(mask)
            cbhg_outs = cbhg_outs.masked_select(mask)

        # calculate loss
        cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
        cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)

        return cbhg_l1_loss, cbhg_mse_loss


class CBHG(torch.nn.Module):
    """CBHG module to convert log Mel-filterbanks to linear spectrogram.

    This is a module of CBHG introduced
    in `Tacotron: Towards End-to-End Speech Synthesis`_.
    The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.

    .. _`Tacotron: Towards End-to-End Speech Synthesis`:
         https://arxiv.org/abs/1703.10135

    """

    def __init__(
        self,
        idim,
        odim,
        conv_bank_layers=8,
        conv_bank_chans=128,
        conv_proj_filts=3,
        conv_proj_chans=256,
        highway_layers=4,
        highway_units=128,
        gru_units=256,
    ):
        """Initialize CBHG module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            conv_bank_layers (int, optional): The number of convolution bank layers.
            conv_bank_chans (int, optional): The number of channels in convolution bank.
            conv_proj_filts (int, optional):
                Kernel size of convolutional projection layer.
            conv_proj_chans (int, optional):
                The number of channels in convolutional projection layer.
            highway_layers (int, optional): The number of highway network layers.
            highway_units (int, optional): The number of highway network units.
            gru_units (int, optional): The number of GRU units (for both directions).

        """
        super(CBHG, self).__init__()
        self.idim = idim
        self.odim = odim
        self.conv_bank_layers = conv_bank_layers
        self.conv_bank_chans = conv_bank_chans
        self.conv_proj_filts = conv_proj_filts
        self.conv_proj_chans = conv_proj_chans
        self.highway_layers = highway_layers
        self.highway_units = highway_units
        self.gru_units = gru_units

        # define 1d convolution bank
        self.conv_bank = torch.nn.ModuleList()
        for k in range(1, self.conv_bank_layers + 1):
            if k % 2 != 0:
                padding = (k - 1) // 2
            else:
                padding = ((k - 1) // 2, (k - 1) // 2 + 1)
            self.conv_bank += [
                torch.nn.Sequential(
                    torch.nn.ConstantPad1d(padding, 0.0),
                    torch.nn.Conv1d(
                        idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True
                    ),
                    torch.nn.BatchNorm1d(self.conv_bank_chans),
                    torch.nn.ReLU(),
                )
            ]

        # define max pooling (need padding for one-side to keep same length)
        self.max_pool = torch.nn.Sequential(
            torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1)
        )

        # define 1d convolution projection
        self.projections = torch.nn.Sequential(
            torch.nn.Conv1d(
                self.conv_bank_chans * self.conv_bank_layers,
                self.conv_proj_chans,
                self.conv_proj_filts,
                stride=1,
                padding=(self.conv_proj_filts - 1) // 2,
                bias=True,
            ),
            torch.nn.BatchNorm1d(self.conv_proj_chans),
            torch.nn.ReLU(),
            torch.nn.Conv1d(
                self.conv_proj_chans,
                self.idim,
                self.conv_proj_filts,
                stride=1,
                padding=(self.conv_proj_filts - 1) // 2,
                bias=True,
            ),
            torch.nn.BatchNorm1d(self.idim),
        )

        # define highway network
        self.highways = torch.nn.ModuleList()
        self.highways += [torch.nn.Linear(idim, self.highway_units)]
        for _ in range(self.highway_layers):
            self.highways += [HighwayNet(self.highway_units)]

        # define bidirectional GRU
        self.gru = torch.nn.GRU(
            self.highway_units,
            gru_units // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )

        # define final projection
        self.output = torch.nn.Linear(gru_units, odim, bias=True)

    def forward(self, xs, ilens):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
            ilens (LongTensor): Batch of lengths of each input sequence (B,).

        Return:
            Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
            LongTensor: Batch of lengths of each output sequence (B,).

        """
        xs = xs.transpose(1, 2)  # (B, idim, Tmax)
        convs = []
        for k in range(self.conv_bank_layers):
            convs += [self.conv_bank[k](xs)]
        convs = torch.cat(convs, dim=1)  # (B, #CH * #BANK, Tmax)
        convs = self.max_pool(convs)
        convs = self.projections(convs).transpose(1, 2)  # (B, Tmax, idim)
        xs = xs.transpose(1, 2) + convs
        # + 1 for dimension adjustment layer
        for i in range(self.highway_layers + 1):
            xs = self.highways[i](xs)

        # sort by length
        xs, ilens, sort_idx = self._sort_by_length(xs, ilens)

        # total_length needs for DataParallel
        # (see https://github.com/pytorch/pytorch/pull/6327)
        total_length = xs.size(1)
        if not isinstance(ilens, torch.Tensor):
            ilens = torch.tensor(ilens)
        xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True)
        self.gru.flatten_parameters()
        xs, _ = self.gru(xs)
        xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)

        # revert sorting by length
        xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)

        xs = self.output(xs)  # (B, Tmax, odim)

        return xs, ilens

    def inference(self, x):
        """Inference.

        Args:
            x (Tensor): The sequences of inputs (T, idim).

        Return:
            Tensor: The sequence of outputs (T, odim).

        """
        assert len(x.size()) == 2
        xs = x.unsqueeze(0)
        ilens = x.new([x.size(0)]).long()

        return self.forward(xs, ilens)[0][0]

    def _sort_by_length(self, xs, ilens):
        sort_ilens, sort_idx = ilens.sort(0, descending=True)
        return xs[sort_idx], ilens[sort_idx], sort_idx

    def _revert_sort_by_length(self, xs, ilens, sort_idx):
        _, revert_idx = sort_idx.sort(0)
        return xs[revert_idx], ilens[revert_idx]


class HighwayNet(torch.nn.Module):
    """Highway Network module.

    This is a module of Highway Network introduced in `Highway Networks`_.

    .. _`Highway Networks`: https://arxiv.org/abs/1505.00387

    """

    def __init__(self, idim):
        """Initialize Highway Network module.

        Args:
            idim (int): Dimension of the inputs.

        """
        super(HighwayNet, self).__init__()
        self.idim = idim
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(idim, idim), torch.nn.ReLU()
        )
        self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid())

    def forward(self, x):
        """Calculate forward propagation.

        Args:
            x (Tensor): Batch of inputs (B, ..., idim).

        Returns:
            Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).

        """
        proj = self.projection(x)
        gate = self.gate(x)
        return proj * gate + x * (1.0 - gate)