Upload seamless_communication/models/generator/builder.py with huggingface_hub
Browse files
seamless_communication/models/generator/builder.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
9 |
+
|
10 |
+
from fairseq2.data import VocabularyInfo
|
11 |
+
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
|
12 |
+
from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
|
13 |
+
from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
|
14 |
+
from fairseq2.nn.projection import Linear
|
15 |
+
from fairseq2.nn.transformer import (
|
16 |
+
MultiheadAttention,
|
17 |
+
StandardMultiheadAttention,
|
18 |
+
TransformerNormOrder,
|
19 |
+
create_default_sdpa,
|
20 |
+
)
|
21 |
+
from fairseq2.typing import DataType, Device
|
22 |
+
from torch.nn import Conv1d
|
23 |
+
|
24 |
+
from seamless_communication.models.generator.ecapa_tdnn_builder import (
|
25 |
+
EcapaTDNNBuilder,
|
26 |
+
EcapaTDNNConfig,
|
27 |
+
ecapa_tdnn_archs,
|
28 |
+
)
|
29 |
+
from seamless_communication.models.generator.vocoder import (
|
30 |
+
PretsselDecoderFrontend,
|
31 |
+
PretsselEncoderFrontend,
|
32 |
+
PretsselVocoder,
|
33 |
+
)
|
34 |
+
from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
|
35 |
+
from seamless_communication.models.unity.fft_decoder_layer import (
|
36 |
+
Conv1dBlock,
|
37 |
+
FeedForwardTransformerLayer,
|
38 |
+
)
|
39 |
+
from seamless_communication.models.unity.length_regulator import (
|
40 |
+
VarianceAdaptor,
|
41 |
+
VariancePredictor,
|
42 |
+
)
|
43 |
+
from seamless_communication.models.unity.t2u_builder import VariancePredictorConfig
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class PretsselEncoderFrontendConfig:
|
48 |
+
prosody_encoder_config: EcapaTDNNConfig
|
49 |
+
dropout: float
|
50 |
+
lang_embed_dim: Optional[int] = None
|
51 |
+
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class FFTLayerConfig:
|
55 |
+
attention_heads: int
|
56 |
+
hidden_dim: int
|
57 |
+
kernel_size: int
|
58 |
+
dropout: float
|
59 |
+
conv1d_dropout: float
|
60 |
+
film_cond_dim: int
|
61 |
+
use_film: bool = False
|
62 |
+
|
63 |
+
|
64 |
+
@dataclass
|
65 |
+
class PretsselDecoderFrontendConfig:
|
66 |
+
upsampling_type: Literal["gaussian", "hard"]
|
67 |
+
variance_predictor_config: VariancePredictorConfig
|
68 |
+
add_variance_parallel: bool
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class VocoderConfig:
|
73 |
+
"""Holds the configuration of a Vocoder model."""
|
74 |
+
|
75 |
+
encoder_frontend_config: PretsselEncoderFrontendConfig
|
76 |
+
fft_layer_config: FFTLayerConfig
|
77 |
+
decoder_frontend_config: PretsselDecoderFrontendConfig
|
78 |
+
pn_conv_dim: int
|
79 |
+
pn_layers: int
|
80 |
+
pn_conv_kernel_size: int
|
81 |
+
pn_dropout: float
|
82 |
+
vocab_info: VocabularyInfo
|
83 |
+
model_dim: int
|
84 |
+
max_seq_len: int
|
85 |
+
encoder_layers: int
|
86 |
+
decoder_layers: int
|
87 |
+
mel_dim: int
|
88 |
+
langs: List # type: ignore[type-arg]
|
89 |
+
upsample_rates: List[int]
|
90 |
+
upsample_kernel_sizes: List[int]
|
91 |
+
upsample_initial_channel: int
|
92 |
+
resblock_kernel_sizes: List[int]
|
93 |
+
resblock_dilation_sizes: List[List[int]]
|
94 |
+
channels: int
|
95 |
+
dimension: int
|
96 |
+
n_filters: int
|
97 |
+
ratios: List[int]
|
98 |
+
norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"]
|
99 |
+
norm_params: Dict[str, Any]
|
100 |
+
kernel_size: int
|
101 |
+
last_kernel_size: int
|
102 |
+
residual_kernel_size: int
|
103 |
+
causal: bool
|
104 |
+
pad_mode: str
|
105 |
+
true_skip: bool
|
106 |
+
compress: int
|
107 |
+
lstm: int
|
108 |
+
disable_norm_outer_blocks: int
|
109 |
+
trim_right_ratio: float
|
110 |
+
gcmvn_stats: Dict[str, List] # type: ignore[type-arg]
|
111 |
+
|
112 |
+
|
113 |
+
vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
|
114 |
+
|
115 |
+
|
116 |
+
vocoder_arch = vocoder_archs.decorator
|
117 |
+
|
118 |
+
|
119 |
+
def pretssel_config() -> (
|
120 |
+
Tuple[PretsselEncoderFrontendConfig, FFTLayerConfig, PretsselDecoderFrontendConfig]
|
121 |
+
):
|
122 |
+
prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
|
123 |
+
|
124 |
+
encoder_frontend_config = PretsselEncoderFrontendConfig(
|
125 |
+
prosody_encoder_config=prosody_encoder_config,
|
126 |
+
dropout=0.2,
|
127 |
+
lang_embed_dim=64,
|
128 |
+
)
|
129 |
+
|
130 |
+
fft_layer_config = FFTLayerConfig(
|
131 |
+
attention_heads=2,
|
132 |
+
hidden_dim=1024,
|
133 |
+
kernel_size=9,
|
134 |
+
dropout=0.0,
|
135 |
+
conv1d_dropout=0.2,
|
136 |
+
use_film=True,
|
137 |
+
film_cond_dim=576,
|
138 |
+
)
|
139 |
+
|
140 |
+
variance_predictor_config = VariancePredictorConfig(
|
141 |
+
var_pred_hidden_dim=512,
|
142 |
+
var_pred_kernel_size=5,
|
143 |
+
var_pred_dropout=0.5,
|
144 |
+
use_film=True,
|
145 |
+
film_cond_dim=576,
|
146 |
+
)
|
147 |
+
|
148 |
+
decoder_frontend_config = PretsselDecoderFrontendConfig(
|
149 |
+
upsampling_type="gaussian",
|
150 |
+
variance_predictor_config=variance_predictor_config,
|
151 |
+
add_variance_parallel=True,
|
152 |
+
)
|
153 |
+
return (
|
154 |
+
encoder_frontend_config,
|
155 |
+
fft_layer_config,
|
156 |
+
decoder_frontend_config,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@vocoder_arch("16khz")
|
161 |
+
def _16khz_vocoder() -> VocoderConfig:
|
162 |
+
(
|
163 |
+
encoder_frontend_config,
|
164 |
+
fft_layer_config,
|
165 |
+
decoder_frontend_config,
|
166 |
+
) = pretssel_config()
|
167 |
+
|
168 |
+
return VocoderConfig(
|
169 |
+
encoder_frontend_config=encoder_frontend_config,
|
170 |
+
fft_layer_config=fft_layer_config,
|
171 |
+
decoder_frontend_config=decoder_frontend_config,
|
172 |
+
pn_conv_dim=512,
|
173 |
+
pn_layers=5,
|
174 |
+
pn_conv_kernel_size=5,
|
175 |
+
pn_dropout=0.5,
|
176 |
+
vocab_info=VocabularyInfo(
|
177 |
+
size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
|
178 |
+
),
|
179 |
+
model_dim=256,
|
180 |
+
max_seq_len=10000,
|
181 |
+
encoder_layers=4,
|
182 |
+
decoder_layers=4,
|
183 |
+
mel_dim=80,
|
184 |
+
langs=[],
|
185 |
+
upsample_rates=[5, 4, 4, 2],
|
186 |
+
upsample_kernel_sizes=[10, 8, 8, 4],
|
187 |
+
upsample_initial_channel=512,
|
188 |
+
resblock_kernel_sizes=[3, 7, 11],
|
189 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
190 |
+
channels=1,
|
191 |
+
dimension=128,
|
192 |
+
n_filters=32,
|
193 |
+
ratios=[8, 5, 4, 2],
|
194 |
+
norm="weight_norm",
|
195 |
+
norm_params={},
|
196 |
+
kernel_size=7,
|
197 |
+
last_kernel_size=7,
|
198 |
+
residual_kernel_size=3,
|
199 |
+
causal=False,
|
200 |
+
pad_mode="constant",
|
201 |
+
true_skip=True,
|
202 |
+
compress=2,
|
203 |
+
lstm=2,
|
204 |
+
disable_norm_outer_blocks=0,
|
205 |
+
trim_right_ratio=1.0,
|
206 |
+
gcmvn_stats={},
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
@vocoder_arch("24khz")
|
211 |
+
def _24khz_vocoder() -> VocoderConfig:
|
212 |
+
(
|
213 |
+
encoder_frontend_config,
|
214 |
+
fft_layer_config,
|
215 |
+
decoder_frontend_config,
|
216 |
+
) = pretssel_config()
|
217 |
+
|
218 |
+
return VocoderConfig(
|
219 |
+
encoder_frontend_config=encoder_frontend_config,
|
220 |
+
fft_layer_config=fft_layer_config,
|
221 |
+
decoder_frontend_config=decoder_frontend_config,
|
222 |
+
pn_conv_dim=512,
|
223 |
+
pn_layers=5,
|
224 |
+
pn_conv_kernel_size=5,
|
225 |
+
pn_dropout=0.5,
|
226 |
+
vocab_info=VocabularyInfo(
|
227 |
+
size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
|
228 |
+
),
|
229 |
+
model_dim=256,
|
230 |
+
max_seq_len=10000,
|
231 |
+
encoder_layers=4,
|
232 |
+
decoder_layers=4,
|
233 |
+
mel_dim=80,
|
234 |
+
langs=[],
|
235 |
+
upsample_rates=[5, 4, 4, 3],
|
236 |
+
upsample_kernel_sizes=[10, 8, 8, 6],
|
237 |
+
upsample_initial_channel=512,
|
238 |
+
resblock_kernel_sizes=[3, 7, 11],
|
239 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
240 |
+
channels=1,
|
241 |
+
dimension=128,
|
242 |
+
n_filters=32,
|
243 |
+
ratios=[8, 5, 4, 2],
|
244 |
+
norm="weight_norm",
|
245 |
+
norm_params={},
|
246 |
+
kernel_size=7,
|
247 |
+
last_kernel_size=7,
|
248 |
+
residual_kernel_size=3,
|
249 |
+
causal=False,
|
250 |
+
pad_mode="constant",
|
251 |
+
true_skip=True,
|
252 |
+
compress=2,
|
253 |
+
lstm=2,
|
254 |
+
disable_norm_outer_blocks=0,
|
255 |
+
trim_right_ratio=1.0,
|
256 |
+
gcmvn_stats={},
|
257 |
+
)
|
258 |
+
|
259 |
+
|
260 |
+
class PretsselVocoderBuilder:
|
261 |
+
config: VocoderConfig
|
262 |
+
prosody_encoder_builder: EcapaTDNNBuilder
|
263 |
+
device: Optional[Device] = None
|
264 |
+
dtype: Optional[DataType] = None
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
config: VocoderConfig,
|
269 |
+
prosody_encoder_builder: EcapaTDNNBuilder,
|
270 |
+
*,
|
271 |
+
device: Optional[Device] = None,
|
272 |
+
dtype: Optional[DataType] = None,
|
273 |
+
) -> None:
|
274 |
+
"""
|
275 |
+
:param config:
|
276 |
+
The configuration to use.
|
277 |
+
:param device:
|
278 |
+
The device on which to initialize modules.
|
279 |
+
:param dtype:
|
280 |
+
The data type of module parameters and buffers.
|
281 |
+
"""
|
282 |
+
self.config = config
|
283 |
+
self.prosody_encoder_builder = prosody_encoder_builder
|
284 |
+
self.device, self.dtype = device, dtype
|
285 |
+
|
286 |
+
def build_embed_tokens(self) -> StandardEmbedding:
|
287 |
+
"""Build a unit embedding table."""
|
288 |
+
|
289 |
+
return StandardEmbedding(
|
290 |
+
num_embeddings=self.config.vocab_info.size,
|
291 |
+
embedding_dim=self.config.model_dim,
|
292 |
+
init_fn=init_scaled_embedding,
|
293 |
+
device=self.device,
|
294 |
+
dtype=self.dtype,
|
295 |
+
)
|
296 |
+
|
297 |
+
def build_fft(self, num_layers: int) -> FeedForwardTransformer:
|
298 |
+
"""Build a Transformer encoder."""
|
299 |
+
|
300 |
+
layers = [self.build_fft_layer() for _ in range(num_layers)]
|
301 |
+
|
302 |
+
return FeedForwardTransformer(
|
303 |
+
layers,
|
304 |
+
norm_order=TransformerNormOrder.POST,
|
305 |
+
device=self.device,
|
306 |
+
dtype=self.dtype,
|
307 |
+
)
|
308 |
+
|
309 |
+
def build_fft_layer(self) -> FeedForwardTransformerLayer:
|
310 |
+
"""Build a Transformer decoder layer."""
|
311 |
+
|
312 |
+
self_attn = self.build_attention(self.config.fft_layer_config.attention_heads)
|
313 |
+
|
314 |
+
conv1d = Conv1dBlock(
|
315 |
+
self.config.model_dim,
|
316 |
+
self.config.fft_layer_config.hidden_dim,
|
317 |
+
self.config.fft_layer_config.kernel_size,
|
318 |
+
bias=True,
|
319 |
+
device=self.device,
|
320 |
+
dtype=self.dtype,
|
321 |
+
)
|
322 |
+
|
323 |
+
return FeedForwardTransformerLayer(
|
324 |
+
self_attn,
|
325 |
+
conv1d,
|
326 |
+
dropout_p=0.0, # fairseq1 doesn't have this
|
327 |
+
conv1d_dropout_p=self.config.fft_layer_config.conv1d_dropout,
|
328 |
+
use_film=self.config.fft_layer_config.use_film,
|
329 |
+
film_cond_dim=self.config.fft_layer_config.film_cond_dim,
|
330 |
+
device=self.device,
|
331 |
+
dtype=self.dtype,
|
332 |
+
)
|
333 |
+
|
334 |
+
def build_attention(self, num_heads: int) -> MultiheadAttention:
|
335 |
+
"""Build a Transformer multi-head attention layer."""
|
336 |
+
|
337 |
+
sdpa = create_default_sdpa(attn_dropout_p=self.config.fft_layer_config.dropout)
|
338 |
+
|
339 |
+
return StandardMultiheadAttention(
|
340 |
+
self.config.model_dim,
|
341 |
+
num_heads,
|
342 |
+
sdpa=sdpa,
|
343 |
+
device=self.device,
|
344 |
+
dtype=self.dtype,
|
345 |
+
)
|
346 |
+
|
347 |
+
def build_variance_adaptor(
|
348 |
+
self,
|
349 |
+
decoder_frontend_config: PretsselDecoderFrontendConfig,
|
350 |
+
) -> VarianceAdaptor:
|
351 |
+
"""Build a variance adaptor module."""
|
352 |
+
|
353 |
+
variance_predictor_config = decoder_frontend_config.variance_predictor_config
|
354 |
+
|
355 |
+
pitch_predictor = VariancePredictor(
|
356 |
+
self.config.model_dim,
|
357 |
+
variance_predictor_config.var_pred_hidden_dim,
|
358 |
+
variance_predictor_config.var_pred_kernel_size,
|
359 |
+
variance_predictor_config.var_pred_dropout,
|
360 |
+
use_film=variance_predictor_config.use_film,
|
361 |
+
film_cond_dim=variance_predictor_config.film_cond_dim,
|
362 |
+
device=self.device,
|
363 |
+
dtype=self.dtype,
|
364 |
+
)
|
365 |
+
|
366 |
+
embed_pitch = Conv1d(1, self.config.model_dim, kernel_size=1)
|
367 |
+
|
368 |
+
vuv_predictor = VariancePredictor(
|
369 |
+
self.config.model_dim,
|
370 |
+
variance_predictor_config.var_pred_hidden_dim,
|
371 |
+
variance_predictor_config.var_pred_kernel_size,
|
372 |
+
variance_predictor_config.var_pred_dropout,
|
373 |
+
use_film=variance_predictor_config.use_film,
|
374 |
+
film_cond_dim=variance_predictor_config.film_cond_dim,
|
375 |
+
device=self.device,
|
376 |
+
dtype=self.dtype,
|
377 |
+
)
|
378 |
+
|
379 |
+
energy_predictor = VariancePredictor(
|
380 |
+
self.config.model_dim,
|
381 |
+
variance_predictor_config.var_pred_hidden_dim,
|
382 |
+
variance_predictor_config.var_pred_kernel_size,
|
383 |
+
variance_predictor_config.var_pred_dropout,
|
384 |
+
use_film=variance_predictor_config.use_film,
|
385 |
+
film_cond_dim=variance_predictor_config.film_cond_dim,
|
386 |
+
device=self.device,
|
387 |
+
dtype=self.dtype,
|
388 |
+
)
|
389 |
+
|
390 |
+
embed_energy = Conv1d(1, self.config.model_dim, kernel_size=1)
|
391 |
+
|
392 |
+
variance_adaptor = VarianceAdaptor(
|
393 |
+
duration_predictor=None,
|
394 |
+
pitch_predictor=pitch_predictor,
|
395 |
+
embed_pitch=embed_pitch,
|
396 |
+
vuv_predictor=vuv_predictor,
|
397 |
+
energy_predictor=energy_predictor,
|
398 |
+
embed_energy=embed_energy,
|
399 |
+
add_variance_parallel=decoder_frontend_config.add_variance_parallel,
|
400 |
+
upsampling_type=decoder_frontend_config.upsampling_type,
|
401 |
+
)
|
402 |
+
|
403 |
+
return variance_adaptor
|
404 |
+
|
405 |
+
def build_model(self) -> PretsselVocoder:
|
406 |
+
"""build the pretssel vocoder."""
|
407 |
+
prosody_encoder = self.prosody_encoder_builder.build_model()
|
408 |
+
embed_tokens = self.build_embed_tokens()
|
409 |
+
|
410 |
+
embed_positions = SinusoidalPositionEncoder(
|
411 |
+
self.config.model_dim,
|
412 |
+
self.config.max_seq_len,
|
413 |
+
_legacy_pad_idx=self.config.vocab_info.pad_idx,
|
414 |
+
device=self.device,
|
415 |
+
)
|
416 |
+
lang_to_index = {l: i for i, l in enumerate(self.config.langs)}
|
417 |
+
encoder_frontend = PretsselEncoderFrontend(
|
418 |
+
prosody_encoder,
|
419 |
+
embed_tokens,
|
420 |
+
embed_positions,
|
421 |
+
lang_to_index,
|
422 |
+
lang_embed_dim=self.config.encoder_frontend_config.lang_embed_dim,
|
423 |
+
dropout_p=self.config.encoder_frontend_config.dropout,
|
424 |
+
device=self.device,
|
425 |
+
dtype=self.dtype,
|
426 |
+
)
|
427 |
+
|
428 |
+
encoder = self.build_fft(self.config.encoder_layers)
|
429 |
+
|
430 |
+
variance_adaptor = self.build_variance_adaptor(
|
431 |
+
self.config.decoder_frontend_config
|
432 |
+
)
|
433 |
+
|
434 |
+
decoder_frontend = PretsselDecoderFrontend(
|
435 |
+
variance_adaptor,
|
436 |
+
embed_positions,
|
437 |
+
device=self.device,
|
438 |
+
dtype=self.dtype,
|
439 |
+
)
|
440 |
+
|
441 |
+
decoder = self.build_fft(self.config.decoder_layers)
|
442 |
+
|
443 |
+
final_proj = Linear(
|
444 |
+
self.config.model_dim,
|
445 |
+
self.config.mel_dim,
|
446 |
+
bias=True,
|
447 |
+
device=self.device,
|
448 |
+
dtype=self.dtype,
|
449 |
+
)
|
450 |
+
|
451 |
+
gcmvn_mean = gcmvn_std = None
|
452 |
+
if self.config.gcmvn_stats is not None:
|
453 |
+
gcmvn_mean = self.config.gcmvn_stats["mean"]
|
454 |
+
gcmvn_std = self.config.gcmvn_stats["std"]
|
455 |
+
|
456 |
+
vocoder = PretsselVocoder(
|
457 |
+
encoder_frontend=encoder_frontend,
|
458 |
+
encoder=encoder,
|
459 |
+
decoder_frontend=decoder_frontend,
|
460 |
+
decoder=decoder,
|
461 |
+
final_proj=final_proj,
|
462 |
+
pn_n_channels=self.config.pn_conv_dim,
|
463 |
+
pn_kernel_size=self.config.pn_conv_kernel_size,
|
464 |
+
pn_layers=self.config.pn_layers,
|
465 |
+
pn_dropout=self.config.pn_dropout,
|
466 |
+
upsample_rates=self.config.upsample_rates,
|
467 |
+
upsample_kernel_sizes=self.config.upsample_kernel_sizes,
|
468 |
+
upsample_initial_channel=self.config.upsample_initial_channel,
|
469 |
+
resblock_kernel_sizes=self.config.resblock_kernel_sizes,
|
470 |
+
resblock_dilation_sizes=self.config.resblock_dilation_sizes,
|
471 |
+
channels=self.config.channels,
|
472 |
+
dimension=self.config.dimension,
|
473 |
+
n_filters=self.config.n_filters,
|
474 |
+
ratios=self.config.ratios,
|
475 |
+
norm=self.config.norm,
|
476 |
+
norm_params=self.config.norm_params,
|
477 |
+
kernel_size=self.config.kernel_size,
|
478 |
+
last_kernel_size=self.config.last_kernel_size,
|
479 |
+
residual_kernel_size=self.config.residual_kernel_size,
|
480 |
+
causal=self.config.causal,
|
481 |
+
pad_mode=self.config.pad_mode,
|
482 |
+
true_skip=self.config.true_skip,
|
483 |
+
compress=self.config.compress,
|
484 |
+
lstm=self.config.lstm,
|
485 |
+
disable_norm_outer_blocks=self.config.disable_norm_outer_blocks,
|
486 |
+
trim_right_ratio=self.config.trim_right_ratio,
|
487 |
+
gcmvn_mean=gcmvn_mean,
|
488 |
+
gcmvn_std=gcmvn_std,
|
489 |
+
)
|
490 |
+
vocoder.to(dtype=self.dtype, device=self.device)
|
491 |
+
return vocoder
|
492 |
+
|
493 |
+
|
494 |
+
def create_vocoder_model(
|
495 |
+
config: VocoderConfig,
|
496 |
+
device: Optional[Device] = None,
|
497 |
+
dtype: Optional[DataType] = None,
|
498 |
+
) -> PretsselVocoder:
|
499 |
+
prosody_encoder_builder = EcapaTDNNBuilder(
|
500 |
+
config.encoder_frontend_config.prosody_encoder_config,
|
501 |
+
device=device,
|
502 |
+
dtype=dtype,
|
503 |
+
)
|
504 |
+
return PretsselVocoderBuilder(
|
505 |
+
config, prosody_encoder_builder, device=device, dtype=dtype
|
506 |
+
).build_model()
|