victan commited on
Commit
26464f1
·
1 Parent(s): 40b9c9d

Upload seamless_communication/models/generator/builder.py with huggingface_hub

Browse files
seamless_communication/models/generator/builder.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, List, Literal, Optional, Tuple
9
+
10
+ from fairseq2.data import VocabularyInfo
11
+ from fairseq2.models.utils.arch_registry import ArchitectureRegistry
12
+ from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
13
+ from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
14
+ from fairseq2.nn.projection import Linear
15
+ from fairseq2.nn.transformer import (
16
+ MultiheadAttention,
17
+ StandardMultiheadAttention,
18
+ TransformerNormOrder,
19
+ create_default_sdpa,
20
+ )
21
+ from fairseq2.typing import DataType, Device
22
+ from torch.nn import Conv1d
23
+
24
+ from seamless_communication.models.generator.ecapa_tdnn_builder import (
25
+ EcapaTDNNBuilder,
26
+ EcapaTDNNConfig,
27
+ ecapa_tdnn_archs,
28
+ )
29
+ from seamless_communication.models.generator.vocoder import (
30
+ PretsselDecoderFrontend,
31
+ PretsselEncoderFrontend,
32
+ PretsselVocoder,
33
+ )
34
+ from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
35
+ from seamless_communication.models.unity.fft_decoder_layer import (
36
+ Conv1dBlock,
37
+ FeedForwardTransformerLayer,
38
+ )
39
+ from seamless_communication.models.unity.length_regulator import (
40
+ VarianceAdaptor,
41
+ VariancePredictor,
42
+ )
43
+ from seamless_communication.models.unity.t2u_builder import VariancePredictorConfig
44
+
45
+
46
+ @dataclass
47
+ class PretsselEncoderFrontendConfig:
48
+ prosody_encoder_config: EcapaTDNNConfig
49
+ dropout: float
50
+ lang_embed_dim: Optional[int] = None
51
+
52
+
53
+ @dataclass
54
+ class FFTLayerConfig:
55
+ attention_heads: int
56
+ hidden_dim: int
57
+ kernel_size: int
58
+ dropout: float
59
+ conv1d_dropout: float
60
+ film_cond_dim: int
61
+ use_film: bool = False
62
+
63
+
64
+ @dataclass
65
+ class PretsselDecoderFrontendConfig:
66
+ upsampling_type: Literal["gaussian", "hard"]
67
+ variance_predictor_config: VariancePredictorConfig
68
+ add_variance_parallel: bool
69
+
70
+
71
+ @dataclass
72
+ class VocoderConfig:
73
+ """Holds the configuration of a Vocoder model."""
74
+
75
+ encoder_frontend_config: PretsselEncoderFrontendConfig
76
+ fft_layer_config: FFTLayerConfig
77
+ decoder_frontend_config: PretsselDecoderFrontendConfig
78
+ pn_conv_dim: int
79
+ pn_layers: int
80
+ pn_conv_kernel_size: int
81
+ pn_dropout: float
82
+ vocab_info: VocabularyInfo
83
+ model_dim: int
84
+ max_seq_len: int
85
+ encoder_layers: int
86
+ decoder_layers: int
87
+ mel_dim: int
88
+ langs: List # type: ignore[type-arg]
89
+ upsample_rates: List[int]
90
+ upsample_kernel_sizes: List[int]
91
+ upsample_initial_channel: int
92
+ resblock_kernel_sizes: List[int]
93
+ resblock_dilation_sizes: List[List[int]]
94
+ channels: int
95
+ dimension: int
96
+ n_filters: int
97
+ ratios: List[int]
98
+ norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"]
99
+ norm_params: Dict[str, Any]
100
+ kernel_size: int
101
+ last_kernel_size: int
102
+ residual_kernel_size: int
103
+ causal: bool
104
+ pad_mode: str
105
+ true_skip: bool
106
+ compress: int
107
+ lstm: int
108
+ disable_norm_outer_blocks: int
109
+ trim_right_ratio: float
110
+ gcmvn_stats: Dict[str, List] # type: ignore[type-arg]
111
+
112
+
113
+ vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
114
+
115
+
116
+ vocoder_arch = vocoder_archs.decorator
117
+
118
+
119
+ def pretssel_config() -> (
120
+ Tuple[PretsselEncoderFrontendConfig, FFTLayerConfig, PretsselDecoderFrontendConfig]
121
+ ):
122
+ prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
123
+
124
+ encoder_frontend_config = PretsselEncoderFrontendConfig(
125
+ prosody_encoder_config=prosody_encoder_config,
126
+ dropout=0.2,
127
+ lang_embed_dim=64,
128
+ )
129
+
130
+ fft_layer_config = FFTLayerConfig(
131
+ attention_heads=2,
132
+ hidden_dim=1024,
133
+ kernel_size=9,
134
+ dropout=0.0,
135
+ conv1d_dropout=0.2,
136
+ use_film=True,
137
+ film_cond_dim=576,
138
+ )
139
+
140
+ variance_predictor_config = VariancePredictorConfig(
141
+ var_pred_hidden_dim=512,
142
+ var_pred_kernel_size=5,
143
+ var_pred_dropout=0.5,
144
+ use_film=True,
145
+ film_cond_dim=576,
146
+ )
147
+
148
+ decoder_frontend_config = PretsselDecoderFrontendConfig(
149
+ upsampling_type="gaussian",
150
+ variance_predictor_config=variance_predictor_config,
151
+ add_variance_parallel=True,
152
+ )
153
+ return (
154
+ encoder_frontend_config,
155
+ fft_layer_config,
156
+ decoder_frontend_config,
157
+ )
158
+
159
+
160
+ @vocoder_arch("16khz")
161
+ def _16khz_vocoder() -> VocoderConfig:
162
+ (
163
+ encoder_frontend_config,
164
+ fft_layer_config,
165
+ decoder_frontend_config,
166
+ ) = pretssel_config()
167
+
168
+ return VocoderConfig(
169
+ encoder_frontend_config=encoder_frontend_config,
170
+ fft_layer_config=fft_layer_config,
171
+ decoder_frontend_config=decoder_frontend_config,
172
+ pn_conv_dim=512,
173
+ pn_layers=5,
174
+ pn_conv_kernel_size=5,
175
+ pn_dropout=0.5,
176
+ vocab_info=VocabularyInfo(
177
+ size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
178
+ ),
179
+ model_dim=256,
180
+ max_seq_len=10000,
181
+ encoder_layers=4,
182
+ decoder_layers=4,
183
+ mel_dim=80,
184
+ langs=[],
185
+ upsample_rates=[5, 4, 4, 2],
186
+ upsample_kernel_sizes=[10, 8, 8, 4],
187
+ upsample_initial_channel=512,
188
+ resblock_kernel_sizes=[3, 7, 11],
189
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
190
+ channels=1,
191
+ dimension=128,
192
+ n_filters=32,
193
+ ratios=[8, 5, 4, 2],
194
+ norm="weight_norm",
195
+ norm_params={},
196
+ kernel_size=7,
197
+ last_kernel_size=7,
198
+ residual_kernel_size=3,
199
+ causal=False,
200
+ pad_mode="constant",
201
+ true_skip=True,
202
+ compress=2,
203
+ lstm=2,
204
+ disable_norm_outer_blocks=0,
205
+ trim_right_ratio=1.0,
206
+ gcmvn_stats={},
207
+ )
208
+
209
+
210
+ @vocoder_arch("24khz")
211
+ def _24khz_vocoder() -> VocoderConfig:
212
+ (
213
+ encoder_frontend_config,
214
+ fft_layer_config,
215
+ decoder_frontend_config,
216
+ ) = pretssel_config()
217
+
218
+ return VocoderConfig(
219
+ encoder_frontend_config=encoder_frontend_config,
220
+ fft_layer_config=fft_layer_config,
221
+ decoder_frontend_config=decoder_frontend_config,
222
+ pn_conv_dim=512,
223
+ pn_layers=5,
224
+ pn_conv_kernel_size=5,
225
+ pn_dropout=0.5,
226
+ vocab_info=VocabularyInfo(
227
+ size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
228
+ ),
229
+ model_dim=256,
230
+ max_seq_len=10000,
231
+ encoder_layers=4,
232
+ decoder_layers=4,
233
+ mel_dim=80,
234
+ langs=[],
235
+ upsample_rates=[5, 4, 4, 3],
236
+ upsample_kernel_sizes=[10, 8, 8, 6],
237
+ upsample_initial_channel=512,
238
+ resblock_kernel_sizes=[3, 7, 11],
239
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
240
+ channels=1,
241
+ dimension=128,
242
+ n_filters=32,
243
+ ratios=[8, 5, 4, 2],
244
+ norm="weight_norm",
245
+ norm_params={},
246
+ kernel_size=7,
247
+ last_kernel_size=7,
248
+ residual_kernel_size=3,
249
+ causal=False,
250
+ pad_mode="constant",
251
+ true_skip=True,
252
+ compress=2,
253
+ lstm=2,
254
+ disable_norm_outer_blocks=0,
255
+ trim_right_ratio=1.0,
256
+ gcmvn_stats={},
257
+ )
258
+
259
+
260
+ class PretsselVocoderBuilder:
261
+ config: VocoderConfig
262
+ prosody_encoder_builder: EcapaTDNNBuilder
263
+ device: Optional[Device] = None
264
+ dtype: Optional[DataType] = None
265
+
266
+ def __init__(
267
+ self,
268
+ config: VocoderConfig,
269
+ prosody_encoder_builder: EcapaTDNNBuilder,
270
+ *,
271
+ device: Optional[Device] = None,
272
+ dtype: Optional[DataType] = None,
273
+ ) -> None:
274
+ """
275
+ :param config:
276
+ The configuration to use.
277
+ :param device:
278
+ The device on which to initialize modules.
279
+ :param dtype:
280
+ The data type of module parameters and buffers.
281
+ """
282
+ self.config = config
283
+ self.prosody_encoder_builder = prosody_encoder_builder
284
+ self.device, self.dtype = device, dtype
285
+
286
+ def build_embed_tokens(self) -> StandardEmbedding:
287
+ """Build a unit embedding table."""
288
+
289
+ return StandardEmbedding(
290
+ num_embeddings=self.config.vocab_info.size,
291
+ embedding_dim=self.config.model_dim,
292
+ init_fn=init_scaled_embedding,
293
+ device=self.device,
294
+ dtype=self.dtype,
295
+ )
296
+
297
+ def build_fft(self, num_layers: int) -> FeedForwardTransformer:
298
+ """Build a Transformer encoder."""
299
+
300
+ layers = [self.build_fft_layer() for _ in range(num_layers)]
301
+
302
+ return FeedForwardTransformer(
303
+ layers,
304
+ norm_order=TransformerNormOrder.POST,
305
+ device=self.device,
306
+ dtype=self.dtype,
307
+ )
308
+
309
+ def build_fft_layer(self) -> FeedForwardTransformerLayer:
310
+ """Build a Transformer decoder layer."""
311
+
312
+ self_attn = self.build_attention(self.config.fft_layer_config.attention_heads)
313
+
314
+ conv1d = Conv1dBlock(
315
+ self.config.model_dim,
316
+ self.config.fft_layer_config.hidden_dim,
317
+ self.config.fft_layer_config.kernel_size,
318
+ bias=True,
319
+ device=self.device,
320
+ dtype=self.dtype,
321
+ )
322
+
323
+ return FeedForwardTransformerLayer(
324
+ self_attn,
325
+ conv1d,
326
+ dropout_p=0.0, # fairseq1 doesn't have this
327
+ conv1d_dropout_p=self.config.fft_layer_config.conv1d_dropout,
328
+ use_film=self.config.fft_layer_config.use_film,
329
+ film_cond_dim=self.config.fft_layer_config.film_cond_dim,
330
+ device=self.device,
331
+ dtype=self.dtype,
332
+ )
333
+
334
+ def build_attention(self, num_heads: int) -> MultiheadAttention:
335
+ """Build a Transformer multi-head attention layer."""
336
+
337
+ sdpa = create_default_sdpa(attn_dropout_p=self.config.fft_layer_config.dropout)
338
+
339
+ return StandardMultiheadAttention(
340
+ self.config.model_dim,
341
+ num_heads,
342
+ sdpa=sdpa,
343
+ device=self.device,
344
+ dtype=self.dtype,
345
+ )
346
+
347
+ def build_variance_adaptor(
348
+ self,
349
+ decoder_frontend_config: PretsselDecoderFrontendConfig,
350
+ ) -> VarianceAdaptor:
351
+ """Build a variance adaptor module."""
352
+
353
+ variance_predictor_config = decoder_frontend_config.variance_predictor_config
354
+
355
+ pitch_predictor = VariancePredictor(
356
+ self.config.model_dim,
357
+ variance_predictor_config.var_pred_hidden_dim,
358
+ variance_predictor_config.var_pred_kernel_size,
359
+ variance_predictor_config.var_pred_dropout,
360
+ use_film=variance_predictor_config.use_film,
361
+ film_cond_dim=variance_predictor_config.film_cond_dim,
362
+ device=self.device,
363
+ dtype=self.dtype,
364
+ )
365
+
366
+ embed_pitch = Conv1d(1, self.config.model_dim, kernel_size=1)
367
+
368
+ vuv_predictor = VariancePredictor(
369
+ self.config.model_dim,
370
+ variance_predictor_config.var_pred_hidden_dim,
371
+ variance_predictor_config.var_pred_kernel_size,
372
+ variance_predictor_config.var_pred_dropout,
373
+ use_film=variance_predictor_config.use_film,
374
+ film_cond_dim=variance_predictor_config.film_cond_dim,
375
+ device=self.device,
376
+ dtype=self.dtype,
377
+ )
378
+
379
+ energy_predictor = VariancePredictor(
380
+ self.config.model_dim,
381
+ variance_predictor_config.var_pred_hidden_dim,
382
+ variance_predictor_config.var_pred_kernel_size,
383
+ variance_predictor_config.var_pred_dropout,
384
+ use_film=variance_predictor_config.use_film,
385
+ film_cond_dim=variance_predictor_config.film_cond_dim,
386
+ device=self.device,
387
+ dtype=self.dtype,
388
+ )
389
+
390
+ embed_energy = Conv1d(1, self.config.model_dim, kernel_size=1)
391
+
392
+ variance_adaptor = VarianceAdaptor(
393
+ duration_predictor=None,
394
+ pitch_predictor=pitch_predictor,
395
+ embed_pitch=embed_pitch,
396
+ vuv_predictor=vuv_predictor,
397
+ energy_predictor=energy_predictor,
398
+ embed_energy=embed_energy,
399
+ add_variance_parallel=decoder_frontend_config.add_variance_parallel,
400
+ upsampling_type=decoder_frontend_config.upsampling_type,
401
+ )
402
+
403
+ return variance_adaptor
404
+
405
+ def build_model(self) -> PretsselVocoder:
406
+ """build the pretssel vocoder."""
407
+ prosody_encoder = self.prosody_encoder_builder.build_model()
408
+ embed_tokens = self.build_embed_tokens()
409
+
410
+ embed_positions = SinusoidalPositionEncoder(
411
+ self.config.model_dim,
412
+ self.config.max_seq_len,
413
+ _legacy_pad_idx=self.config.vocab_info.pad_idx,
414
+ device=self.device,
415
+ )
416
+ lang_to_index = {l: i for i, l in enumerate(self.config.langs)}
417
+ encoder_frontend = PretsselEncoderFrontend(
418
+ prosody_encoder,
419
+ embed_tokens,
420
+ embed_positions,
421
+ lang_to_index,
422
+ lang_embed_dim=self.config.encoder_frontend_config.lang_embed_dim,
423
+ dropout_p=self.config.encoder_frontend_config.dropout,
424
+ device=self.device,
425
+ dtype=self.dtype,
426
+ )
427
+
428
+ encoder = self.build_fft(self.config.encoder_layers)
429
+
430
+ variance_adaptor = self.build_variance_adaptor(
431
+ self.config.decoder_frontend_config
432
+ )
433
+
434
+ decoder_frontend = PretsselDecoderFrontend(
435
+ variance_adaptor,
436
+ embed_positions,
437
+ device=self.device,
438
+ dtype=self.dtype,
439
+ )
440
+
441
+ decoder = self.build_fft(self.config.decoder_layers)
442
+
443
+ final_proj = Linear(
444
+ self.config.model_dim,
445
+ self.config.mel_dim,
446
+ bias=True,
447
+ device=self.device,
448
+ dtype=self.dtype,
449
+ )
450
+
451
+ gcmvn_mean = gcmvn_std = None
452
+ if self.config.gcmvn_stats is not None:
453
+ gcmvn_mean = self.config.gcmvn_stats["mean"]
454
+ gcmvn_std = self.config.gcmvn_stats["std"]
455
+
456
+ vocoder = PretsselVocoder(
457
+ encoder_frontend=encoder_frontend,
458
+ encoder=encoder,
459
+ decoder_frontend=decoder_frontend,
460
+ decoder=decoder,
461
+ final_proj=final_proj,
462
+ pn_n_channels=self.config.pn_conv_dim,
463
+ pn_kernel_size=self.config.pn_conv_kernel_size,
464
+ pn_layers=self.config.pn_layers,
465
+ pn_dropout=self.config.pn_dropout,
466
+ upsample_rates=self.config.upsample_rates,
467
+ upsample_kernel_sizes=self.config.upsample_kernel_sizes,
468
+ upsample_initial_channel=self.config.upsample_initial_channel,
469
+ resblock_kernel_sizes=self.config.resblock_kernel_sizes,
470
+ resblock_dilation_sizes=self.config.resblock_dilation_sizes,
471
+ channels=self.config.channels,
472
+ dimension=self.config.dimension,
473
+ n_filters=self.config.n_filters,
474
+ ratios=self.config.ratios,
475
+ norm=self.config.norm,
476
+ norm_params=self.config.norm_params,
477
+ kernel_size=self.config.kernel_size,
478
+ last_kernel_size=self.config.last_kernel_size,
479
+ residual_kernel_size=self.config.residual_kernel_size,
480
+ causal=self.config.causal,
481
+ pad_mode=self.config.pad_mode,
482
+ true_skip=self.config.true_skip,
483
+ compress=self.config.compress,
484
+ lstm=self.config.lstm,
485
+ disable_norm_outer_blocks=self.config.disable_norm_outer_blocks,
486
+ trim_right_ratio=self.config.trim_right_ratio,
487
+ gcmvn_mean=gcmvn_mean,
488
+ gcmvn_std=gcmvn_std,
489
+ )
490
+ vocoder.to(dtype=self.dtype, device=self.device)
491
+ return vocoder
492
+
493
+
494
+ def create_vocoder_model(
495
+ config: VocoderConfig,
496
+ device: Optional[Device] = None,
497
+ dtype: Optional[DataType] = None,
498
+ ) -> PretsselVocoder:
499
+ prosody_encoder_builder = EcapaTDNNBuilder(
500
+ config.encoder_frontend_config.prosody_encoder_config,
501
+ device=device,
502
+ dtype=dtype,
503
+ )
504
+ return PretsselVocoderBuilder(
505
+ config, prosody_encoder_builder, device=device, dtype=dtype
506
+ ).build_model()