File size: 9,241 Bytes
2f5f13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import math
from typing import Optional

import torch
from torch.nn.utils import remove_weight_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.utils.checkpoint import checkpoint

from rvc.lib.algorithm.commons import init_weights
from rvc.lib.algorithm.generators.hifigan import SineGenerator
from rvc.lib.algorithm.residuals import LRELU_SLOPE, ResBlock


class SourceModuleHnNSF(torch.nn.Module):
    """

    Source Module for generating harmonic and noise components for audio synthesis.



    This module generates a harmonic source signal using sine waves and adds

    optional noise. It's often used in neural vocoders as a source of excitation.



    Args:

        sample_rate (int): Sampling rate of the audio in Hz.

        harmonic_num (int, optional): Number of harmonic overtones to generate above the fundamental frequency (F0). Defaults to 0.

        sine_amp (float, optional): Amplitude of the sine wave components. Defaults to 0.1.

        add_noise_std (float, optional): Standard deviation of the additive white Gaussian noise. Defaults to 0.003.

        voiced_threshod (float, optional): Threshold for the fundamental frequency (F0) to determine if a frame is voiced. If F0 is below this threshold, it's considered unvoiced. Defaults to 0.

    """

    def __init__(

        self,

        sample_rate: int,

        harmonic_num: int = 0,

        sine_amp: float = 0.1,

        add_noise_std: float = 0.003,

        voiced_threshod: float = 0,

    ):
        super(SourceModuleHnNSF, self).__init__()

        self.sine_amp = sine_amp
        self.noise_std = add_noise_std

        self.l_sin_gen = SineGenerator(
            sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
        )
        self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
        self.l_tanh = torch.nn.Tanh()

    def forward(self, x: torch.Tensor, upsample_factor: int = 1):
        sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor)
        sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
        return sine_merge, None, None


class HiFiGANNSFGenerator(torch.nn.Module):
    """

    Generator module based on the Neural Source Filter (NSF) architecture.



    This generator synthesizes audio by first generating a source excitation signal

    (harmonic and noise) and then filtering it through a series of upsampling and

    residual blocks. Global conditioning can be applied to influence the generation.



    Args:

        initial_channel (int): Number of input channels to the initial convolutional layer.

        resblock_kernel_sizes (list): List of kernel sizes for the residual blocks.

        resblock_dilation_sizes (list): List of lists of dilation rates for the residual blocks, corresponding to each kernel size.

        upsample_rates (list): List of upsampling factors for each upsampling layer.

        upsample_initial_channel (int): Number of output channels from the initial convolutional layer, which is also the input to the first upsampling layer.

        upsample_kernel_sizes (list): List of kernel sizes for the transposed convolutional layers used for upsampling.

        gin_channels (int): Number of input channels for the global conditioning. If 0, no global conditioning is used.

        sr (int): Sampling rate of the audio.

        checkpointing (bool, optional): Whether to use gradient checkpointing to save memory during training. Defaults to False.

    """

    def __init__(

        self,

        initial_channel: int,

        resblock_kernel_sizes: list,

        resblock_dilation_sizes: list,

        upsample_rates: list,

        upsample_initial_channel: int,

        upsample_kernel_sizes: list,

        gin_channels: int,

        sr: int,

        checkpointing: bool = False,

    ):
        super(HiFiGANNSFGenerator, self).__init__()

        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)
        self.checkpointing = checkpointing
        self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
        self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)

        self.conv_pre = torch.nn.Conv1d(
            initial_channel, upsample_initial_channel, 7, 1, padding=3
        )

        self.ups = torch.nn.ModuleList()
        self.noise_convs = torch.nn.ModuleList()

        channels = [
            upsample_initial_channel // (2 ** (i + 1))
            for i in range(len(upsample_rates))
        ]
        stride_f0s = [
            math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1
            for i in range(len(upsample_rates))
        ]

        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            # handling odd upsampling rates
            if u % 2 == 0:
                # old method
                padding = (k - u) // 2
            else:
                padding = u // 2 + u % 2

            self.ups.append(
                weight_norm(
                    torch.nn.ConvTranspose1d(
                        upsample_initial_channel // (2**i),
                        channels[i],
                        k,
                        u,
                        padding=padding,
                        output_padding=u % 2,
                    )
                )
            )
            """ handling odd upsampling rates

            #  s   k   p

            # 40  80  20

            # 32  64  16

            #  4   8   2

            #  2   3   1

            # 63 125  31

            #  9  17   4

            #  3   5   1

            #  1   1   0

            """
            stride = stride_f0s[i]
            kernel = 1 if stride == 1 else stride * 2 - stride % 2
            padding = 0 if stride == 1 else (kernel - stride) // 2

            self.noise_convs.append(
                torch.nn.Conv1d(
                    1,
                    channels[i],
                    kernel_size=kernel,
                    stride=stride,
                    padding=padding,
                )
            )

        self.resblocks = torch.nn.ModuleList(
            [
                ResBlock(channels[i], k, d)
                for i in range(len(self.ups))
                for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)
            ]
        )

        self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
        self.ups.apply(init_weights)

        if gin_channels != 0:
            self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)

        self.upp = math.prod(upsample_rates)
        self.lrelu_slope = LRELU_SLOPE

    def forward(

        self, x: torch.Tensor, f0: torch.Tensor, g: Optional[torch.Tensor] = None

    ):
        har_source, _, _ = self.m_source(f0, self.upp)
        har_source = har_source.transpose(1, 2)
        # new tensor
        x = self.conv_pre(x)

        if g is not None:
            # in-place call
            x += self.cond(g)

        for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
            # in-place call
            x = torch.nn.functional.leaky_relu_(x, self.lrelu_slope)

            # Apply upsampling layer
            if self.training and self.checkpointing:
                x = checkpoint(ups, x, use_reentrant=False)
            else:
                x = ups(x)

            # Add noise excitation
            x += noise_convs(har_source)

            # Apply residual blocks
            def resblock_forward(x, blocks):
                return sum(block(x) for block in blocks) / len(blocks)

            blocks = self.resblocks[i * self.num_kernels : (i + 1) * self.num_kernels]

            # Checkpoint or regular computation for ResBlocks
            if self.training and self.checkpointing:
                x = checkpoint(resblock_forward, x, blocks, use_reentrant=False)
            else:
                x = resblock_forward(x, blocks)
        # in-place call
        x = torch.nn.functional.leaky_relu_(x)
        # in-place call
        x = torch.tanh_(self.conv_post(x))

        return x

    def remove_weight_norm(self):
        for l in self.ups:
            remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()

    def __prepare_scriptable__(self):
        for l in self.ups:
            for hook in l._forward_pre_hooks.values():
                if (
                    hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
                    and hook.__class__.__name__ == "WeightNorm"
                ):
                    remove_weight_norm(l)
        for l in self.resblocks:
            for hook in l._forward_pre_hooks.values():
                if (
                    hook.__module__ == "torch.nn.utils.parametrizations.weight_norm"
                    and hook.__class__.__name__ == "WeightNorm"
                ):
                    remove_weight_norm(l)
        return self