victan commited on
Commit
78232b8
·
1 Parent(s): 44dfa65

Upload seamless_communication/models/generator/vocoder.py with huggingface_hub

Browse files
seamless_communication/models/generator/vocoder.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Literal, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairseq2.nn.embedding import Embedding, StandardEmbedding
12
+ from fairseq2.nn.padding import PaddingMask
13
+ from fairseq2.nn.position_encoder import PositionEncoder
14
+ from fairseq2.nn.projection import Projection
15
+ from fairseq2.typing import DataType, Device
16
+ from torch.nn import (
17
+ ELU,
18
+ BatchNorm1d,
19
+ Conv1d,
20
+ ConvTranspose1d,
21
+ Dropout,
22
+ Module,
23
+ ModuleList,
24
+ Parameter,
25
+ Sequential,
26
+ Tanh,
27
+ init,
28
+ )
29
+ from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
30
+
31
+ from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
32
+ from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
33
+ from seamless_communication.models.unity.length_regulator import VarianceAdaptor
34
+ from seamless_communication.models.vocoder.hifigan import (
35
+ LRELU_SLOPE,
36
+ ResBlock,
37
+ init_weights,
38
+ )
39
+
40
+ from .streamable import (
41
+ StreamableConv1d,
42
+ StreamableConvTranspose1d,
43
+ StreamableLSTM,
44
+ StreamableResnetBlock,
45
+ )
46
+
47
+ ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
48
+
49
+
50
+ class PretsselEncoderFrontend(Module):
51
+ """
52
+ Represent Encoder frontend, including the prosody encoder and language embedding
53
+ """
54
+
55
+ prosody_encoder: ECAPA_TDNN
56
+ embed_tokens: Embedding
57
+ embed_positions: PositionEncoder
58
+ pos_emb_alpha: Parameter
59
+ embed_lang: Embedding
60
+ dropout: Dropout
61
+
62
+ def __init__(
63
+ self,
64
+ prosody_encoder: ECAPA_TDNN,
65
+ embed_tokens: Embedding,
66
+ embed_positions: PositionEncoder,
67
+ lang_to_index: Dict[str, int],
68
+ lang_embed_dim: Optional[int],
69
+ dropout_p: float,
70
+ device: Optional[Device] = None,
71
+ dtype: Optional[DataType] = None,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.prosody_encoder = prosody_encoder
76
+
77
+ self.embed_tokens = embed_tokens
78
+
79
+ self.embed_positions = embed_positions
80
+ self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
81
+
82
+ self.lang_to_index = lang_to_index
83
+
84
+ if lang_embed_dim is not None:
85
+ self.embed_lang = StandardEmbedding(
86
+ len(lang_to_index), lang_embed_dim, device=device, dtype=dtype
87
+ )
88
+ else:
89
+ self.register_module("embed_lang", None)
90
+
91
+ self.dropout = Dropout(dropout_p)
92
+
93
+ self.device = device
94
+ self.dtype = dtype
95
+
96
+ def forward(
97
+ self,
98
+ seqs: torch.Tensor,
99
+ padding_mask: Optional[PaddingMask],
100
+ prosody_input_seqs: torch.Tensor,
101
+ prosody_padding_mask: Optional[PaddingMask],
102
+ tgt_lang: str,
103
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ prosody_embs = self.prosody_encoder(
105
+ prosody_input_seqs,
106
+ prosody_padding_mask,
107
+ ).unsqueeze(1)
108
+
109
+ if self.embed_lang is not None:
110
+ lang_index = self.lang_to_index[tgt_lang]
111
+ lang_index_tensor = (
112
+ torch.Tensor([lang_index]).to(seqs).repeat(seqs.size(0), 1)
113
+ )
114
+ lang_embeds = self.embed_lang(lang_index_tensor)
115
+ prosody_embs = torch.cat([prosody_embs, lang_embeds], dim=-1)
116
+
117
+ seqs = self.embed_tokens(seqs)
118
+ seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
119
+ seqs = self.dropout(seqs)
120
+
121
+ return seqs, prosody_embs
122
+
123
+
124
+ class PretsselDecoderFrontend(Module):
125
+ """Represent Decoder frontend, including VarianceAdaptor & Positional embedding"""
126
+
127
+ variance_adaptor: VarianceAdaptor
128
+ embed_positions: PositionEncoder
129
+ pos_emb_alpha: Parameter
130
+
131
+ def __init__(
132
+ self,
133
+ variance_adaptor: VarianceAdaptor,
134
+ embed_positions: PositionEncoder,
135
+ device: Optional[Device] = None,
136
+ dtype: Optional[DataType] = None,
137
+ ):
138
+ super().__init__()
139
+
140
+ self.variance_adaptor = variance_adaptor
141
+ self.embed_positions = embed_positions
142
+ self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
143
+
144
+ self.device = device
145
+ self.dtype = dtype
146
+
147
+ def forward(
148
+ self,
149
+ seqs: torch.Tensor,
150
+ padding_mask: PaddingMask,
151
+ durations: Optional[torch.Tensor] = None,
152
+ duration_factor: float = 1.0,
153
+ min_duration: int = 0,
154
+ film_cond_emb: Optional[torch.Tensor] = None,
155
+ ) -> Tuple[torch.Tensor, PaddingMask]:
156
+ seqs, padding_mask, _ = self.variance_adaptor(
157
+ seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb
158
+ )
159
+
160
+ seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
161
+
162
+ return seqs, padding_mask
163
+
164
+
165
+ class PretsselVocoder(Module):
166
+ """The expressivity-preserving vocoder"""
167
+
168
+ encoder_frontend: PretsselEncoderFrontend
169
+ encoder: FeedForwardTransformer
170
+ decoder_frontend: PretsselDecoderFrontend
171
+ decoder: FeedForwardTransformer
172
+ final_proj: Projection
173
+
174
+ def __init__( # type: ignore[no-untyped-def]
175
+ self,
176
+ encoder_frontend: PretsselEncoderFrontend,
177
+ encoder: FeedForwardTransformer,
178
+ decoder_frontend: PretsselDecoderFrontend,
179
+ decoder: FeedForwardTransformer,
180
+ final_proj: Projection,
181
+ pn_n_channels: int,
182
+ pn_kernel_size: int,
183
+ pn_layers: int,
184
+ pn_dropout: float,
185
+ upsample_rates: List[int],
186
+ upsample_kernel_sizes: List[int],
187
+ upsample_initial_channel: int,
188
+ resblock_kernel_sizes: List[int],
189
+ resblock_dilation_sizes: List[List[int]],
190
+ mel_dim: int = 80,
191
+ add_ups_out_pad: bool = True,
192
+ channels: int = 1,
193
+ dimension: int = 128,
194
+ n_filters: int = 32,
195
+ ratios: List[int] = [8, 5, 4, 2],
196
+ norm: Literal[
197
+ "none", "weight_norm", "spectral_norm", "time_group_norm"
198
+ ] = "none",
199
+ norm_params: Dict[str, Any] = {},
200
+ kernel_size: int = 7,
201
+ last_kernel_size: int = 7,
202
+ residual_kernel_size: int = 3,
203
+ causal: bool = False,
204
+ pad_mode: str = "constant",
205
+ true_skip: bool = True,
206
+ compress: int = 2,
207
+ lstm: int = 0,
208
+ disable_norm_outer_blocks: int = 0,
209
+ trim_right_ratio: float = 1.0,
210
+ gcmvn_mean: Optional[List[float]] = None,
211
+ gcmvn_std: Optional[List[float]] = None,
212
+ device: Optional[Device] = None,
213
+ dtype: Optional[DataType] = None,
214
+ ):
215
+ super().__init__()
216
+ self.encoder_frontend = encoder_frontend
217
+ self.encoder = encoder
218
+ self.decoder_frontend = decoder_frontend
219
+ self.decoder = decoder
220
+ self.final_proj = final_proj
221
+ mult = 1
222
+ stream_layers: List[Module] = [
223
+ StreamableConv1d(
224
+ channels,
225
+ mult * n_filters,
226
+ kernel_size,
227
+ norm="none" if disable_norm_outer_blocks >= 1 else norm,
228
+ norm_kwargs=norm_params,
229
+ causal=causal,
230
+ pad_mode=pad_mode,
231
+ activation=Tanh(),
232
+ device=device,
233
+ dtype=dtype,
234
+ )
235
+ ]
236
+ # Downsample to from audio scale
237
+ for i, ratio in enumerate(list(reversed(ratios))):
238
+ block_norm = "none" if disable_norm_outer_blocks >= i + 2 else norm
239
+ stream_layers.append(
240
+ StreamableResnetBlock(
241
+ mult * n_filters,
242
+ kernel_sizes=[residual_kernel_size, 1],
243
+ dilations=[1, 1],
244
+ norm=block_norm,
245
+ norm_params=norm_params,
246
+ causal=causal,
247
+ pad_mode=pad_mode,
248
+ compress=compress,
249
+ true_skip=true_skip,
250
+ device=device,
251
+ dtype=dtype,
252
+ )
253
+ )
254
+ stream_layers.append(ELU(**ELU_PARAMS))
255
+ stream_layers.append(
256
+ StreamableConv1d(
257
+ mult * n_filters,
258
+ mult * n_filters * 2,
259
+ kernel_size=ratio * 2,
260
+ stride=ratio,
261
+ norm=block_norm,
262
+ norm_kwargs=norm_params,
263
+ causal=causal,
264
+ pad_mode=pad_mode,
265
+ device=device,
266
+ dtype=dtype,
267
+ )
268
+ )
269
+ mult *= 2
270
+
271
+ stream_layers.append(StreamableLSTM(mult * n_filters, num_layers=lstm))
272
+ stream_layers.append(ELU(**ELU_PARAMS))
273
+ n_blocks = len(ratios) + 2
274
+ stream_layers.append(
275
+ StreamableConv1d(
276
+ mult * n_filters,
277
+ dimension,
278
+ last_kernel_size,
279
+ norm="none" if disable_norm_outer_blocks == n_blocks else norm,
280
+ norm_kwargs=norm_params,
281
+ causal=causal,
282
+ pad_mode=pad_mode,
283
+ device=device,
284
+ dtype=dtype,
285
+ )
286
+ )
287
+ stream_layers.append(
288
+ StreamableConv1d(
289
+ dimension,
290
+ mult * n_filters,
291
+ kernel_size,
292
+ norm="none" if disable_norm_outer_blocks == n_blocks else norm,
293
+ norm_kwargs=norm_params,
294
+ causal=causal,
295
+ pad_mode=pad_mode,
296
+ device=device,
297
+ dtype=dtype,
298
+ )
299
+ )
300
+ stream_layers.append(
301
+ StreamableLSTM(
302
+ mult * n_filters, num_layers=lstm, device=device, dtype=dtype
303
+ )
304
+ )
305
+
306
+ # resample back to raw audio scale
307
+ for i, ratio in enumerate(ratios):
308
+ block_norm = (
309
+ "none" if disable_norm_outer_blocks >= n_blocks - (i + 1) else norm
310
+ )
311
+ stream_layers.append(ELU(**ELU_PARAMS))
312
+ stream_layers.append(
313
+ StreamableConvTranspose1d(
314
+ mult * n_filters,
315
+ mult * n_filters // 2,
316
+ kernel_size=ratio * 2,
317
+ stride=ratio,
318
+ norm=block_norm,
319
+ norm_kwargs=norm_params,
320
+ causal=causal,
321
+ trim_right_ratio=trim_right_ratio,
322
+ device=device,
323
+ dtype=dtype,
324
+ )
325
+ )
326
+ stream_layers.append(
327
+ StreamableResnetBlock(
328
+ mult * n_filters // 2,
329
+ kernel_sizes=[residual_kernel_size, 1],
330
+ dilations=[1, 1],
331
+ norm=block_norm,
332
+ norm_params=norm_params,
333
+ activation_params=ELU_PARAMS,
334
+ causal=causal,
335
+ pad_mode=pad_mode,
336
+ compress=compress,
337
+ true_skip=true_skip,
338
+ device=device,
339
+ dtype=dtype,
340
+ )
341
+ )
342
+ mult //= 2
343
+
344
+ stream_layers.append(ELU(**ELU_PARAMS))
345
+ stream_layers.append(
346
+ StreamableConv1d(
347
+ n_filters,
348
+ channels,
349
+ last_kernel_size,
350
+ norm="none" if disable_norm_outer_blocks >= 1 else norm,
351
+ norm_kwargs=norm_params,
352
+ causal=causal,
353
+ pad_mode=pad_mode,
354
+ device=device,
355
+ dtype=dtype,
356
+ )
357
+ )
358
+ self.n_streams = len(stream_layers)
359
+ chunk_size = self.n_streams // 4
360
+ stream_idx = 0
361
+
362
+ self.pn_layers = pn_layers
363
+ self.layers = ModuleList()
364
+ assert pn_kernel_size % 2 == 1
365
+ for i in range(pn_layers):
366
+ cur_layers = (
367
+ [
368
+ Conv1d(
369
+ mel_dim if i == 0 else pn_n_channels,
370
+ pn_n_channels if i < pn_layers - 1 else mel_dim,
371
+ kernel_size=pn_kernel_size,
372
+ padding="same",
373
+ device=device,
374
+ dtype=dtype,
375
+ ),
376
+ BatchNorm1d(
377
+ pn_n_channels if i < pn_layers - 1 else mel_dim,
378
+ device=device,
379
+ dtype=dtype,
380
+ ),
381
+ ]
382
+ + ([Tanh()] if i < pn_layers - 1 else [])
383
+ + [Dropout(pn_dropout)]
384
+ )
385
+ self.layers.append(Sequential(*cur_layers))
386
+ self.reset_parameters()
387
+ self.layers.extend(stream_layers[:chunk_size])
388
+ stream_idx += chunk_size
389
+ self.layers.append(
390
+ weight_norm(
391
+ Conv1d(
392
+ mel_dim if mel_dim is not None else 80,
393
+ upsample_initial_channel,
394
+ 7,
395
+ 1,
396
+ padding="same",
397
+ device=device,
398
+ dtype=dtype,
399
+ )
400
+ )
401
+ )
402
+ self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size]) # noqa
403
+ stream_idx += chunk_size
404
+
405
+ self.num_kernels = len(resblock_kernel_sizes)
406
+ self.num_upsamples = len(upsample_rates)
407
+ ups = ModuleList()
408
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
409
+ out_pad = u % 2 if add_ups_out_pad else 0
410
+ ups.append(
411
+ weight_norm(
412
+ ConvTranspose1d(
413
+ upsample_initial_channel // (2**i),
414
+ upsample_initial_channel // (2 ** (i + 1)),
415
+ k,
416
+ u,
417
+ padding=(k - u) // 2 + out_pad,
418
+ output_padding=out_pad,
419
+ device=device,
420
+ dtype=dtype,
421
+ )
422
+ )
423
+ )
424
+ ups.apply(init_weights)
425
+ self.layers.extend(ups)
426
+ self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size]) # noqa
427
+ stream_idx += chunk_size
428
+
429
+ for i in range(self.num_upsamples):
430
+ ch = upsample_initial_channel // (2 ** (i + 1))
431
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
432
+ self.layers.append(
433
+ ResBlock(
434
+ ch,
435
+ k,
436
+ d,
437
+ ).to(device, dtype=dtype)
438
+ )
439
+ self.layers.extend(stream_layers[stream_idx:])
440
+
441
+ conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
442
+ conv_post.apply(init_weights)
443
+ self.layers.append(conv_post)
444
+ for u, k in zip(upsample_rates, upsample_kernel_sizes):
445
+ assert k == 2 * u, (k, u)
446
+
447
+ mean = torch.zeros((mel_dim,), dtype=torch.float)
448
+ scale = torch.zeros((mel_dim,), dtype=torch.float)
449
+ self.register_buffer("mean", mean)
450
+ self.register_buffer("scale", scale)
451
+
452
+ self.gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)
453
+ self.gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)
454
+
455
+ def reset_parameters(self) -> None:
456
+ for i in range(self.pn_layers):
457
+ init.xavier_uniform_(
458
+ self.layers[i][0].weight,
459
+ init.calculate_gain("tanh" if i < self.pn_layers - 1 else "linear"),
460
+ )
461
+
462
+ def gcmvn_denormalize(self, x: torch.Tensor) -> torch.Tensor:
463
+ if self.gcmvn_mean is None or self.gcmvn_std is None:
464
+ raise ValueError("gcmvn_mean is not set")
465
+
466
+ assert (
467
+ x.ndim == 3
468
+ and x.shape[2] == self.gcmvn_mean.shape[0]
469
+ and x.shape[2] == self.gcmvn_std.shape[0]
470
+ )
471
+ gcmvn_mean = self.gcmvn_mean.to(x)
472
+ gcmvn_std = self.gcmvn_std.to(x)
473
+ x = x * gcmvn_std.view(1, 1, -1).expand_as(x) # type: ignore[attr-defined]
474
+ return x + gcmvn_mean.view(1, 1, -1).expand_as(x) # type: ignore[attr-defined,no-any-return]
475
+
476
+ def forward(
477
+ self,
478
+ seqs: torch.Tensor,
479
+ tgt_lang: str,
480
+ prosody_input_seqs: torch.Tensor,
481
+ padding_mask: Optional[PaddingMask] = None,
482
+ prosody_padding_mask: Optional[PaddingMask] = None,
483
+ durations: Optional[torch.Tensor] = None,
484
+ duration_factor: float = 1.0,
485
+ min_duration: int = 0,
486
+ normalize_before: bool = True,
487
+ ) -> List[torch.Tensor]:
488
+ # Here we are adding batch dimension for the pretssel
489
+ if seqs.ndim < 2:
490
+ seqs = seqs.unsqueeze(0)
491
+ if prosody_input_seqs.ndim < 3:
492
+ prosody_input_seqs = prosody_input_seqs.unsqueeze(0)
493
+ seqs, cond_embs = self.encoder_frontend(
494
+ seqs,
495
+ padding_mask,
496
+ prosody_input_seqs,
497
+ prosody_padding_mask,
498
+ tgt_lang,
499
+ )
500
+ seqs, padding_mask = self.encoder(seqs, padding_mask, cond_embs)
501
+ seqs, padding_mask = self.decoder_frontend(
502
+ seqs, padding_mask, durations, duration_factor, min_duration, cond_embs
503
+ )
504
+ seqs, padding_mask = self.decoder(seqs, padding_mask, cond_embs)
505
+ seqs = self.final_proj(seqs)
506
+
507
+ pn = seqs.transpose(1, 2) # B x T x C -> B x C x T
508
+ for i in range(self.pn_layers):
509
+ pn = self.layers[i](pn)
510
+ pn = pn.transpose(1, 2)
511
+
512
+ x = seqs + pn
513
+ x = self.gcmvn_denormalize(x)
514
+
515
+ wavs = []
516
+ for idx, _x in enumerate(x):
517
+ _x = _x[: durations[idx].sum()] # type: ignore[index]
518
+ if normalize_before:
519
+ _x = (_x - self.mean) / self.scale
520
+
521
+ _x = _x.transpose(1, 0).unsqueeze(0)
522
+ chunk_size = self.n_streams // 4
523
+ _x = self.layers[self.pn_layers + chunk_size](_x)
524
+ for i in range(self.num_upsamples):
525
+ _x = F.leaky_relu(_x, LRELU_SLOPE)
526
+ _x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](_x)
527
+ xs = None
528
+ for j in range(self.num_kernels):
529
+ if xs is None:
530
+ xs = self.layers[
531
+ i * self.num_kernels
532
+ + j
533
+ + self.pn_layers
534
+ + 3 * chunk_size
535
+ + self.num_upsamples
536
+ + 1
537
+ ](_x)
538
+ else:
539
+ xs += self.layers[
540
+ i * self.num_kernels
541
+ + j
542
+ + self.pn_layers
543
+ + 3 * chunk_size
544
+ + self.num_upsamples
545
+ + 1
546
+ ](_x)
547
+ _x = xs / self.num_kernels # type: ignore
548
+ _x = F.leaky_relu(_x)
549
+ _x = self.layers[
550
+ self.pn_layers
551
+ + self.n_streams
552
+ + self.num_upsamples * (1 + self.num_kernels)
553
+ + 1
554
+ ](_x)
555
+ skip_output = _x
556
+ h = skip_output
557
+
558
+ for i1 in range(self.pn_layers, self.pn_layers + chunk_size):
559
+ h = self.layers[i1](h)
560
+ i1 += 2
561
+ for i2 in range(i1, i1 + chunk_size):
562
+ h = self.layers[i2](h)
563
+ i2 = i2 + self.num_upsamples + 1
564
+
565
+ for i3 in range(i2, i2 + chunk_size):
566
+ h = self.layers[i3](h)
567
+ i3 = i3 + (self.num_upsamples * self.num_kernels) + 1
568
+ for i4 in range(i3, i3 + chunk_size):
569
+ h = self.layers[i4](h)
570
+ h = h[:, :, : _x.size(-1)]
571
+
572
+ wavs.append(0.8 * h + torch.tanh(skip_output).squeeze(0))
573
+ return wavs
574
+
575
+ def remove_weight_norm(self) -> None:
576
+ i = self.pn_layers + 1
577
+ for j in range(self.num_upsamples):
578
+ remove_weight_norm(self.layers[i + j])
579
+ for k in range(self.num_upsamples * self.num_kernels):
580
+ self.layers[i + j + k + 1].remove_weight_norm()
581
+ remove_weight_norm(self.layers[self.pn_layers])
582
+ remove_weight_norm(
583
+ self.layers[
584
+ self.pn_layers + 1 + self.num_upsamples * (1 + self.num_kernels)
585
+ ]
586
+ )