File size: 5,727 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from torch import nn
from torch.nn import Module
from torch.nn.utils import parametrize


class KernelPredictor(Module):
    def __init__(
        self,
        cond_channels: int,
        conv_in_channels: int,
        conv_out_channels: int,
        conv_layers: int,
        conv_kernel_size: int = 3,
        kpnet_hidden_channels: int = 64,
        kpnet_conv_size: int = 3,
        kpnet_dropout: float = 0.0,
        lReLU_slope: float = 0.1,
    ):
        r"""Initializes a KernelPredictor object.
        KernelPredictor is a class that predicts the kernel size for the convolutional layers in the UnivNet model.
        The kernels of the LVC layers are predicted using a kernel predictor that takes the log-mel-spectrogram as the input.

        Args:
            cond_channels (int): The number of channels for the conditioning sequence.
            conv_in_channels (int): The number of channels for the input sequence.
            conv_out_channels (int): The number of channels for the output sequence.
            conv_layers (int): The number of layers in the model.
            conv_kernel_size (int, optional): The kernel size for the convolutional layers. Defaults to 3.
            kpnet_hidden_channels (int, optional): The number of hidden channels in the kernel predictor network. Defaults to 64.
            kpnet_conv_size (int, optional): The kernel size for the kernel predictor network. Defaults to 3.
            kpnet_dropout (float, optional): The dropout rate for the kernel predictor network. Defaults to 0.0.
            lReLU_slope (float, optional): The slope for the leaky ReLU activation function. Defaults to 0.1.
        """
        super().__init__()

        self.conv_in_channels = conv_in_channels
        self.conv_out_channels = conv_out_channels
        self.conv_kernel_size = conv_kernel_size
        self.conv_layers = conv_layers

        kpnet_kernel_channels = (
            conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
        )  # l_w

        kpnet_bias_channels = conv_out_channels * conv_layers  # l_b

        padding = (kpnet_conv_size - 1) // 2

        self.input_conv = nn.Sequential(
            nn.utils.parametrizations.weight_norm(
                nn.Conv1d(
                    cond_channels,
                    kpnet_hidden_channels,
                    5,
                    padding=2,
                    bias=True,
                ),
            ),
            nn.LeakyReLU(lReLU_slope),
        )

        self.residual_convs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Dropout(kpnet_dropout),
                    nn.utils.parametrizations.weight_norm(
                        nn.Conv1d(
                            kpnet_hidden_channels,
                            kpnet_hidden_channels,
                            kpnet_conv_size,
                            padding=padding,
                            bias=True,
                        ),
                    ),
                    nn.LeakyReLU(lReLU_slope),
                    nn.utils.parametrizations.weight_norm(
                        nn.Conv1d(
                            kpnet_hidden_channels,
                            kpnet_hidden_channels,
                            kpnet_conv_size,
                            padding=padding,
                            bias=True,
                        ),
                    ),
                    nn.LeakyReLU(lReLU_slope),
                )
                for _ in range(3)
            ],
        )

        self.kernel_conv = nn.utils.parametrizations.weight_norm(
            nn.Conv1d(
                kpnet_hidden_channels,
                kpnet_kernel_channels,
                kpnet_conv_size,
                padding=padding,
                bias=True,
            ),
        )
        self.bias_conv = nn.utils.parametrizations.weight_norm(
            nn.Conv1d(
                kpnet_hidden_channels,
                kpnet_bias_channels,
                kpnet_conv_size,
                padding=padding,
                bias=True,
            ),
        )

    def forward(self, c: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        r"""Computes the forward pass of the model.

        Args:
            c (Tensor): The conditioning sequence (batch, cond_channels, cond_length).

        Returns:
            Tuple[Tensor, Tensor]: A tuple containing the kernel and bias tensors.
        """
        batch, _, cond_length = c.shape
        c = self.input_conv(c.to(dtype=self.kernel_conv.weight.dtype))
        for residual_conv in self.residual_convs:
            c = c + residual_conv(c)
        k = self.kernel_conv(c)
        b = self.bias_conv(c)
        kernels = k.contiguous().view(
            batch,
            self.conv_layers,
            self.conv_in_channels,
            self.conv_out_channels,
            self.conv_kernel_size,
            cond_length,
        )
        bias = b.contiguous().view(
            batch,
            self.conv_layers,
            self.conv_out_channels,
            cond_length,
        )

        return kernels, bias

    def remove_weight_norm(self):
        r"""Removes weight normalization from the input, kernel, bias, and residual convolutional layers."""
        parametrize.remove_parametrizations(self.input_conv[0], "weight")
        parametrize.remove_parametrizations(self.kernel_conv, "weight")
        parametrize.remove_parametrizations(self.bias_conv, "weight")

        for block in self.residual_convs:
            parametrize.remove_parametrizations(block[1], "weight") # type: ignore
            parametrize.remove_parametrizations(block[3], "weight") # type: ignore