File size: 6,915 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
from torch import nn
from torch.nn.utils import parametrize


@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
    n_channels_int = n_channels[0]
    in_act = input_a + input_b
    t_act = torch.tanh(in_act[:, :n_channels_int, :])
    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
    acts = t_act * s_act
    return acts


class WN(torch.nn.Module):
    """Wavenet layers with weight norm and no input conditioning.

         |-----------------------------------------------------------------------------|
         |                                    |-> tanh    -|                           |
    res -|- conv1d(dilation) -> dropout -> + -|            * -> conv1d1x1 -> split -|- + -> res
    g -------------------------------------|  |-> sigmoid -|                        |
    o --------------------------------------------------------------------------- + --------- o

    Args:
        in_channels (int): number of input channels.
        hidden_channes (int): number of hidden channels.
        kernel_size (int): filter kernel size for the first conv layer.
        dilation_rate (int): dilations rate to increase dilation per layer.
            If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
        num_layers (int): number of wavenet layers.
        c_in_channels (int): number of channels of conditioning input.
        dropout_p (float): dropout rate.
        weight_norm (bool): enable/disable weight norm for convolution layers.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        kernel_size,
        dilation_rate,
        num_layers,
        c_in_channels=0,
        dropout_p=0,
        weight_norm=True,
    ):
        super().__init__()
        assert kernel_size % 2 == 1
        assert hidden_channels % 2 == 0
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.num_layers = num_layers
        self.c_in_channels = c_in_channels
        self.dropout_p = dropout_p

        self.in_layers = torch.nn.ModuleList()
        self.res_skip_layers = torch.nn.ModuleList()
        self.dropout = nn.Dropout(dropout_p)

        # init conditioning layer
        if c_in_channels > 0:
            cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
            self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
        # intermediate layers
        for i in range(num_layers):
            dilation = dilation_rate**i
            padding = int((kernel_size * dilation - dilation) / 2)
            if i == 0:
                in_layer = torch.nn.Conv1d(
                    in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
                )
            else:
                in_layer = torch.nn.Conv1d(
                    hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
                )
            in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
            self.in_layers.append(in_layer)

            if i < num_layers - 1:
                res_skip_channels = 2 * hidden_channels
            else:
                res_skip_channels = hidden_channels

            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
            res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
            self.res_skip_layers.append(res_skip_layer)
        # setup weight norm
        if not weight_norm:
            self.remove_weight_norm()

    def forward(self, x, x_mask=None, g=None, **kwargs):  # pylint: disable=unused-argument
        output = torch.zeros_like(x)
        n_channels_tensor = torch.IntTensor([self.hidden_channels])
        x_mask = 1.0 if x_mask is None else x_mask
        if g is not None:
            g = self.cond_layer(g)
        for i in range(self.num_layers):
            x_in = self.in_layers[i](x)
            x_in = self.dropout(x_in)
            if g is not None:
                cond_offset = i * 2 * self.hidden_channels
                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
            else:
                g_l = torch.zeros_like(x_in)
            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
            res_skip_acts = self.res_skip_layers[i](acts)
            if i < self.num_layers - 1:
                x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
                output = output + res_skip_acts[:, self.hidden_channels :, :]
            else:
                output = output + res_skip_acts
        return output * x_mask

    def remove_weight_norm(self):
        if self.c_in_channels != 0:
            parametrize.remove_parametrizations(self.cond_layer, "weight")
        for l in self.in_layers:
            parametrize.remove_parametrizations(l, "weight")
        for l in self.res_skip_layers:
            parametrize.remove_parametrizations(l, "weight")


class WNBlocks(nn.Module):
    """Wavenet blocks.

    Note: After each block dilation resets to 1 and it increases in each block
        along the dilation rate.

    Args:
        in_channels (int): number of input channels.
        hidden_channes (int): number of hidden channels.
        kernel_size (int): filter kernel size for the first conv layer.
        dilation_rate (int): dilations rate to increase dilation per layer.
            If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers.
        num_blocks (int): number of wavenet blocks.
        num_layers (int): number of wavenet layers.
        c_in_channels (int): number of channels of conditioning input.
        dropout_p (float): dropout rate.
        weight_norm (bool): enable/disable weight norm for convolution layers.
    """

    def __init__(
        self,
        in_channels,
        hidden_channels,
        kernel_size,
        dilation_rate,
        num_blocks,
        num_layers,
        c_in_channels=0,
        dropout_p=0,
        weight_norm=True,
    ):
        super().__init__()
        self.wn_blocks = nn.ModuleList()
        for idx in range(num_blocks):
            layer = WN(
                in_channels=in_channels if idx == 0 else hidden_channels,
                hidden_channels=hidden_channels,
                kernel_size=kernel_size,
                dilation_rate=dilation_rate,
                num_layers=num_layers,
                c_in_channels=c_in_channels,
                dropout_p=dropout_p,
                weight_norm=weight_norm,
            )
            self.wn_blocks.append(layer)

    def forward(self, x, x_mask=None, g=None):
        o = x
        for layer in self.wn_blocks:
            o = layer(o, x_mask, g)
        return o