File size: 9,617 Bytes
2f5f13b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import torch
from typing import Optional
from rvc.lib.algorithm.generators.hifigan_mrf import HiFiGANMRFGenerator
from rvc.lib.algorithm.generators.hifigan_nsf import HiFiGANNSFGenerator
from rvc.lib.algorithm.generators.hifigan import HiFiGANGenerator
from rvc.lib.algorithm.generators.refinegan import RefineGANGenerator
from rvc.lib.algorithm.commons import slice_segments, rand_slice_segments
from rvc.lib.algorithm.residuals import ResidualCouplingBlock
from rvc.lib.algorithm.encoders import TextEncoder, PosteriorEncoder


class Synthesizer(torch.nn.Module):
    """

    Base Synthesizer model.



    Args:

        spec_channels (int): Number of channels in the spectrogram.

        segment_size (int): Size of the audio segment.

        inter_channels (int): Number of channels in the intermediate layers.

        hidden_channels (int): Number of channels in the hidden layers.

        filter_channels (int): Number of channels in the filter layers.

        n_heads (int): Number of attention heads.

        n_layers (int): Number of layers in the encoder.

        kernel_size (int): Size of the convolution kernel.

        p_dropout (float): Dropout probability.

        resblock (str): Type of residual block.

        resblock_kernel_sizes (list): Kernel sizes for the residual blocks.

        resblock_dilation_sizes (list): Dilation sizes for the residual blocks.

        upsample_rates (list): Upsampling rates for the decoder.

        upsample_initial_channel (int): Number of channels in the initial upsampling layer.

        upsample_kernel_sizes (list): Kernel sizes for the upsampling layers.

        spk_embed_dim (int): Dimension of the speaker embedding.

        gin_channels (int): Number of channels in the global conditioning vector.

        sr (int): Sampling rate of the audio.

        use_f0 (bool): Whether to use F0 information.

        text_enc_hidden_dim (int): Hidden dimension for the text encoder.

        kwargs: Additional keyword arguments.

    """

    def __init__(

        self,

        spec_channels: int,

        segment_size: int,

        inter_channels: int,

        hidden_channels: int,

        filter_channels: int,

        n_heads: int,

        n_layers: int,

        kernel_size: int,

        p_dropout: float,

        resblock: str,

        resblock_kernel_sizes: list,

        resblock_dilation_sizes: list,

        upsample_rates: list,

        upsample_initial_channel: int,

        upsample_kernel_sizes: list,

        spk_embed_dim: int,

        gin_channels: int,

        sr: int,

        use_f0: bool,

        text_enc_hidden_dim: int = 768,

        vocoder: str = "HiFi-GAN",

        randomized: bool = True,

        checkpointing: bool = False,

        **kwargs,

    ):
        super().__init__()
        self.segment_size = segment_size
        self.use_f0 = use_f0
        self.randomized = randomized

        self.enc_p = TextEncoder(
            inter_channels,
            hidden_channels,
            filter_channels,
            n_heads,
            n_layers,
            kernel_size,
            p_dropout,
            text_enc_hidden_dim,
            f0=use_f0,
        )
        print(f"Using {vocoder} vocoder")
        if use_f0:
            if vocoder == "MRF HiFi-GAN":
                self.dec = HiFiGANMRFGenerator(
                    in_channel=inter_channels,
                    upsample_initial_channel=upsample_initial_channel,
                    upsample_rates=upsample_rates,
                    upsample_kernel_sizes=upsample_kernel_sizes,
                    resblock_kernel_sizes=resblock_kernel_sizes,
                    resblock_dilations=resblock_dilation_sizes,
                    gin_channels=gin_channels,
                    sample_rate=sr,
                    harmonic_num=8,
                    checkpointing=checkpointing,
                )
            elif vocoder == "RefineGAN":
                self.dec = RefineGANGenerator(
                    sample_rate=sr,
                    downsample_rates=upsample_rates[::-1],
                    upsample_rates=upsample_rates,
                    start_channels=16,
                    num_mels=inter_channels,
                    checkpointing=checkpointing,
                )
            else:
                self.dec = HiFiGANNSFGenerator(
                    inter_channels,
                    resblock_kernel_sizes,
                    resblock_dilation_sizes,
                    upsample_rates,
                    upsample_initial_channel,
                    upsample_kernel_sizes,
                    gin_channels=gin_channels,
                    sr=sr,
                    checkpointing=checkpointing,
                )
        else:
            if vocoder == "MRF HiFi-GAN":
                print("MRF HiFi-GAN does not support training without pitch guidance.")
                self.dec = None
            elif vocoder == "RefineGAN":
                print("RefineGAN does not support training without pitch guidance.")
                self.dec = None
            else:
                self.dec = HiFiGANGenerator(
                    inter_channels,
                    resblock_kernel_sizes,
                    resblock_dilation_sizes,
                    upsample_rates,
                    upsample_initial_channel,
                    upsample_kernel_sizes,
                    gin_channels=gin_channels,
                    checkpointing=checkpointing,
                )
        self.enc_q = PosteriorEncoder(
            spec_channels,
            inter_channels,
            hidden_channels,
            5,
            1,
            16,
            gin_channels=gin_channels,
        )
        self.flow = ResidualCouplingBlock(
            inter_channels,
            hidden_channels,
            5,
            1,
            3,
            gin_channels=gin_channels,
        )
        self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels)

    def _remove_weight_norm_from(self, module):
        for hook in module._forward_pre_hooks.values():
            if getattr(hook, "__class__", None).__name__ == "WeightNorm":
                torch.nn.utils.remove_weight_norm(module)

    def remove_weight_norm(self):
        for module in [self.dec, self.flow, self.enc_q]:
            self._remove_weight_norm_from(module)

    def __prepare_scriptable__(self):
        self.remove_weight_norm()
        return self

    def forward(

        self,

        phone: torch.Tensor,

        phone_lengths: torch.Tensor,

        pitch: Optional[torch.Tensor] = None,

        pitchf: Optional[torch.Tensor] = None,

        y: Optional[torch.Tensor] = None,

        y_lengths: Optional[torch.Tensor] = None,

        ds: Optional[torch.Tensor] = None,

    ):
        g = self.emb_g(ds).unsqueeze(-1)
        m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)

        if y is not None:
            z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
            z_p = self.flow(z, y_mask, g=g)
            # regular old training method using random slices
            if self.randomized:
                z_slice, ids_slice = rand_slice_segments(
                    z, y_lengths, self.segment_size
                )
                if self.use_f0:
                    pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
                    o = self.dec(z_slice, pitchf, g=g)
                else:
                    o = self.dec(z_slice, g=g)
                return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
            # future use for finetuning using the entire dataset each pass
            else:
                if self.use_f0:
                    o = self.dec(z, pitchf, g=g)
                else:
                    o = self.dec(z, g=g)
                return o, None, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
        else:
            return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)

    @torch.jit.export
    def infer(

        self,

        phone: torch.Tensor,

        phone_lengths: torch.Tensor,

        pitch: Optional[torch.Tensor] = None,

        nsff0: Optional[torch.Tensor] = None,

        sid: torch.Tensor = None,

        rate: Optional[torch.Tensor] = None,

    ):
        """

        Inference of the model.



        Args:

            phone (torch.Tensor): Phoneme sequence.

            phone_lengths (torch.Tensor): Lengths of the phoneme sequences.

            pitch (torch.Tensor, optional): Pitch sequence.

            nsff0 (torch.Tensor, optional): Fine-grained pitch sequence.

            sid (torch.Tensor): Speaker embedding.

            rate (torch.Tensor, optional): Rate for time-stretching.

        """
        g = self.emb_g(sid).unsqueeze(-1)
        m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
        z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask

        if rate is not None:
            head = int(z_p.shape[2] * (1.0 - rate.item()))
            z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:]
            if self.use_f0 and nsff0 is not None:
                nsff0 = nsff0[:, head:]

        z = self.flow(z_p, x_mask, g=g, reverse=True)
        o = (
            self.dec(z * x_mask, nsff0, g=g)
            if self.use_f0
            else self.dec(z * x_mask, g=g)
        )

        return o, x_mask, (z, z_p, m_p, logs_p)