m4singer commited on
Commit
d2fa653
·
1 Parent(s): feddfa7
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. checkpoints/m4singer_diff_e2e/config.yaml +348 -0
  2. checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt +3 -0
  3. checkpoints/m4singer_fs2_e2e/config.yaml +347 -0
  4. checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt +3 -0
  5. checkpoints/m4singer_hifigan/config.yaml +246 -0
  6. checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt +3 -0
  7. checkpoints/m4singer_pe/config.yaml +172 -0
  8. checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt +3 -0
  9. configs/config_base.yaml +42 -0
  10. configs/singing/base.yaml +42 -0
  11. configs/singing/fs2.yaml +3 -0
  12. configs/tts/base.yaml +95 -0
  13. configs/tts/base_zh.yaml +3 -0
  14. configs/tts/fs2.yaml +80 -0
  15. configs/tts/hifigan.yaml +21 -0
  16. configs/tts/lj/base_mel2wav.yaml +3 -0
  17. configs/tts/lj/base_text2mel.yaml +13 -0
  18. configs/tts/lj/fs2.yaml +3 -0
  19. configs/tts/lj/hifigan.yaml +3 -0
  20. configs/tts/lj/pwg.yaml +3 -0
  21. configs/tts/pwg.yaml +110 -0
  22. data_gen/singing/binarize.py +393 -0
  23. data_gen/tts/base_binarizer.py +224 -0
  24. data_gen/tts/bin/binarize.py +20 -0
  25. data_gen/tts/binarizer_zh.py +59 -0
  26. data_gen/tts/data_gen_utils.py +347 -0
  27. data_gen/tts/txt_processors/base_text_processor.py +8 -0
  28. data_gen/tts/txt_processors/en.py +78 -0
  29. data_gen/tts/txt_processors/zh.py +41 -0
  30. data_gen/tts/txt_processors/zh_g2pM.py +71 -0
  31. inference/m4singer/base_svs_infer.py +242 -0
  32. inference/m4singer/ds_e2e.py +67 -0
  33. inference/m4singer/gradio/gradio_settings.yaml +31 -0
  34. inference/m4singer/gradio/infer.py +104 -0
  35. inference/m4singer/m4singer/m4singer_pinyin2ph.txt +413 -0
  36. inference/m4singer/m4singer/map.py +7 -0
  37. modules/__init__.py +0 -0
  38. modules/commons/common_layers.py +668 -0
  39. modules/commons/espnet_positional_embedding.py +113 -0
  40. modules/commons/ssim.py +391 -0
  41. modules/diffsinger_midi/fs2.py +118 -0
  42. modules/fastspeech/fs2.py +255 -0
  43. modules/fastspeech/pe.py +149 -0
  44. modules/fastspeech/tts_modules.py +357 -0
  45. modules/hifigan/hifigan.py +370 -0
  46. modules/hifigan/mel_utils.py +81 -0
  47. modules/parallel_wavegan/__init__.py +0 -0
  48. modules/parallel_wavegan/layers/__init__.py +5 -0
  49. modules/parallel_wavegan/layers/causal_conv.py +56 -0
  50. modules/parallel_wavegan/layers/pqmf.py +129 -0
checkpoints/m4singer_diff_e2e/config.yaml ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ K_step: 1000
2
+ accumulate_grad_batches: 1
3
+ audio_num_mel_bins: 80
4
+ audio_sample_rate: 24000
5
+ base_config:
6
+ - usr/configs/m4singer/base.yaml
7
+ binarization_args:
8
+ shuffle: false
9
+ with_align: true
10
+ with_f0: true
11
+ with_f0cwt: true
12
+ with_spk_embed: true
13
+ with_txt: true
14
+ with_wav: false
15
+ binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
16
+ binary_data_dir: data/binary/m4singer
17
+ check_val_every_n_epoch: 10
18
+ clip_grad_norm: 1
19
+ content_cond_steps: []
20
+ cwt_add_f0_loss: false
21
+ cwt_hidden_size: 128
22
+ cwt_layers: 2
23
+ cwt_loss: l1
24
+ cwt_std_scale: 0.8
25
+ datasets:
26
+ - m4singer
27
+ debug: false
28
+ dec_ffn_kernel_size: 9
29
+ dec_layers: 4
30
+ decay_steps: 100000
31
+ decoder_type: fft
32
+ dict_dir: ''
33
+ diff_decoder_type: wavenet
34
+ diff_loss_type: l1
35
+ dilation_cycle_length: 4
36
+ dropout: 0.1
37
+ ds_workers: 4
38
+ dur_enc_hidden_stride_kernel:
39
+ - 0,2,3
40
+ - 0,2,3
41
+ - 0,1,3
42
+ dur_loss: mse
43
+ dur_predictor_kernel: 3
44
+ dur_predictor_layers: 5
45
+ enc_ffn_kernel_size: 9
46
+ enc_layers: 4
47
+ encoder_K: 8
48
+ encoder_type: fft
49
+ endless_ds: true
50
+ ffn_act: gelu
51
+ ffn_padding: SAME
52
+ fft_size: 512
53
+ fmax: 12000
54
+ fmin: 30
55
+ fs2_ckpt: checkpoints/m4singer_fs2_e2e
56
+ gaussian_start: true
57
+ gen_dir_name: ''
58
+ gen_tgt_spk_id: -1
59
+ hidden_size: 256
60
+ hop_size: 128
61
+ infer: false
62
+ keep_bins: 80
63
+ lambda_commit: 0.25
64
+ lambda_energy: 0.0
65
+ lambda_f0: 0.0
66
+ lambda_ph_dur: 1.0
67
+ lambda_sent_dur: 1.0
68
+ lambda_uv: 0.0
69
+ lambda_word_dur: 1.0
70
+ load_ckpt: ''
71
+ log_interval: 100
72
+ loud_norm: false
73
+ lr: 0.001
74
+ max_beta: 0.02
75
+ max_epochs: 1000
76
+ max_eval_sentences: 1
77
+ max_eval_tokens: 60000
78
+ max_frames: 5000
79
+ max_input_tokens: 1550
80
+ max_sentences: 28
81
+ max_tokens: 36000
82
+ max_updates: 900000
83
+ mel_loss: ssim:0.5|l1:0.5
84
+ mel_vmax: 1.5
85
+ mel_vmin: -6.0
86
+ min_level_db: -120
87
+ norm_type: gn
88
+ num_ckpt_keep: 3
89
+ num_heads: 2
90
+ num_sanity_val_steps: 1
91
+ num_spk: 20
92
+ num_test_samples: 0
93
+ num_valid_plots: 10
94
+ optimizer_adam_beta1: 0.9
95
+ optimizer_adam_beta2: 0.98
96
+ out_wav_norm: false
97
+ pe_ckpt: checkpoints/m4singer_pe
98
+ pe_enable: true
99
+ pitch_ar: false
100
+ pitch_enc_hidden_stride_kernel:
101
+ - 0,2,5
102
+ - 0,2,5
103
+ - 0,2,5
104
+ pitch_extractor: parselmouth
105
+ pitch_loss: l1
106
+ pitch_norm: log
107
+ pitch_type: frame
108
+ pndm_speedup: 20
109
+ pre_align_args:
110
+ allow_no_txt: false
111
+ denoise: false
112
+ forced_align: mfa
113
+ txt_processor: zh_g2pM
114
+ use_sox: true
115
+ use_tone: false
116
+ pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
117
+ predictor_dropout: 0.5
118
+ predictor_grad: 0.1
119
+ predictor_hidden: -1
120
+ predictor_kernel: 5
121
+ predictor_layers: 5
122
+ prenet_dropout: 0.5
123
+ prenet_hidden_size: 256
124
+ pretrain_fs_ckpt: ''
125
+ processed_data_dir: xxx
126
+ profile_infer: false
127
+ raw_data_dir: data/raw/m4singer
128
+ ref_norm_layer: bn
129
+ rel_pos: true
130
+ reset_phone_dict: true
131
+ residual_channels: 256
132
+ residual_layers: 20
133
+ save_best: false
134
+ save_ckpt: true
135
+ save_codes:
136
+ - configs
137
+ - modules
138
+ - tasks
139
+ - utils
140
+ - usr
141
+ save_f0: true
142
+ save_gt: true
143
+ schedule_type: linear
144
+ seed: 1234
145
+ sort_by_len: true
146
+ spec_max:
147
+ - -0.3894500136375427
148
+ - -0.3796464204788208
149
+ - -0.2914905250072479
150
+ - -0.15550297498703003
151
+ - -0.08502643555402756
152
+ - 0.10698417574167252
153
+ - -0.0739326998591423
154
+ - -0.0541548952460289
155
+ - 0.15501998364925385
156
+ - 0.06483431905508041
157
+ - 0.03054228238761425
158
+ - -0.013737732544541359
159
+ - -0.004876468330621719
160
+ - 0.04368264228105545
161
+ - 0.13329921662807465
162
+ - 0.16471388936042786
163
+ - 0.04605761915445328
164
+ - -0.05680707097053528
165
+ - 0.0542571023106575
166
+ - -0.0076539707370102406
167
+ - -0.00953489076346159
168
+ - -0.04434828832745552
169
+ - 0.001293870504014194
170
+ - -0.12238839268684387
171
+ - 0.06418416649103165
172
+ - 0.02843189612030983
173
+ - 0.08505241572856903
174
+ - 0.07062800228595734
175
+ - 0.00120724702719599
176
+ - -0.07675088942050934
177
+ - 0.03785804659128189
178
+ - 0.04890783503651619
179
+ - -0.06888376921415329
180
+ - -0.0839693546295166
181
+ - -0.17545585334300995
182
+ - -0.2911079525947571
183
+ - -0.4238220453262329
184
+ - -0.262084037065506
185
+ - -0.3002263605594635
186
+ - -0.3845032751560211
187
+ - -0.3906497061252594
188
+ - -0.6550108790397644
189
+ - -0.7810799479484558
190
+ - -0.7503029704093933
191
+ - -0.7995198965072632
192
+ - -0.8092347383499146
193
+ - -0.6196113228797913
194
+ - -0.6684317588806152
195
+ - -0.7735874056816101
196
+ - -0.8324533104896545
197
+ - -0.9601566791534424
198
+ - -0.955253541469574
199
+ - -0.748817503452301
200
+ - -0.9106167554855347
201
+ - -0.9707801342010498
202
+ - -1.053107500076294
203
+ - -1.0448424816131592
204
+ - -1.1082794666290283
205
+ - -1.1296544075012207
206
+ - -1.071642279624939
207
+ - -1.1003081798553467
208
+ - -1.166810154914856
209
+ - -1.1408926248550415
210
+ - -1.1330615282058716
211
+ - -1.1167492866516113
212
+ - -1.0716774463653564
213
+ - -1.035891056060791
214
+ - -1.0092483758926392
215
+ - -0.9675999879837036
216
+ - -0.938962996006012
217
+ - -1.0120564699172974
218
+ - -0.9777995347976685
219
+ - -1.029313564300537
220
+ - -0.9459163546562195
221
+ - -0.8519706130027771
222
+ - -0.7751091122627258
223
+ - -0.7933766841888428
224
+ - -0.9019735455513
225
+ - -0.9983296990394592
226
+ - -1.505873441696167
227
+ spec_min:
228
+ - -6.0
229
+ - -6.0
230
+ - -6.0
231
+ - -6.0
232
+ - -6.0
233
+ - -6.0
234
+ - -6.0
235
+ - -6.0
236
+ - -6.0
237
+ - -6.0
238
+ - -6.0
239
+ - -6.0
240
+ - -6.0
241
+ - -6.0
242
+ - -6.0
243
+ - -6.0
244
+ - -6.0
245
+ - -6.0
246
+ - -6.0
247
+ - -6.0
248
+ - -6.0
249
+ - -6.0
250
+ - -6.0
251
+ - -6.0
252
+ - -6.0
253
+ - -6.0
254
+ - -6.0
255
+ - -6.0
256
+ - -6.0
257
+ - -6.0
258
+ - -6.0
259
+ - -6.0
260
+ - -6.0
261
+ - -6.0
262
+ - -6.0
263
+ - -6.0
264
+ - -6.0
265
+ - -6.0
266
+ - -6.0
267
+ - -6.0
268
+ - -6.0
269
+ - -6.0
270
+ - -6.0
271
+ - -6.0
272
+ - -6.0
273
+ - -6.0
274
+ - -6.0
275
+ - -6.0
276
+ - -6.0
277
+ - -6.0
278
+ - -6.0
279
+ - -6.0
280
+ - -6.0
281
+ - -6.0
282
+ - -6.0
283
+ - -6.0
284
+ - -6.0
285
+ - -6.0
286
+ - -6.0
287
+ - -6.0
288
+ - -6.0
289
+ - -6.0
290
+ - -6.0
291
+ - -6.0
292
+ - -6.0
293
+ - -6.0
294
+ - -6.0
295
+ - -6.0
296
+ - -6.0
297
+ - -6.0
298
+ - -6.0
299
+ - -6.0
300
+ - -6.0
301
+ - -6.0
302
+ - -6.0
303
+ - -6.0
304
+ - -6.0
305
+ - -6.0
306
+ - -6.0
307
+ - -6.0
308
+ spk_cond_steps: []
309
+ stop_token_weight: 5.0
310
+ task_cls: usr.diffsinger_task.DiffSingerMIDITask
311
+ test_ids: []
312
+ test_input_dir: ''
313
+ test_num: 0
314
+ test_prefixes:
315
+ - "Alto-2#\u5C81\u6708\u795E\u5077"
316
+ - "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
317
+ - "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
318
+ - "Tenor-1#\u7AE5\u8BDD"
319
+ - "Tenor-2#\u6D88\u6101"
320
+ - "Tenor-2#\u4E00\u8364\u4E00\u7D20"
321
+ - "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
322
+ - "Soprano-1#\u95EE\u6625"
323
+ test_set_name: test
324
+ timesteps: 1000
325
+ train_set_name: train
326
+ use_denoise: false
327
+ use_energy_embed: false
328
+ use_gt_dur: false
329
+ use_gt_f0: false
330
+ use_midi: true
331
+ use_nsf: true
332
+ use_pitch_embed: false
333
+ use_pos_embed: true
334
+ use_spk_embed: false
335
+ use_spk_id: true
336
+ use_split_spk_id: false
337
+ use_uv: true
338
+ use_var_enc: false
339
+ val_check_interval: 2000
340
+ valid_num: 0
341
+ valid_set_name: valid
342
+ vocoder: vocoders.hifigan.HifiGAN
343
+ vocoder_ckpt: checkpoints/m4singer_hifigan
344
+ warmup_updates: 2000
345
+ wav2spec_eps: 1e-6
346
+ weight_decay: 0
347
+ win_size: 512
348
+ work_dir: checkpoints/m4singer_diff_e2e
checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbea4e8b9712d2cca54cc07915859472a17f2f3b97a86f33a6c9974192bb5b47
3
+ size 392239086
checkpoints/m4singer_fs2_e2e/config.yaml ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ K_step: 51
2
+ accumulate_grad_batches: 1
3
+ audio_num_mel_bins: 80
4
+ audio_sample_rate: 24000
5
+ base_config:
6
+ - configs/singing/fs2.yaml
7
+ - usr/configs/m4singer/base.yaml
8
+ binarization_args:
9
+ shuffle: false
10
+ with_align: true
11
+ with_f0: true
12
+ with_f0cwt: true
13
+ with_spk_embed: true
14
+ with_txt: true
15
+ with_wav: false
16
+ binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
17
+ binary_data_dir: data/binary/m4singer
18
+ check_val_every_n_epoch: 10
19
+ clip_grad_norm: 1
20
+ content_cond_steps: []
21
+ cwt_add_f0_loss: false
22
+ cwt_hidden_size: 128
23
+ cwt_layers: 2
24
+ cwt_loss: l1
25
+ cwt_std_scale: 0.8
26
+ datasets:
27
+ - m4singer
28
+ debug: false
29
+ dec_ffn_kernel_size: 9
30
+ dec_layers: 4
31
+ decay_steps: 50000
32
+ decoder_type: fft
33
+ dict_dir: ''
34
+ diff_decoder_type: wavenet
35
+ diff_loss_type: l1
36
+ dilation_cycle_length: 1
37
+ dropout: 0.1
38
+ ds_workers: 4
39
+ dur_enc_hidden_stride_kernel:
40
+ - 0,2,3
41
+ - 0,2,3
42
+ - 0,1,3
43
+ dur_loss: mse
44
+ dur_predictor_kernel: 3
45
+ dur_predictor_layers: 5
46
+ enc_ffn_kernel_size: 9
47
+ enc_layers: 4
48
+ encoder_K: 8
49
+ encoder_type: fft
50
+ endless_ds: true
51
+ ffn_act: gelu
52
+ ffn_padding: SAME
53
+ fft_size: 512
54
+ fmax: 12000
55
+ fmin: 30
56
+ fs2_ckpt: ''
57
+ gen_dir_name: ''
58
+ gen_tgt_spk_id: -1
59
+ hidden_size: 256
60
+ hop_size: 128
61
+ infer: false
62
+ keep_bins: 80
63
+ lambda_commit: 0.25
64
+ lambda_energy: 0.0
65
+ lambda_f0: 1.0
66
+ lambda_ph_dur: 1.0
67
+ lambda_sent_dur: 1.0
68
+ lambda_uv: 1.0
69
+ lambda_word_dur: 1.0
70
+ load_ckpt: ''
71
+ log_interval: 100
72
+ loud_norm: false
73
+ lr: 1
74
+ max_beta: 0.06
75
+ max_epochs: 1000
76
+ max_eval_sentences: 1
77
+ max_eval_tokens: 60000
78
+ max_frames: 5000
79
+ max_input_tokens: 1550
80
+ max_sentences: 12
81
+ max_tokens: 40000
82
+ max_updates: 320000
83
+ mel_loss: ssim:0.5|l1:0.5
84
+ mel_vmax: 1.5
85
+ mel_vmin: -6.0
86
+ min_level_db: -120
87
+ norm_type: gn
88
+ num_ckpt_keep: 3
89
+ num_heads: 2
90
+ num_sanity_val_steps: 1
91
+ num_spk: 20
92
+ num_test_samples: 0
93
+ num_valid_plots: 10
94
+ optimizer_adam_beta1: 0.9
95
+ optimizer_adam_beta2: 0.98
96
+ out_wav_norm: false
97
+ pe_ckpt: checkpoints/m4singer_pe
98
+ pe_enable: true
99
+ pitch_ar: false
100
+ pitch_enc_hidden_stride_kernel:
101
+ - 0,2,5
102
+ - 0,2,5
103
+ - 0,2,5
104
+ pitch_extractor: parselmouth
105
+ pitch_loss: l1
106
+ pitch_norm: log
107
+ pitch_type: frame
108
+ pre_align_args:
109
+ allow_no_txt: false
110
+ denoise: false
111
+ forced_align: mfa
112
+ txt_processor: zh_g2pM
113
+ use_sox: true
114
+ use_tone: false
115
+ pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
116
+ predictor_dropout: 0.5
117
+ predictor_grad: 0.1
118
+ predictor_hidden: -1
119
+ predictor_kernel: 5
120
+ predictor_layers: 5
121
+ prenet_dropout: 0.5
122
+ prenet_hidden_size: 256
123
+ pretrain_fs_ckpt: ''
124
+ processed_data_dir: xxx
125
+ profile_infer: false
126
+ raw_data_dir: data/raw/m4singer
127
+ ref_norm_layer: bn
128
+ rel_pos: true
129
+ reset_phone_dict: true
130
+ residual_channels: 256
131
+ residual_layers: 20
132
+ save_best: false
133
+ save_ckpt: true
134
+ save_codes:
135
+ - configs
136
+ - modules
137
+ - tasks
138
+ - utils
139
+ - usr
140
+ save_f0: true
141
+ save_gt: true
142
+ schedule_type: linear
143
+ seed: 1234
144
+ sort_by_len: true
145
+ spec_max:
146
+ - -0.3894500136375427
147
+ - -0.3796464204788208
148
+ - -0.2914905250072479
149
+ - -0.15550297498703003
150
+ - -0.08502643555402756
151
+ - 0.10698417574167252
152
+ - -0.0739326998591423
153
+ - -0.0541548952460289
154
+ - 0.15501998364925385
155
+ - 0.06483431905508041
156
+ - 0.03054228238761425
157
+ - -0.013737732544541359
158
+ - -0.004876468330621719
159
+ - 0.04368264228105545
160
+ - 0.13329921662807465
161
+ - 0.16471388936042786
162
+ - 0.04605761915445328
163
+ - -0.05680707097053528
164
+ - 0.0542571023106575
165
+ - -0.0076539707370102406
166
+ - -0.00953489076346159
167
+ - -0.04434828832745552
168
+ - 0.001293870504014194
169
+ - -0.12238839268684387
170
+ - 0.06418416649103165
171
+ - 0.02843189612030983
172
+ - 0.08505241572856903
173
+ - 0.07062800228595734
174
+ - 0.00120724702719599
175
+ - -0.07675088942050934
176
+ - 0.03785804659128189
177
+ - 0.04890783503651619
178
+ - -0.06888376921415329
179
+ - -0.0839693546295166
180
+ - -0.17545585334300995
181
+ - -0.2911079525947571
182
+ - -0.4238220453262329
183
+ - -0.262084037065506
184
+ - -0.3002263605594635
185
+ - -0.3845032751560211
186
+ - -0.3906497061252594
187
+ - -0.6550108790397644
188
+ - -0.7810799479484558
189
+ - -0.7503029704093933
190
+ - -0.7995198965072632
191
+ - -0.8092347383499146
192
+ - -0.6196113228797913
193
+ - -0.6684317588806152
194
+ - -0.7735874056816101
195
+ - -0.8324533104896545
196
+ - -0.9601566791534424
197
+ - -0.955253541469574
198
+ - -0.748817503452301
199
+ - -0.9106167554855347
200
+ - -0.9707801342010498
201
+ - -1.053107500076294
202
+ - -1.0448424816131592
203
+ - -1.1082794666290283
204
+ - -1.1296544075012207
205
+ - -1.071642279624939
206
+ - -1.1003081798553467
207
+ - -1.166810154914856
208
+ - -1.1408926248550415
209
+ - -1.1330615282058716
210
+ - -1.1167492866516113
211
+ - -1.0716774463653564
212
+ - -1.035891056060791
213
+ - -1.0092483758926392
214
+ - -0.9675999879837036
215
+ - -0.938962996006012
216
+ - -1.0120564699172974
217
+ - -0.9777995347976685
218
+ - -1.029313564300537
219
+ - -0.9459163546562195
220
+ - -0.8519706130027771
221
+ - -0.7751091122627258
222
+ - -0.7933766841888428
223
+ - -0.9019735455513
224
+ - -0.9983296990394592
225
+ - -1.505873441696167
226
+ spec_min:
227
+ - -6.0
228
+ - -6.0
229
+ - -6.0
230
+ - -6.0
231
+ - -6.0
232
+ - -6.0
233
+ - -6.0
234
+ - -6.0
235
+ - -6.0
236
+ - -6.0
237
+ - -6.0
238
+ - -6.0
239
+ - -6.0
240
+ - -6.0
241
+ - -6.0
242
+ - -6.0
243
+ - -6.0
244
+ - -6.0
245
+ - -6.0
246
+ - -6.0
247
+ - -6.0
248
+ - -6.0
249
+ - -6.0
250
+ - -6.0
251
+ - -6.0
252
+ - -6.0
253
+ - -6.0
254
+ - -6.0
255
+ - -6.0
256
+ - -6.0
257
+ - -6.0
258
+ - -6.0
259
+ - -6.0
260
+ - -6.0
261
+ - -6.0
262
+ - -6.0
263
+ - -6.0
264
+ - -6.0
265
+ - -6.0
266
+ - -6.0
267
+ - -6.0
268
+ - -6.0
269
+ - -6.0
270
+ - -6.0
271
+ - -6.0
272
+ - -6.0
273
+ - -6.0
274
+ - -6.0
275
+ - -6.0
276
+ - -6.0
277
+ - -6.0
278
+ - -6.0
279
+ - -6.0
280
+ - -6.0
281
+ - -6.0
282
+ - -6.0
283
+ - -6.0
284
+ - -6.0
285
+ - -6.0
286
+ - -6.0
287
+ - -6.0
288
+ - -6.0
289
+ - -6.0
290
+ - -6.0
291
+ - -6.0
292
+ - -6.0
293
+ - -6.0
294
+ - -6.0
295
+ - -6.0
296
+ - -6.0
297
+ - -6.0
298
+ - -6.0
299
+ - -6.0
300
+ - -6.0
301
+ - -6.0
302
+ - -6.0
303
+ - -6.0
304
+ - -6.0
305
+ - -6.0
306
+ - -6.0
307
+ spk_cond_steps: []
308
+ stop_token_weight: 5.0
309
+ task_cls: usr.diffsinger_task.AuxDecoderMIDITask
310
+ test_ids: []
311
+ test_input_dir: ''
312
+ test_num: 0
313
+ test_prefixes:
314
+ - "Alto-2#\u5C81\u6708\u795E\u5077"
315
+ - "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
316
+ - "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
317
+ - "Tenor-1#\u7AE5\u8BDD"
318
+ - "Tenor-2#\u6D88\u6101"
319
+ - "Tenor-2#\u4E00\u8364\u4E00\u7D20"
320
+ - "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
321
+ - "Soprano-1#\u95EE\u6625"
322
+ test_set_name: test
323
+ timesteps: 100
324
+ train_set_name: train
325
+ use_denoise: false
326
+ use_energy_embed: false
327
+ use_gt_dur: false
328
+ use_gt_f0: false
329
+ use_midi: true
330
+ use_nsf: true
331
+ use_pitch_embed: false
332
+ use_pos_embed: true
333
+ use_spk_embed: false
334
+ use_spk_id: true
335
+ use_split_spk_id: false
336
+ use_uv: true
337
+ use_var_enc: false
338
+ val_check_interval: 2000
339
+ valid_num: 0
340
+ valid_set_name: valid
341
+ vocoder: vocoders.hifigan.HifiGAN
342
+ vocoder_ckpt: checkpoints/m4singer_hifigan
343
+ warmup_updates: 2000
344
+ wav2spec_eps: 1e-6
345
+ weight_decay: 0
346
+ win_size: 512
347
+ work_dir: checkpoints/m4singer_fs2_e2e
checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:993d7063a1773bd29d2810591f98152218a4cf8440e2b10c4761516a28f9d566
3
+ size 290456153
checkpoints/m4singer_hifigan/config.yaml ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_eval_tokens: 60000
2
+ max_eval_sentences: 1
3
+ save_ckpt: true
4
+ log_interval: 100
5
+ accumulate_grad_batches: 1
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ amp: false
9
+ audio_num_mel_bins: 80
10
+ audio_sample_rate: 24000
11
+ aux_context_window: 0
12
+ #base_config:
13
+ #- egs/egs_bases/singing/pwg.yaml
14
+ #- egs/egs_bases/tts/vocoder/hifigan.yaml
15
+ binarization_args:
16
+ reset_phone_dict: true
17
+ reset_word_dict: true
18
+ shuffle: false
19
+ trim_eos_bos: false
20
+ trim_sil: false
21
+ with_align: false
22
+ with_f0: true
23
+ with_f0cwt: false
24
+ with_linear: false
25
+ with_spk_embed: false
26
+ with_spk_id: true
27
+ with_txt: false
28
+ with_wav: true
29
+ with_word: false
30
+ binarizer_cls: data_gen.tts.singing.binarize.SingingBinarizer
31
+ binary_data_dir: data/binary/m4singer_vocoder
32
+ check_val_every_n_epoch: 10
33
+ clip_grad_norm: 1
34
+ clip_grad_value: 0
35
+ datasets: []
36
+ debug: false
37
+ dec_ffn_kernel_size: 9
38
+ dec_layers: 4
39
+ dict_dir: ''
40
+ disc_start_steps: 40000
41
+ discriminator_grad_norm: 1
42
+ discriminator_optimizer_params:
43
+ eps: 1.0e-06
44
+ lr: 0.0002
45
+ weight_decay: 0.0
46
+ discriminator_params:
47
+ bias: true
48
+ conv_channels: 64
49
+ in_channels: 1
50
+ kernel_size: 3
51
+ layers: 10
52
+ nonlinear_activation: LeakyReLU
53
+ nonlinear_activation_params:
54
+ negative_slope: 0.2
55
+ out_channels: 1
56
+ use_weight_norm: true
57
+ discriminator_scheduler_params:
58
+ gamma: 0.999
59
+ step_size: 600
60
+ dropout: 0.1
61
+ ds_workers: 1
62
+ enc_ffn_kernel_size: 9
63
+ enc_layers: 4
64
+ endless_ds: true
65
+ ffn_act: gelu
66
+ ffn_padding: SAME
67
+ fft_size: 512
68
+ fmax: 12000
69
+ fmin: 30
70
+ frames_multiple: 1
71
+ gen_dir_name: ''
72
+ generator_grad_norm: 10
73
+ generator_optimizer_params:
74
+ eps: 1.0e-06
75
+ lr: 0.0002
76
+ weight_decay: 0.0
77
+ generator_params:
78
+ aux_context_window: 0
79
+ aux_channels: 80
80
+ dropout: 0.0
81
+ gate_channels: 128
82
+ in_channels: 1
83
+ kernel_size: 3
84
+ layers: 30
85
+ out_channels: 1
86
+ residual_channels: 64
87
+ skip_channels: 64
88
+ stacks: 3
89
+ upsample_net: ConvInUpsampleNetwork
90
+ upsample_params:
91
+ upsample_scales:
92
+ - 2
93
+ - 4
94
+ - 4
95
+ - 4
96
+ use_nsf: false
97
+ use_pitch_embed: true
98
+ use_weight_norm: true
99
+ generator_scheduler_params:
100
+ gamma: 0.999
101
+ step_size: 600
102
+ griffin_lim_iters: 60
103
+ hidden_size: 256
104
+ hop_size: 128
105
+ infer: false
106
+ lambda_adv: 1.0
107
+ lambda_cdisc: 4.0
108
+ lambda_energy: 0.0
109
+ lambda_f0: 0.0
110
+ lambda_mel: 5.0
111
+ lambda_mel_adv: 1.0
112
+ lambda_ph_dur: 0.0
113
+ lambda_sent_dur: 0.0
114
+ lambda_uv: 0.0
115
+ lambda_word_dur: 0.0
116
+ load_ckpt: 'checkpoints/m4singer_hifigan'
117
+ loud_norm: false
118
+ lr: 2.0
119
+ max_epochs: 1000
120
+ max_frames: 2400
121
+ max_input_tokens: 1550
122
+ max_samples: 8192
123
+ max_sentences: 20
124
+ max_tokens: 24000
125
+ max_updates: 3000000
126
+ max_valid_sentences: 1
127
+ max_valid_tokens: 60000
128
+ mel_loss: ssim:0.5|l1:0.5
129
+ mel_vmax: 1.5
130
+ mel_vmin: -6
131
+ min_frames: 0
132
+ min_level_db: -120
133
+ num_ckpt_keep: 3
134
+ num_heads: 2
135
+ num_mels: 80
136
+ num_sanity_val_steps: 5
137
+ num_spk: 100
138
+ num_test_samples: 0
139
+ num_valid_plots: 10
140
+ optimizer_adam_beta1: 0.9
141
+ optimizer_adam_beta2: 0.98
142
+ out_wav_norm: false
143
+ pitch_extractor: parselmouth
144
+ pitch_type: frame
145
+ pre_align_args:
146
+ allow_no_txt: false
147
+ denoise: false
148
+ sox_resample: true
149
+ sox_to_wav: false
150
+ trim_sil: false
151
+ txt_processor: zh
152
+ use_tone: false
153
+ pre_align_cls: data_gen.tts.singing.pre_align.SingingPreAlign
154
+ predictor_grad: 0.0
155
+ print_nan_grads: false
156
+ processed_data_dir: ''
157
+ profile_infer: false
158
+ raw_data_dir: ''
159
+ ref_level_db: 20
160
+ rename_tmux: true
161
+ rerun_gen: true
162
+ resblock: '1'
163
+ resblock_dilation_sizes:
164
+ - - 1
165
+ - 3
166
+ - 5
167
+ - - 1
168
+ - 3
169
+ - 5
170
+ - - 1
171
+ - 3
172
+ - 5
173
+ resblock_kernel_sizes:
174
+ - 3
175
+ - 7
176
+ - 11
177
+ resume_from_checkpoint: 0
178
+ save_best: true
179
+ save_codes: []
180
+ save_f0: true
181
+ save_gt: true
182
+ scheduler: rsqrt
183
+ seed: 1234
184
+ sort_by_len: true
185
+ stft_loss_params:
186
+ fft_sizes:
187
+ - 1024
188
+ - 2048
189
+ - 512
190
+ hop_sizes:
191
+ - 120
192
+ - 240
193
+ - 50
194
+ win_lengths:
195
+ - 600
196
+ - 1200
197
+ - 240
198
+ window: hann_window
199
+ task_cls: tasks.vocoder.hifigan.HifiGanTask
200
+ tb_log_interval: 100
201
+ test_ids: []
202
+ test_input_dir: ''
203
+ test_num: 50
204
+ test_prefixes: []
205
+ test_set_name: test
206
+ train_set_name: train
207
+ train_sets: ''
208
+ upsample_initial_channel: 512
209
+ upsample_kernel_sizes:
210
+ - 16
211
+ - 16
212
+ - 4
213
+ - 4
214
+ upsample_rates:
215
+ - 8
216
+ - 4
217
+ - 2
218
+ - 2
219
+ use_cdisc: false
220
+ use_cond_disc: false
221
+ use_fm_loss: false
222
+ use_gt_dur: true
223
+ use_gt_f0: true
224
+ use_mel_loss: true
225
+ use_ms_stft: false
226
+ use_pitch_embed: true
227
+ use_ref_enc: true
228
+ use_spec_disc: false
229
+ use_spk_embed: false
230
+ use_spk_id: false
231
+ use_split_spk_id: false
232
+ val_check_interval: 2000
233
+ valid_infer_interval: 10000
234
+ valid_monitor_key: val_loss
235
+ valid_monitor_mode: min
236
+ valid_set_name: valid
237
+ vocoder: pwg
238
+ vocoder_ckpt: ''
239
+ vocoder_denoise_c: 0.0
240
+ warmup_updates: 8000
241
+ weight_decay: 0
242
+ win_length: null
243
+ win_size: 512
244
+ window: hann
245
+ word_size: 3000
246
+ work_dir: checkpoints/m4singer_hifigan
checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3e859bd2b1e125fe661aedfd6fa3e97e10e06f3ec3d03b7735a041984402f89
3
+ size 1016324099
checkpoints/m4singer_pe/config.yaml ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accumulate_grad_batches: 1
2
+ audio_num_mel_bins: 80
3
+ audio_sample_rate: 24000
4
+ base_config:
5
+ - configs/tts/lj/fs2.yaml
6
+ binarization_args:
7
+ shuffle: false
8
+ with_align: true
9
+ with_f0: true
10
+ with_f0cwt: true
11
+ with_spk_embed: true
12
+ with_txt: true
13
+ with_wav: false
14
+ binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
15
+ binary_data_dir: data/binary/m4singer
16
+ check_val_every_n_epoch: 10
17
+ clip_grad_norm: 1
18
+ cwt_add_f0_loss: false
19
+ cwt_hidden_size: 128
20
+ cwt_layers: 2
21
+ cwt_loss: l1
22
+ cwt_std_scale: 0.8
23
+ debug: false
24
+ dec_ffn_kernel_size: 9
25
+ dec_layers: 4
26
+ decoder_type: fft
27
+ dict_dir: ''
28
+ dropout: 0.1
29
+ ds_workers: 4
30
+ dur_enc_hidden_stride_kernel:
31
+ - 0,2,3
32
+ - 0,2,3
33
+ - 0,1,3
34
+ dur_loss: mse
35
+ dur_predictor_kernel: 3
36
+ dur_predictor_layers: 2
37
+ enc_ffn_kernel_size: 9
38
+ enc_layers: 4
39
+ encoder_K: 8
40
+ encoder_type: fft
41
+ endless_ds: true
42
+ ffn_act: gelu
43
+ ffn_padding: SAME
44
+ fft_size: 512
45
+ fmax: 12000
46
+ fmin: 30
47
+ gen_dir_name: ''
48
+ hidden_size: 256
49
+ hop_size: 128
50
+ infer: false
51
+ lambda_commit: 0.25
52
+ lambda_energy: 0.1
53
+ lambda_f0: 1.0
54
+ lambda_ph_dur: 1.0
55
+ lambda_sent_dur: 1.0
56
+ lambda_uv: 1.0
57
+ lambda_word_dur: 1.0
58
+ load_ckpt: ''
59
+ log_interval: 100
60
+ loud_norm: false
61
+ lr: 0.1
62
+ max_epochs: 1000
63
+ max_eval_sentences: 1
64
+ max_eval_tokens: 60000
65
+ max_frames: 5000
66
+ max_input_tokens: 1550
67
+ max_sentences: 100000
68
+ max_tokens: 20000
69
+ max_updates: 280000
70
+ mel_loss: l1
71
+ mel_vmax: 1.5
72
+ mel_vmin: -6
73
+ min_level_db: -120
74
+ norm_type: gn
75
+ num_ckpt_keep: 3
76
+ num_heads: 2
77
+ num_sanity_val_steps: 5
78
+ num_spk: 1
79
+ num_test_samples: 20
80
+ num_valid_plots: 10
81
+ optimizer_adam_beta1: 0.9
82
+ optimizer_adam_beta2: 0.98
83
+ out_wav_norm: false
84
+ pitch_ar: false
85
+ pitch_enc_hidden_stride_kernel:
86
+ - 0,2,5
87
+ - 0,2,5
88
+ - 0,2,5
89
+ pitch_extractor_conv_layers: 2
90
+ pitch_loss: l1
91
+ pitch_norm: log
92
+ pitch_type: frame
93
+ pre_align_args:
94
+ allow_no_txt: false
95
+ denoise: false
96
+ forced_align: mfa
97
+ txt_processor: en
98
+ use_sox: false
99
+ use_tone: true
100
+ pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
101
+ predictor_dropout: 0.5
102
+ predictor_grad: 0.1
103
+ predictor_hidden: -1
104
+ predictor_kernel: 5
105
+ predictor_layers: 2
106
+ prenet_dropout: 0.5
107
+ prenet_hidden_size: 256
108
+ pretrain_fs_ckpt: ''
109
+ processed_data_dir: data/processed/ljspeech
110
+ profile_infer: false
111
+ raw_data_dir: data/raw/LJSpeech-1.1
112
+ ref_norm_layer: bn
113
+ reset_phone_dict: true
114
+ save_best: false
115
+ save_ckpt: true
116
+ save_codes:
117
+ - configs
118
+ - modules
119
+ - tasks
120
+ - utils
121
+ - usr
122
+ save_f0: false
123
+ save_gt: false
124
+ seed: 1234
125
+ sort_by_len: true
126
+ stop_token_weight: 5.0
127
+ task_cls: tasks.tts.pe.PitchExtractionTask
128
+ test_ids:
129
+ - 68
130
+ - 70
131
+ - 74
132
+ - 87
133
+ - 110
134
+ - 172
135
+ - 190
136
+ - 215
137
+ - 231
138
+ - 294
139
+ - 316
140
+ - 324
141
+ - 402
142
+ - 422
143
+ - 485
144
+ - 500
145
+ - 505
146
+ - 508
147
+ - 509
148
+ - 519
149
+ test_input_dir: ''
150
+ test_num: 523
151
+ test_set_name: test
152
+ train_set_name: train
153
+ use_denoise: false
154
+ use_energy_embed: false
155
+ use_gt_dur: false
156
+ use_gt_f0: false
157
+ use_pitch_embed: true
158
+ use_pos_embed: true
159
+ use_spk_embed: false
160
+ use_spk_id: false
161
+ use_split_spk_id: false
162
+ use_uv: true
163
+ use_var_enc: false
164
+ val_check_interval: 2000
165
+ valid_num: 348
166
+ valid_set_name: valid
167
+ vocoder: pwg
168
+ vocoder_ckpt: ''
169
+ warmup_updates: 2000
170
+ weight_decay: 0
171
+ win_size: 512
172
+ work_dir: checkpoints/m4singer_pe
checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10cbf382bf82ecf335fbf68ba226f93c9c715b0476f6604351cbad9783f529fe
3
+ size 39146292
configs/config_base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ binary_data_dir: ''
3
+ work_dir: '' # experiment directory.
4
+ infer: false # infer
5
+ seed: 1234
6
+ debug: false
7
+ save_codes:
8
+ - configs
9
+ - modules
10
+ - tasks
11
+ - utils
12
+ - usr
13
+
14
+ #############
15
+ # dataset
16
+ #############
17
+ ds_workers: 1
18
+ test_num: 100
19
+ valid_num: 100
20
+ endless_ds: false
21
+ sort_by_len: true
22
+
23
+ #########
24
+ # train and eval
25
+ #########
26
+ load_ckpt: ''
27
+ save_ckpt: true
28
+ save_best: false
29
+ num_ckpt_keep: 3
30
+ clip_grad_norm: 0
31
+ accumulate_grad_batches: 1
32
+ log_interval: 100
33
+ num_sanity_val_steps: 5 # steps of validation at the beginning
34
+ check_val_every_n_epoch: 10
35
+ val_check_interval: 2000
36
+ max_epochs: 1000
37
+ max_updates: 160000
38
+ max_tokens: 31250
39
+ max_sentences: 100000
40
+ max_eval_tokens: -1
41
+ max_eval_sentences: -1
42
+ test_input_dir: ''
configs/singing/base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config:
2
+ - configs/tts/base.yaml
3
+ - configs/tts/base_zh.yaml
4
+
5
+
6
+ datasets: []
7
+ test_prefixes: []
8
+ test_num: 0
9
+ valid_num: 0
10
+
11
+ pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
12
+ binarizer_cls: data_gen.singing.binarize.SingingBinarizer
13
+ pre_align_args:
14
+ use_tone: false # for ZH
15
+ forced_align: mfa
16
+ use_sox: true
17
+ hop_size: 128 # Hop size.
18
+ fft_size: 512 # FFT size.
19
+ win_size: 512 # FFT size.
20
+ max_frames: 8000
21
+ fmin: 50 # Minimum freq in mel basis calculation.
22
+ fmax: 11025 # Maximum frequency in mel basis calculation.
23
+ pitch_type: frame
24
+
25
+ hidden_size: 256
26
+ mel_loss: "ssim:0.5|l1:0.5"
27
+ lambda_f0: 0.0
28
+ lambda_uv: 0.0
29
+ lambda_energy: 0.0
30
+ lambda_ph_dur: 0.0
31
+ lambda_sent_dur: 0.0
32
+ lambda_word_dur: 0.0
33
+ predictor_grad: 0.0
34
+ use_spk_embed: true
35
+ use_spk_id: false
36
+
37
+ max_tokens: 20000
38
+ max_updates: 400000
39
+ num_spk: 100
40
+ save_f0: true
41
+ use_gt_dur: true
42
+ use_gt_f0: true
configs/singing/fs2.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/fs2.yaml
3
+ - configs/singing/base.yaml
configs/tts/base.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # task
2
+ base_config: configs/config_base.yaml
3
+ task_cls: ''
4
+ #############
5
+ # dataset
6
+ #############
7
+ raw_data_dir: ''
8
+ processed_data_dir: ''
9
+ binary_data_dir: ''
10
+ dict_dir: ''
11
+ pre_align_cls: ''
12
+ binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
13
+ pre_align_args:
14
+ use_tone: true # for ZH
15
+ forced_align: mfa
16
+ use_sox: false
17
+ txt_processor: en
18
+ allow_no_txt: false
19
+ denoise: false
20
+ binarization_args:
21
+ shuffle: false
22
+ with_txt: true
23
+ with_wav: false
24
+ with_align: true
25
+ with_spk_embed: true
26
+ with_f0: true
27
+ with_f0cwt: true
28
+
29
+ loud_norm: false
30
+ endless_ds: true
31
+ reset_phone_dict: true
32
+
33
+ test_num: 100
34
+ valid_num: 100
35
+ max_frames: 1550
36
+ max_input_tokens: 1550
37
+ audio_num_mel_bins: 80
38
+ audio_sample_rate: 22050
39
+ hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
40
+ win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
41
+ fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
42
+ fmax: 7600 # To be increased/reduced depending on data.
43
+ fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter
44
+ min_level_db: -100
45
+ num_spk: 1
46
+ mel_vmin: -6
47
+ mel_vmax: 1.5
48
+ ds_workers: 4
49
+
50
+ #########
51
+ # model
52
+ #########
53
+ dropout: 0.1
54
+ enc_layers: 4
55
+ dec_layers: 4
56
+ hidden_size: 384
57
+ num_heads: 2
58
+ prenet_dropout: 0.5
59
+ prenet_hidden_size: 256
60
+ stop_token_weight: 5.0
61
+ enc_ffn_kernel_size: 9
62
+ dec_ffn_kernel_size: 9
63
+ ffn_act: gelu
64
+ ffn_padding: 'SAME'
65
+
66
+
67
+ ###########
68
+ # optimization
69
+ ###########
70
+ lr: 2.0
71
+ warmup_updates: 8000
72
+ optimizer_adam_beta1: 0.9
73
+ optimizer_adam_beta2: 0.98
74
+ weight_decay: 0
75
+ clip_grad_norm: 1
76
+
77
+
78
+ ###########
79
+ # train and eval
80
+ ###########
81
+ max_tokens: 30000
82
+ max_sentences: 100000
83
+ max_eval_sentences: 1
84
+ max_eval_tokens: 60000
85
+ train_set_name: 'train'
86
+ valid_set_name: 'valid'
87
+ test_set_name: 'test'
88
+ vocoder: pwg
89
+ vocoder_ckpt: ''
90
+ profile_infer: false
91
+ out_wav_norm: false
92
+ save_gt: false
93
+ save_f0: false
94
+ gen_dir_name: ''
95
+ use_denoise: false
configs/tts/base_zh.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pre_align_args:
2
+ txt_processor: zh_g2pM
3
+ binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer
configs/tts/fs2.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/base.yaml
2
+ task_cls: tasks.tts.fs2.FastSpeech2Task
3
+
4
+ # model
5
+ hidden_size: 256
6
+ dropout: 0.1
7
+ encoder_type: fft # fft|tacotron|tacotron2|conformer
8
+ encoder_K: 8 # for tacotron encoder
9
+ decoder_type: fft # fft|rnn|conv|conformer
10
+ use_pos_embed: true
11
+
12
+ # duration
13
+ predictor_hidden: -1
14
+ predictor_kernel: 5
15
+ predictor_layers: 2
16
+ dur_predictor_kernel: 3
17
+ dur_predictor_layers: 2
18
+ predictor_dropout: 0.5
19
+
20
+ # pitch and energy
21
+ use_pitch_embed: true
22
+ pitch_type: ph # frame|ph|cwt
23
+ use_uv: true
24
+ cwt_hidden_size: 128
25
+ cwt_layers: 2
26
+ cwt_loss: l1
27
+ cwt_add_f0_loss: false
28
+ cwt_std_scale: 0.8
29
+
30
+ pitch_ar: false
31
+ #pitch_embed_type: 0q
32
+ pitch_loss: 'l1' # l1|l2|ssim
33
+ pitch_norm: log
34
+ use_energy_embed: false
35
+
36
+ # reference encoder and speaker embedding
37
+ use_spk_id: false
38
+ use_split_spk_id: false
39
+ use_spk_embed: false
40
+ use_var_enc: false
41
+ lambda_commit: 0.25
42
+ ref_norm_layer: bn
43
+ pitch_enc_hidden_stride_kernel:
44
+ - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
45
+ - 0,2,5
46
+ - 0,2,5
47
+ dur_enc_hidden_stride_kernel:
48
+ - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
49
+ - 0,2,3
50
+ - 0,1,3
51
+
52
+
53
+ # mel
54
+ mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
55
+
56
+ # loss lambda
57
+ lambda_f0: 1.0
58
+ lambda_uv: 1.0
59
+ lambda_energy: 0.1
60
+ lambda_ph_dur: 1.0
61
+ lambda_sent_dur: 1.0
62
+ lambda_word_dur: 1.0
63
+ predictor_grad: 0.1
64
+
65
+ # train and eval
66
+ pretrain_fs_ckpt: ''
67
+ warmup_updates: 2000
68
+ max_tokens: 32000
69
+ max_sentences: 100000
70
+ max_eval_sentences: 1
71
+ max_updates: 120000
72
+ num_valid_plots: 5
73
+ num_test_samples: 0
74
+ test_ids: []
75
+ use_gt_dur: false
76
+ use_gt_f0: false
77
+
78
+ # exp
79
+ dur_loss: mse # huber|mol
80
+ norm_type: gn
configs/tts/hifigan.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/pwg.yaml
2
+ task_cls: tasks.vocoder.hifigan.HifiGanTask
3
+ resblock: "1"
4
+ adam_b1: 0.8
5
+ adam_b2: 0.99
6
+ upsample_rates: [ 8,8,2,2 ]
7
+ upsample_kernel_sizes: [ 16,16,4,4 ]
8
+ upsample_initial_channel: 128
9
+ resblock_kernel_sizes: [ 3,7,11 ]
10
+ resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
11
+
12
+ lambda_mel: 45.0
13
+
14
+ max_samples: 8192
15
+ max_sentences: 16
16
+
17
+ generator_params:
18
+ lr: 0.0002 # Generator's learning rate.
19
+ aux_context_window: 0 # Context window size for auxiliary feature.
20
+ discriminator_optimizer_params:
21
+ lr: 0.0002 # Discriminator's learning rate.
configs/tts/lj/base_mel2wav.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ raw_data_dir: 'data/raw/LJSpeech-1.1'
2
+ processed_data_dir: 'data/processed/ljspeech'
3
+ binary_data_dir: 'data/binary/ljspeech_wav'
configs/tts/lj/base_text2mel.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ raw_data_dir: 'data/raw/LJSpeech-1.1'
2
+ processed_data_dir: 'data/processed/ljspeech'
3
+ binary_data_dir: 'data/binary/ljspeech'
4
+ pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
5
+
6
+ pitch_type: cwt
7
+ mel_loss: l1
8
+ num_test_samples: 20
9
+ test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294,
10
+ 316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ]
11
+ use_energy_embed: false
12
+ test_num: 523
13
+ valid_num: 348
configs/tts/lj/fs2.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/fs2.yaml
3
+ - configs/tts/lj/base_text2mel.yaml
configs/tts/lj/hifigan.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/hifigan.yaml
3
+ - configs/tts/lj/base_mel2wav.yaml
configs/tts/lj/pwg.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ base_config:
2
+ - configs/tts/pwg.yaml
3
+ - configs/tts/lj/base_mel2wav.yaml
configs/tts/pwg.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_config: configs/tts/base.yaml
2
+ task_cls: tasks.vocoder.pwg.PwgTask
3
+
4
+ binarization_args:
5
+ with_wav: true
6
+ with_spk_embed: false
7
+ with_align: false
8
+ test_input_dir: ''
9
+
10
+ ###########
11
+ # train and eval
12
+ ###########
13
+ max_samples: 25600
14
+ max_sentences: 5
15
+ max_eval_sentences: 1
16
+ max_updates: 1000000
17
+ val_check_interval: 2000
18
+
19
+
20
+ ###########################################################
21
+ # FEATURE EXTRACTION SETTING #
22
+ ###########################################################
23
+ sampling_rate: 22050 # Sampling rate.
24
+ fft_size: 1024 # FFT size.
25
+ hop_size: 256 # Hop size.
26
+ win_length: null # Window length.
27
+ # If set to null, it will be the same as fft_size.
28
+ window: "hann" # Window function.
29
+ num_mels: 80 # Number of mel basis.
30
+ fmin: 80 # Minimum freq in mel basis calculation.
31
+ fmax: 7600 # Maximum frequency in mel basis calculation.
32
+ format: "hdf5" # Feature file format. "npy" or "hdf5" is supported.
33
+
34
+ ###########################################################
35
+ # GENERATOR NETWORK ARCHITECTURE SETTING #
36
+ ###########################################################
37
+ generator_params:
38
+ in_channels: 1 # Number of input channels.
39
+ out_channels: 1 # Number of output channels.
40
+ kernel_size: 3 # Kernel size of dilated convolution.
41
+ layers: 30 # Number of residual block layers.
42
+ stacks: 3 # Number of stacks i.e., dilation cycles.
43
+ residual_channels: 64 # Number of channels in residual conv.
44
+ gate_channels: 128 # Number of channels in gated conv.
45
+ skip_channels: 64 # Number of channels in skip conv.
46
+ aux_channels: 80 # Number of channels for auxiliary feature conv.
47
+ # Must be the same as num_mels.
48
+ aux_context_window: 2 # Context window size for auxiliary feature.
49
+ # If set to 2, previous 2 and future 2 frames will be considered.
50
+ dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
51
+ use_weight_norm: true # Whether to use weight norm.
52
+ # If set to true, it will be applied to all of the conv layers.
53
+ upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
54
+ upsample_params: # Upsampling network parameters.
55
+ upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size.
56
+ use_pitch_embed: false
57
+
58
+ ###########################################################
59
+ # DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
60
+ ###########################################################
61
+ discriminator_params:
62
+ in_channels: 1 # Number of input channels.
63
+ out_channels: 1 # Number of output channels.
64
+ kernel_size: 3 # Number of output channels.
65
+ layers: 10 # Number of conv layers.
66
+ conv_channels: 64 # Number of chnn layers.
67
+ bias: true # Whether to use bias parameter in conv.
68
+ use_weight_norm: true # Whether to use weight norm.
69
+ # If set to true, it will be applied to all of the conv layers.
70
+ nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
71
+ nonlinear_activation_params: # Nonlinear function parameters
72
+ negative_slope: 0.2 # Alpha in LeakyReLU.
73
+
74
+ ###########################################################
75
+ # STFT LOSS SETTING #
76
+ ###########################################################
77
+ stft_loss_params:
78
+ fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
79
+ hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
80
+ win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
81
+ window: "hann_window" # Window function for STFT-based loss
82
+ use_mel_loss: false
83
+
84
+ ###########################################################
85
+ # ADVERSARIAL LOSS SETTING #
86
+ ###########################################################
87
+ lambda_adv: 4.0 # Loss balancing coefficient.
88
+
89
+ ###########################################################
90
+ # OPTIMIZER & SCHEDULER SETTING #
91
+ ###########################################################
92
+ generator_optimizer_params:
93
+ lr: 0.0001 # Generator's learning rate.
94
+ eps: 1.0e-6 # Generator's epsilon.
95
+ weight_decay: 0.0 # Generator's weight decay coefficient.
96
+ generator_scheduler_params:
97
+ step_size: 200000 # Generator's scheduler step size.
98
+ gamma: 0.5 # Generator's scheduler gamma.
99
+ # At each step size, lr will be multiplied by this parameter.
100
+ generator_grad_norm: 10 # Generator's gradient norm.
101
+ discriminator_optimizer_params:
102
+ lr: 0.00005 # Discriminator's learning rate.
103
+ eps: 1.0e-6 # Discriminator's epsilon.
104
+ weight_decay: 0.0 # Discriminator's weight decay coefficient.
105
+ discriminator_scheduler_params:
106
+ step_size: 200000 # Discriminator's scheduler step size.
107
+ gamma: 0.5 # Discriminator's scheduler gamma.
108
+ # At each step size, lr will be multiplied by this parameter.
109
+ discriminator_grad_norm: 1 # Discriminator's gradient norm.
110
+ disc_start_steps: 40000 # Number of steps to start to train discriminator.
data_gen/singing/binarize.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from copy import deepcopy
4
+ import pandas as pd
5
+ import logging
6
+ from tqdm import tqdm
7
+ import json
8
+ import glob
9
+ import re
10
+ from resemblyzer import VoiceEncoder
11
+ import traceback
12
+ import numpy as np
13
+ import pretty_midi
14
+ import librosa
15
+ from scipy.interpolate import interp1d
16
+ import torch
17
+ from textgrid import TextGrid
18
+
19
+ from utils.hparams import hparams
20
+ from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
21
+ from utils.pitch_utils import f0_to_coarse
22
+ from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
23
+ from data_gen.tts.binarizer_zh import ZhBinarizer
24
+ from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
25
+ from vocoders.base_vocoder import VOCODERS
26
+
27
+
28
+ class SingingBinarizer(BaseBinarizer):
29
+ def __init__(self, processed_data_dir=None):
30
+ if processed_data_dir is None:
31
+ processed_data_dir = hparams['processed_data_dir']
32
+ self.processed_data_dirs = processed_data_dir.split(",")
33
+ self.binarization_args = hparams['binarization_args']
34
+ self.pre_align_args = hparams['pre_align_args']
35
+ self.item2txt = {}
36
+ self.item2ph = {}
37
+ self.item2wavfn = {}
38
+ self.item2f0fn = {}
39
+ self.item2tgfn = {}
40
+ self.item2spk = {}
41
+
42
+ def split_train_test_set(self, item_names):
43
+ item_names = deepcopy(item_names)
44
+ test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
45
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
46
+ logging.info("train {}".format(len(train_item_names)))
47
+ logging.info("test {}".format(len(test_item_names)))
48
+ return train_item_names, test_item_names
49
+
50
+ def load_meta_data(self):
51
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
52
+ wav_suffix = '_wf0.wav'
53
+ txt_suffix = '.txt'
54
+ ph_suffix = '_ph.txt'
55
+ tg_suffix = '.TextGrid'
56
+ all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
57
+
58
+ for piece_path in all_wav_pieces:
59
+ item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
60
+ if len(self.processed_data_dirs) > 1:
61
+ item_name = f'ds{ds_id}_{item_name}'
62
+ self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
63
+ self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
64
+ self.item2wavfn[item_name] = piece_path
65
+
66
+ self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
67
+ if len(self.processed_data_dirs) > 1:
68
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
69
+ self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
70
+ print('spkers: ', set(self.item2spk.values()))
71
+ self.item_names = sorted(list(self.item2txt.keys()))
72
+ if self.binarization_args['shuffle']:
73
+ random.seed(1234)
74
+ random.shuffle(self.item_names)
75
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
76
+
77
+ @property
78
+ def train_item_names(self):
79
+ return self._train_item_names
80
+
81
+ @property
82
+ def valid_item_names(self):
83
+ return self._test_item_names
84
+
85
+ @property
86
+ def test_item_names(self):
87
+ return self._test_item_names
88
+
89
+ def process(self):
90
+ self.load_meta_data()
91
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
92
+ self.spk_map = self.build_spk_map()
93
+ print("| spk_map: ", self.spk_map)
94
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
95
+ json.dump(self.spk_map, open(spk_map_fn, 'w'))
96
+
97
+ self.phone_encoder = self._phone_encoder()
98
+ self.process_data('valid')
99
+ self.process_data('test')
100
+ self.process_data('train')
101
+
102
+ def _phone_encoder(self):
103
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
104
+ ph_set = []
105
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
106
+ for ph_sent in self.item2ph.values():
107
+ ph_set += ph_sent.split(' ')
108
+ ph_set = sorted(set(ph_set))
109
+ json.dump(ph_set, open(ph_set_fn, 'w'))
110
+ print("| Build phone set: ", ph_set)
111
+ else:
112
+ ph_set = json.load(open(ph_set_fn, 'r'))
113
+ print("| Load phone set: ", ph_set)
114
+ return build_phone_encoder(hparams['binary_data_dir'])
115
+
116
+ # @staticmethod
117
+ # def get_pitch(wav_fn, spec, res):
118
+ # wav_suffix = '_wf0.wav'
119
+ # f0_suffix = '_f0.npy'
120
+ # f0fn = wav_fn.replace(wav_suffix, f0_suffix)
121
+ # pitch_info = np.load(f0fn)
122
+ # f0 = [x[1] for x in pitch_info]
123
+ # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
124
+ # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
125
+ # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
126
+ # # f0_x_coor = np.arange(0, 1, 1 / len(f0))
127
+ # # f0_x_coor[-1] = 1
128
+ # # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
129
+ # if sum(f0) == 0:
130
+ # raise BinarizationError("Empty f0")
131
+ # assert len(f0) == len(spec), (len(f0), len(spec))
132
+ # pitch_coarse = f0_to_coarse(f0)
133
+ #
134
+ # # vis f0
135
+ # # import matplotlib.pyplot as plt
136
+ # # from textgrid import TextGrid
137
+ # # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
138
+ # # fig = plt.figure(figsize=(12, 6))
139
+ # # plt.pcolor(spec.T, vmin=-5, vmax=0)
140
+ # # ax = plt.gca()
141
+ # # ax2 = ax.twinx()
142
+ # # ax2.plot(f0, color='red')
143
+ # # ax2.set_ylim(0, 800)
144
+ # # itvs = TextGrid.fromFile(tg_fn)[0]
145
+ # # for itv in itvs:
146
+ # # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
147
+ # # plt.vlines(x=x, ymin=0, ymax=80, color='black')
148
+ # # plt.text(x=x, y=20, s=itv.mark, color='black')
149
+ # # plt.savefig('tmp/20211229_singing_plots_test.png')
150
+ #
151
+ # res['f0'] = f0
152
+ # res['pitch'] = pitch_coarse
153
+
154
+ @classmethod
155
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
156
+ if hparams['vocoder'] in VOCODERS:
157
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
158
+ else:
159
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
160
+ res = {
161
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
162
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
163
+ }
164
+ try:
165
+ if binarization_args['with_f0']:
166
+ # cls.get_pitch(wav_fn, mel, res)
167
+ cls.get_pitch(wav, mel, res)
168
+ if binarization_args['with_txt']:
169
+ try:
170
+ # print(ph)
171
+ phone_encoded = res['phone'] = encoder.encode(ph)
172
+ except:
173
+ traceback.print_exc()
174
+ raise BinarizationError(f"Empty phoneme")
175
+ if binarization_args['with_align']:
176
+ cls.get_align(tg_fn, ph, mel, phone_encoded, res)
177
+ except BinarizationError as e:
178
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
179
+ return None
180
+ return res
181
+
182
+
183
+ class MidiSingingBinarizer(SingingBinarizer):
184
+ item2midi = {}
185
+ item2midi_dur = {}
186
+ item2is_slur = {}
187
+ item2ph_durs = {}
188
+ item2wdb = {}
189
+
190
+ def load_meta_data(self):
191
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
192
+ meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict]
193
+
194
+ for song_item in meta_midi:
195
+ item_name = raw_item_name = song_item['item_name']
196
+ if len(self.processed_data_dirs) > 1:
197
+ item_name = f'ds{ds_id}_{item_name}'
198
+ self.item2wavfn[item_name] = song_item['wav_fn']
199
+ self.item2txt[item_name] = song_item['txt']
200
+
201
+ self.item2ph[item_name] = ' '.join(song_item['phs'])
202
+ self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
203
+ self.item2ph_durs[item_name] = song_item['ph_dur']
204
+
205
+ self.item2midi[item_name] = song_item['notes']
206
+ self.item2midi_dur[item_name] = song_item['notes_dur']
207
+ self.item2is_slur[item_name] = song_item['is_slur']
208
+ self.item2spk[item_name] = 'pop-cs'
209
+ if len(self.processed_data_dirs) > 1:
210
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
211
+
212
+ print('spkers: ', set(self.item2spk.values()))
213
+ self.item_names = sorted(list(self.item2txt.keys()))
214
+ if self.binarization_args['shuffle']:
215
+ random.seed(1234)
216
+ random.shuffle(self.item_names)
217
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
218
+
219
+ @staticmethod
220
+ def get_pitch(wav_fn, wav, spec, ph, res):
221
+ wav_suffix = '.wav'
222
+ # midi_suffix = '.mid'
223
+ wav_dir = 'wavs'
224
+ f0_dir = 'f0'
225
+
226
+ item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
227
+ res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
228
+ res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
229
+ res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
230
+ res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
231
+ assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
232
+ res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
233
+
234
+ # gt f0.
235
+ gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
236
+ if sum(gt_f0) == 0:
237
+ raise BinarizationError("Empty **gt** f0")
238
+ res['f0'] = gt_f0
239
+ res['pitch'] = gt_pitch_coarse
240
+
241
+ @staticmethod
242
+ def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
243
+ mel2ph = np.zeros([mel.shape[0]], int)
244
+ startTime = 0
245
+
246
+ for i_ph in range(len(ph_durs)):
247
+ start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
248
+ end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
249
+ mel2ph[start_frame:end_frame] = i_ph + 1
250
+ startTime = startTime + ph_durs[i_ph]
251
+
252
+ # print('ph durs: ', ph_durs)
253
+ # print('mel2ph: ', mel2ph, len(mel2ph))
254
+ res['mel2ph'] = mel2ph
255
+ # res['dur'] = None
256
+
257
+ @classmethod
258
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
259
+ if hparams['vocoder'] in VOCODERS:
260
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
261
+ else:
262
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
263
+ res = {
264
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
265
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
266
+ }
267
+ try:
268
+ if binarization_args['with_f0']:
269
+ cls.get_pitch(wav_fn, wav, mel, ph, res)
270
+ if binarization_args['with_txt']:
271
+ try:
272
+ phone_encoded = res['phone'] = encoder.encode(ph)
273
+ except:
274
+ traceback.print_exc()
275
+ raise BinarizationError(f"Empty phoneme")
276
+ if binarization_args['with_align']:
277
+ cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
278
+ except BinarizationError as e:
279
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
280
+ return None
281
+ return res
282
+
283
+
284
+ class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
285
+ pass
286
+
287
+ class M4SingerBinarizer(MidiSingingBinarizer):
288
+ item2midi = {}
289
+ item2midi_dur = {}
290
+ item2is_slur = {}
291
+ item2ph_durs = {}
292
+ item2wdb = {}
293
+
294
+ def split_train_test_set(self, item_names):
295
+ item_names = deepcopy(item_names)
296
+ test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
297
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
298
+ logging.info("train {}".format(len(train_item_names)))
299
+ logging.info("test {}".format(len(test_item_names)))
300
+ return train_item_names, test_item_names
301
+
302
+ def load_meta_data(self):
303
+ raw_data_dir = hparams['raw_data_dir']
304
+ song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
305
+ for song_item in song_items:
306
+ item_name = raw_item_name = song_item['item_name']
307
+ singer, song_name, sent_id = item_name.split("#")
308
+ self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav'
309
+ self.item2txt[item_name] = song_item['txt']
310
+
311
+ self.item2ph[item_name] = ' '.join(song_item['phs'])
312
+ self.item2ph_durs[item_name] = song_item['ph_dur']
313
+
314
+ self.item2midi[item_name] = song_item['notes']
315
+ self.item2midi_dur[item_name] = song_item['notes_dur']
316
+ self.item2is_slur[item_name] = song_item['is_slur']
317
+ self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['<SP>', '<AP>'])\
318
+ or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])]
319
+ self.item2spk[item_name] = singer
320
+
321
+ print('spkers: ', set(self.item2spk.values()))
322
+ self.item_names = sorted(list(self.item2txt.keys()))
323
+ if self.binarization_args['shuffle']:
324
+ random.seed(1234)
325
+ random.shuffle(self.item_names)
326
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
327
+
328
+ @staticmethod
329
+ def get_pitch(item_name, wav, spec, ph, res):
330
+ wav_suffix = '.wav'
331
+ # midi_suffix = '.mid'
332
+ wav_dir = 'wavs'
333
+ f0_dir = 'text_f0_align'
334
+
335
+ #item_name = os.path.splitext(os.path.basename(wav_fn))[0]
336
+ res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name])
337
+ res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name])
338
+ res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name])
339
+ res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name])
340
+ assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
341
+
342
+ # gt f0.
343
+ # f0 = None
344
+ # f0_suffix = '_f0.npy'
345
+ # f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
346
+ # pitch_info = np.load(f0fn)
347
+ # f0 = [x[1] for x in pitch_info]
348
+ # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
349
+ #
350
+ # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
351
+ # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
352
+ # if sum(f0) == 0:
353
+ # raise BinarizationError("Empty **gt** f0")
354
+ #
355
+ # pitch_coarse = f0_to_coarse(f0)
356
+ # res['f0'] = f0
357
+ # res['pitch'] = pitch_coarse
358
+
359
+ # gt f0.
360
+ gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
361
+ if sum(gt_f0) == 0:
362
+ raise BinarizationError("Empty **gt** f0")
363
+ res['f0'] = gt_f0
364
+ res['pitch'] = gt_pitch_coarse
365
+
366
+ @classmethod
367
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
368
+ if hparams['vocoder'] in VOCODERS:
369
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
370
+ else:
371
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
372
+ res = {
373
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
374
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
375
+ }
376
+ try:
377
+ if binarization_args['with_f0']:
378
+ cls.get_pitch(item_name, wav, mel, ph, res)
379
+ if binarization_args['with_txt']:
380
+ try:
381
+ phone_encoded = res['phone'] = encoder.encode(ph)
382
+ except:
383
+ traceback.print_exc()
384
+ raise BinarizationError(f"Empty phoneme")
385
+ if binarization_args['with_align']:
386
+ cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
387
+ except BinarizationError as e:
388
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
389
+ return None
390
+ return res
391
+
392
+ if __name__ == "__main__":
393
+ SingingBinarizer().process()
data_gen/tts/base_binarizer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["OMP_NUM_THREADS"] = "1"
3
+
4
+ from utils.multiprocess_utils import chunked_multiprocess_run
5
+ import random
6
+ import traceback
7
+ import json
8
+ from resemblyzer import VoiceEncoder
9
+ from tqdm import tqdm
10
+ from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
11
+ from utils.hparams import set_hparams, hparams
12
+ import numpy as np
13
+ from utils.indexed_datasets import IndexedDatasetBuilder
14
+ from vocoders.base_vocoder import VOCODERS
15
+ import pandas as pd
16
+
17
+
18
+ class BinarizationError(Exception):
19
+ pass
20
+
21
+
22
+ class BaseBinarizer:
23
+ def __init__(self, processed_data_dir=None):
24
+ if processed_data_dir is None:
25
+ processed_data_dir = hparams['processed_data_dir']
26
+ self.processed_data_dirs = processed_data_dir.split(",")
27
+ self.binarization_args = hparams['binarization_args']
28
+ self.pre_align_args = hparams['pre_align_args']
29
+ self.forced_align = self.pre_align_args['forced_align']
30
+ tg_dir = None
31
+ if self.forced_align == 'mfa':
32
+ tg_dir = 'mfa_outputs'
33
+ if self.forced_align == 'kaldi':
34
+ tg_dir = 'kaldi_outputs'
35
+ self.item2txt = {}
36
+ self.item2ph = {}
37
+ self.item2wavfn = {}
38
+ self.item2tgfn = {}
39
+ self.item2spk = {}
40
+ for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
41
+ self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
42
+ for r_idx, r in self.meta_df.iterrows():
43
+ item_name = raw_item_name = r['item_name']
44
+ if len(self.processed_data_dirs) > 1:
45
+ item_name = f'ds{ds_id}_{item_name}'
46
+ self.item2txt[item_name] = r['txt']
47
+ self.item2ph[item_name] = r['ph']
48
+ self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
49
+ self.item2spk[item_name] = r.get('spk', 'SPK1')
50
+ if len(self.processed_data_dirs) > 1:
51
+ self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
52
+ if tg_dir is not None:
53
+ self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
54
+ self.item_names = sorted(list(self.item2txt.keys()))
55
+ if self.binarization_args['shuffle']:
56
+ random.seed(1234)
57
+ random.shuffle(self.item_names)
58
+
59
+ @property
60
+ def train_item_names(self):
61
+ return self.item_names[hparams['test_num']+hparams['valid_num']:]
62
+
63
+ @property
64
+ def valid_item_names(self):
65
+ return self.item_names[0: hparams['test_num']+hparams['valid_num']] #
66
+
67
+ @property
68
+ def test_item_names(self):
69
+ return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids'
70
+
71
+ def build_spk_map(self):
72
+ spk_map = set()
73
+ for item_name in self.item_names:
74
+ spk_name = self.item2spk[item_name]
75
+ spk_map.add(spk_name)
76
+ spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
77
+ assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
78
+ return spk_map
79
+
80
+ def item_name2spk_id(self, item_name):
81
+ return self.spk_map[self.item2spk[item_name]]
82
+
83
+ def _phone_encoder(self):
84
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
85
+ ph_set = []
86
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
87
+ for processed_data_dir in self.processed_data_dirs:
88
+ ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
89
+ ph_set = sorted(set(ph_set))
90
+ json.dump(ph_set, open(ph_set_fn, 'w'))
91
+ else:
92
+ ph_set = json.load(open(ph_set_fn, 'r'))
93
+ print("| phone set: ", ph_set)
94
+ return build_phone_encoder(hparams['binary_data_dir'])
95
+
96
+ def meta_data(self, prefix):
97
+ if prefix == 'valid':
98
+ item_names = self.valid_item_names
99
+ elif prefix == 'test':
100
+ item_names = self.test_item_names
101
+ else:
102
+ item_names = self.train_item_names
103
+ for item_name in item_names:
104
+ ph = self.item2ph[item_name]
105
+ txt = self.item2txt[item_name]
106
+ tg_fn = self.item2tgfn.get(item_name)
107
+ wav_fn = self.item2wavfn[item_name]
108
+ spk_id = self.item_name2spk_id(item_name)
109
+ yield item_name, ph, txt, tg_fn, wav_fn, spk_id
110
+
111
+ def process(self):
112
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
113
+ self.spk_map = self.build_spk_map()
114
+ print("| spk_map: ", self.spk_map)
115
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
116
+ json.dump(self.spk_map, open(spk_map_fn, 'w'))
117
+
118
+ self.phone_encoder = self._phone_encoder()
119
+ self.process_data('valid')
120
+ self.process_data('test')
121
+ self.process_data('train')
122
+
123
+ def process_data(self, prefix):
124
+ data_dir = hparams['binary_data_dir']
125
+ args = []
126
+ builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
127
+ lengths = []
128
+ f0s = []
129
+ total_sec = 0
130
+ if self.binarization_args['with_spk_embed']:
131
+ voice_encoder = VoiceEncoder().cuda()
132
+
133
+ meta_data = list(self.meta_data(prefix))
134
+ for m in meta_data:
135
+ args.append(list(m) + [self.phone_encoder, self.binarization_args])
136
+ num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
137
+ for f_id, (_, item) in enumerate(
138
+ zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
139
+ if item is None:
140
+ continue
141
+ item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
142
+ if self.binarization_args['with_spk_embed'] else None
143
+ if not self.binarization_args['with_wav'] and 'wav' in item:
144
+ #print("del wav")
145
+ del item['wav']
146
+ builder.add_item(item)
147
+ lengths.append(item['len'])
148
+ total_sec += item['sec']
149
+ if item.get('f0') is not None:
150
+ f0s.append(item['f0'])
151
+ builder.finalize()
152
+ np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
153
+ if len(f0s) > 0:
154
+ f0s = np.concatenate(f0s, 0)
155
+ f0s = f0s[f0s != 0]
156
+ np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
157
+ print(f"| {prefix} total duration: {total_sec:.3f}s")
158
+
159
+ @classmethod
160
+ def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
161
+ if hparams['vocoder'] in VOCODERS:
162
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
163
+ else:
164
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
165
+ res = {
166
+ 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
167
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
168
+ }
169
+ try:
170
+ if binarization_args['with_f0']:
171
+ cls.get_pitch(wav, mel, res)
172
+ if binarization_args['with_f0cwt']:
173
+ cls.get_f0cwt(res['f0'], res)
174
+ if binarization_args['with_txt']:
175
+ try:
176
+ phone_encoded = res['phone'] = encoder.encode(ph)
177
+ except:
178
+ traceback.print_exc()
179
+ raise BinarizationError(f"Empty phoneme")
180
+ if binarization_args['with_align']:
181
+ cls.get_align(tg_fn, ph, mel, phone_encoded, res)
182
+ except BinarizationError as e:
183
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
184
+ return None
185
+ return res
186
+
187
+ @staticmethod
188
+ def get_align(tg_fn, ph, mel, phone_encoded, res):
189
+ if tg_fn is not None and os.path.exists(tg_fn):
190
+ mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
191
+ else:
192
+ raise BinarizationError(f"Align not found")
193
+ if mel2ph.max() - 1 >= len(phone_encoded):
194
+ raise BinarizationError(
195
+ f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
196
+ res['mel2ph'] = mel2ph
197
+ res['dur'] = dur
198
+
199
+ @staticmethod
200
+ def get_pitch(wav, mel, res):
201
+ f0, pitch_coarse = get_pitch(wav, mel, hparams)
202
+ if sum(f0) == 0:
203
+ raise BinarizationError("Empty f0")
204
+ res['f0'] = f0
205
+ res['pitch'] = pitch_coarse
206
+
207
+ @staticmethod
208
+ def get_f0cwt(f0, res):
209
+ from utils.cwt import get_cont_lf0, get_lf0_cwt
210
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
211
+ logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
212
+ cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
213
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
214
+ if np.any(np.isnan(Wavelet_lf0)):
215
+ raise BinarizationError("NaN CWT")
216
+ res['cwt_spec'] = Wavelet_lf0
217
+ res['cwt_scales'] = scales
218
+ res['f0_mean'] = logf0s_mean_org
219
+ res['f0_std'] = logf0s_std_org
220
+
221
+
222
+ if __name__ == "__main__":
223
+ set_hparams()
224
+ BaseBinarizer().process()
data_gen/tts/bin/binarize.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["OMP_NUM_THREADS"] = "1"
4
+
5
+ import importlib
6
+ from utils.hparams import set_hparams, hparams
7
+
8
+
9
+ def binarize():
10
+ binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
11
+ pkg = ".".join(binarizer_cls.split(".")[:-1])
12
+ cls_name = binarizer_cls.split(".")[-1]
13
+ binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
14
+ print("| Binarizer: ", binarizer_cls)
15
+ binarizer_cls().process()
16
+
17
+
18
+ if __name__ == '__main__':
19
+ set_hparams()
20
+ binarize()
data_gen/tts/binarizer_zh.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["OMP_NUM_THREADS"] = "1"
4
+
5
+ from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
6
+ from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
7
+ from data_gen.tts.data_gen_utils import get_mel2ph
8
+ from utils.hparams import set_hparams, hparams
9
+ import numpy as np
10
+
11
+
12
+ class ZhBinarizer(BaseBinarizer):
13
+ @staticmethod
14
+ def get_align(tg_fn, ph, mel, phone_encoded, res):
15
+ if tg_fn is not None and os.path.exists(tg_fn):
16
+ _, dur = get_mel2ph(tg_fn, ph, mel, hparams)
17
+ else:
18
+ raise BinarizationError(f"Align not found")
19
+ ph_list = ph.split(" ")
20
+ assert len(dur) == len(ph_list)
21
+ mel2ph = []
22
+ # 分隔符的时长分配给韵母
23
+ dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
24
+ for i in range(len(dur)):
25
+ p = ph_list[i]
26
+ if p[0] != '<' and not p[0].isalpha():
27
+ uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
28
+ j = 0
29
+ while j < len(uv_) and not uv_[j]:
30
+ j += 1
31
+ dur[i - 1] += j
32
+ dur[i] -= j
33
+ if dur[i] < 100:
34
+ dur[i - 1] += dur[i]
35
+ dur[i] = 0
36
+ # 声母和韵母等长
37
+ for i in range(len(dur)):
38
+ p = ph_list[i]
39
+ if p in ALL_SHENMU:
40
+ p_next = ph_list[i + 1]
41
+ if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
42
+ print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
43
+ f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
44
+ continue
45
+ total = dur[i + 1] + dur[i]
46
+ dur[i] = total // 2
47
+ dur[i + 1] = total - dur[i]
48
+ for i in range(len(dur)):
49
+ mel2ph += [i + 1] * dur[i]
50
+ mel2ph = np.array(mel2ph)
51
+ if mel2ph.max() - 1 >= len(phone_encoded):
52
+ raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
53
+ res['mel2ph'] = mel2ph
54
+ res['dur'] = dur
55
+
56
+
57
+ if __name__ == "__main__":
58
+ set_hparams()
59
+ ZhBinarizer().process()
data_gen/tts/data_gen_utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import parselmouth
6
+ import os
7
+ import torch
8
+ from skimage.transform import resize
9
+ from utils.text_encoder import TokenTextEncoder
10
+ from utils.pitch_utils import f0_to_coarse
11
+ import struct
12
+ import webrtcvad
13
+ from scipy.ndimage.morphology import binary_dilation
14
+ import librosa
15
+ import numpy as np
16
+ from utils import audio
17
+ import pyloudnorm as pyln
18
+ import re
19
+ import json
20
+ from collections import OrderedDict
21
+
22
+ PUNCS = '!,.?;:'
23
+
24
+ int16_max = (2 ** 15) - 1
25
+
26
+
27
+ def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
28
+ """
29
+ Ensures that segments without voice in the waveform remain no longer than a
30
+ threshold determined by the VAD parameters in params.py.
31
+ :param wav: the raw waveform as a numpy array of floats
32
+ :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
33
+ :return: the same waveform with silences trimmed away (length <= original wav length)
34
+ """
35
+
36
+ ## Voice Activation Detection
37
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
38
+ # This sets the granularity of the VAD. Should not need to be changed.
39
+ sampling_rate = 16000
40
+ wav_raw, sr = librosa.core.load(path, sr=sr)
41
+
42
+ if norm:
43
+ meter = pyln.Meter(sr) # create BS.1770 meter
44
+ loudness = meter.integrated_loudness(wav_raw)
45
+ wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
46
+ if np.abs(wav_raw).max() > 1.0:
47
+ wav_raw = wav_raw / np.abs(wav_raw).max()
48
+
49
+ wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
50
+
51
+ vad_window_length = 30 # In milliseconds
52
+ # Number of frames to average together when performing the moving average smoothing.
53
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
54
+ vad_moving_average_width = 8
55
+
56
+ # Compute the voice detection window size
57
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
58
+
59
+ # Trim the end of the audio to have a multiple of the window size
60
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
61
+
62
+ # Convert the float waveform to 16-bit mono PCM
63
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
64
+
65
+ # Perform voice activation detection
66
+ voice_flags = []
67
+ vad = webrtcvad.Vad(mode=3)
68
+ for window_start in range(0, len(wav), samples_per_window):
69
+ window_end = window_start + samples_per_window
70
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
71
+ sample_rate=sampling_rate))
72
+ voice_flags = np.array(voice_flags)
73
+
74
+ # Smooth the voice detection with a moving average
75
+ def moving_average(array, width):
76
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
77
+ ret = np.cumsum(array_padded, dtype=float)
78
+ ret[width:] = ret[width:] - ret[:-width]
79
+ return ret[width - 1:] / width
80
+
81
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
82
+ audio_mask = np.round(audio_mask).astype(np.bool)
83
+
84
+ # Dilate the voiced regions
85
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
86
+ audio_mask = np.repeat(audio_mask, samples_per_window)
87
+ audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
88
+ if return_raw_wav:
89
+ return wav_raw, audio_mask, sr
90
+ return wav_raw[audio_mask], audio_mask, sr
91
+
92
+
93
+ def process_utterance(wav_path,
94
+ fft_size=1024,
95
+ hop_size=256,
96
+ win_length=1024,
97
+ window="hann",
98
+ num_mels=80,
99
+ fmin=80,
100
+ fmax=7600,
101
+ eps=1e-6,
102
+ sample_rate=22050,
103
+ loud_norm=False,
104
+ min_level_db=-100,
105
+ return_linear=False,
106
+ trim_long_sil=False, vocoder='pwg'):
107
+ if isinstance(wav_path, str):
108
+ if trim_long_sil:
109
+ wav, _, _ = trim_long_silences(wav_path, sample_rate)
110
+ else:
111
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
112
+ else:
113
+ wav = wav_path
114
+
115
+ if loud_norm:
116
+ meter = pyln.Meter(sample_rate) # create BS.1770 meter
117
+ loudness = meter.integrated_loudness(wav)
118
+ wav = pyln.normalize.loudness(wav, loudness, -22.0)
119
+ if np.abs(wav).max() > 1:
120
+ wav = wav / np.abs(wav).max()
121
+
122
+ # get amplitude spectrogram
123
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
124
+ win_length=win_length, window=window, pad_mode="constant")
125
+ spc = np.abs(x_stft) # (n_bins, T)
126
+
127
+ # get mel basis
128
+ fmin = 0 if fmin == -1 else fmin
129
+ fmax = sample_rate / 2 if fmax == -1 else fmax
130
+ mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
131
+ mel = mel_basis @ spc
132
+
133
+ if vocoder == 'pwg':
134
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
135
+ else:
136
+ assert False, f'"{vocoder}" is not in ["pwg"].'
137
+
138
+ l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
139
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
140
+ wav = wav[:mel.shape[1] * hop_size]
141
+
142
+ if not return_linear:
143
+ return wav, mel
144
+ else:
145
+ spc = audio.amp_to_db(spc)
146
+ spc = audio.normalize(spc, {'min_level_db': min_level_db})
147
+ return wav, mel, spc
148
+
149
+
150
+ def get_pitch(wav_data, mel, hparams):
151
+ """
152
+
153
+ :param wav_data: [T]
154
+ :param mel: [T, 80]
155
+ :param hparams:
156
+ :return:
157
+ """
158
+ time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
159
+ f0_min = 80
160
+ f0_max = 750
161
+
162
+ if hparams['hop_size'] == 128:
163
+ pad_size = 4
164
+ elif hparams['hop_size'] == 256:
165
+ pad_size = 2
166
+ else:
167
+ assert False
168
+
169
+ f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
170
+ time_step=time_step / 1000, voicing_threshold=0.6,
171
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
172
+ lpad = pad_size * 2
173
+ rpad = len(mel) - len(f0) - lpad
174
+ f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
175
+ # mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
176
+ # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
177
+ # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
178
+ delta_l = len(mel) - len(f0)
179
+ assert np.abs(delta_l) <= 8
180
+ if delta_l > 0:
181
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
182
+ f0 = f0[:len(mel)]
183
+ pitch_coarse = f0_to_coarse(f0)
184
+ return f0, pitch_coarse
185
+
186
+
187
+ def remove_empty_lines(text):
188
+ """remove empty lines"""
189
+ assert (len(text) > 0)
190
+ assert (isinstance(text, list))
191
+ text = [t.strip() for t in text]
192
+ if "" in text:
193
+ text.remove("")
194
+ return text
195
+
196
+
197
+ class TextGrid(object):
198
+ def __init__(self, text):
199
+ text = remove_empty_lines(text)
200
+ self.text = text
201
+ self.line_count = 0
202
+ self._get_type()
203
+ self._get_time_intval()
204
+ self._get_size()
205
+ self.tier_list = []
206
+ self._get_item_list()
207
+
208
+ def _extract_pattern(self, pattern, inc):
209
+ """
210
+ Parameters
211
+ ----------
212
+ pattern : regex to extract pattern
213
+ inc : increment of line count after extraction
214
+ Returns
215
+ -------
216
+ group : extracted info
217
+ """
218
+ try:
219
+ group = re.match(pattern, self.text[self.line_count]).group(1)
220
+ self.line_count += inc
221
+ except AttributeError:
222
+ raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
223
+ return group
224
+
225
+ def _get_type(self):
226
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
227
+
228
+ def _get_time_intval(self):
229
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
230
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
231
+
232
+ def _get_size(self):
233
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
234
+
235
+ def _get_item_list(self):
236
+ """Only supports IntervalTier currently"""
237
+ for itemIdx in range(1, self.size + 1):
238
+ tier = OrderedDict()
239
+ item_list = []
240
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
241
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
242
+ if tier_class != "IntervalTier":
243
+ raise NotImplementedError("Only IntervalTier class is supported currently")
244
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
245
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
246
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
247
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
248
+ for i in range(int(tier_size)):
249
+ item = OrderedDict()
250
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
251
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
252
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
253
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
254
+ item_list.append(item)
255
+ tier["idx"] = tier_idx
256
+ tier["class"] = tier_class
257
+ tier["name"] = tier_name
258
+ tier["xmin"] = tier_xmin
259
+ tier["xmax"] = tier_xmax
260
+ tier["size"] = tier_size
261
+ tier["items"] = item_list
262
+ self.tier_list.append(tier)
263
+
264
+ def toJson(self):
265
+ _json = OrderedDict()
266
+ _json["file_type"] = self.file_type
267
+ _json["xmin"] = self.xmin
268
+ _json["xmax"] = self.xmax
269
+ _json["size"] = self.size
270
+ _json["tiers"] = self.tier_list
271
+ return json.dumps(_json, ensure_ascii=False, indent=2)
272
+
273
+
274
+ def get_mel2ph(tg_fn, ph, mel, hparams):
275
+ ph_list = ph.split(" ")
276
+ with open(tg_fn, "r") as f:
277
+ tg = f.readlines()
278
+ tg = remove_empty_lines(tg)
279
+ tg = TextGrid(tg)
280
+ tg = json.loads(tg.toJson())
281
+ split = np.ones(len(ph_list) + 1, np.float) * -1
282
+ tg_idx = 0
283
+ ph_idx = 0
284
+ tg_align = [x for x in tg['tiers'][-1]['items']]
285
+ tg_align_ = []
286
+ for x in tg_align:
287
+ x['xmin'] = float(x['xmin'])
288
+ x['xmax'] = float(x['xmax'])
289
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
290
+ x['text'] = ''
291
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
292
+ tg_align_[-1]['xmax'] = x['xmax']
293
+ continue
294
+ tg_align_.append(x)
295
+ tg_align = tg_align_
296
+ tg_len = len([x for x in tg_align if x['text'] != ''])
297
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
298
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
299
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
300
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
301
+ split[ph_idx] = 1e8
302
+ ph_idx += 1
303
+ continue
304
+ x = tg_align[tg_idx]
305
+ if x['text'] == '' and ph_idx == len(ph_list):
306
+ tg_idx += 1
307
+ continue
308
+ assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
309
+ ph = ph_list[ph_idx]
310
+ if x['text'] == '' and not is_sil_phoneme(ph):
311
+ assert False, (ph_list, tg_align)
312
+ if x['text'] != '' and is_sil_phoneme(ph):
313
+ ph_idx += 1
314
+ else:
315
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
316
+ or x['text'].lower() == ph.lower() \
317
+ or x['text'].lower() == 'sil', (x['text'], ph)
318
+ split[ph_idx] = x['xmin']
319
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
320
+ split[ph_idx - 1] = split[ph_idx]
321
+ ph_idx += 1
322
+ tg_idx += 1
323
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
324
+ assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
325
+ mel2ph = np.zeros([mel.shape[0]], np.int)
326
+ split[0] = 0
327
+ split[-1] = 1e8
328
+ for i in range(len(split) - 1):
329
+ assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
330
+ split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
331
+ for ph_idx in range(len(ph_list)):
332
+ mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
333
+ mel2ph_torch = torch.from_numpy(mel2ph)
334
+ T_t = len(ph_list)
335
+ dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
336
+ dur = dur[1:].numpy()
337
+ return mel2ph, dur
338
+
339
+
340
+ def build_phone_encoder(data_dir):
341
+ phone_list_file = os.path.join(data_dir, 'phone_set.json')
342
+ phone_list = json.load(open(phone_list_file))
343
+ return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
344
+
345
+
346
+ def is_sil_phoneme(p):
347
+ return not p[0].isalpha()
data_gen/tts/txt_processors/base_text_processor.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class BaseTxtProcessor:
2
+ @staticmethod
3
+ def sp_phonemes():
4
+ return ['|']
5
+
6
+ @classmethod
7
+ def process(cls, txt, pre_align_args):
8
+ raise NotImplementedError
data_gen/tts/txt_processors/en.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from data_gen.tts.data_gen_utils import PUNCS
3
+ from g2p_en import G2p
4
+ import unicodedata
5
+ from g2p_en.expand import normalize_numbers
6
+ from nltk import pos_tag
7
+ from nltk.tokenize import TweetTokenizer
8
+
9
+ from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
10
+
11
+
12
+ class EnG2p(G2p):
13
+ word_tokenize = TweetTokenizer().tokenize
14
+
15
+ def __call__(self, text):
16
+ # preprocessing
17
+ words = EnG2p.word_tokenize(text)
18
+ tokens = pos_tag(words) # tuples of (word, tag)
19
+
20
+ # steps
21
+ prons = []
22
+ for word, pos in tokens:
23
+ if re.search("[a-z]", word) is None:
24
+ pron = [word]
25
+
26
+ elif word in self.homograph2features: # Check homograph
27
+ pron1, pron2, pos1 = self.homograph2features[word]
28
+ if pos.startswith(pos1):
29
+ pron = pron1
30
+ else:
31
+ pron = pron2
32
+ elif word in self.cmu: # lookup CMU dict
33
+ pron = self.cmu[word][0]
34
+ else: # predict for oov
35
+ pron = self.predict(word)
36
+
37
+ prons.extend(pron)
38
+ prons.extend([" "])
39
+
40
+ return prons[:-1]
41
+
42
+
43
+ class TxtProcessor(BaseTxtProcessor):
44
+ g2p = EnG2p()
45
+
46
+ @staticmethod
47
+ def preprocess_text(text):
48
+ text = normalize_numbers(text)
49
+ text = ''.join(char for char in unicodedata.normalize('NFD', text)
50
+ if unicodedata.category(char) != 'Mn') # Strip accents
51
+ text = text.lower()
52
+ text = re.sub("[\'\"()]+", "", text)
53
+ text = re.sub("[-]+", " ", text)
54
+ text = re.sub(f"[^ a-z{PUNCS}]", "", text)
55
+ text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
56
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
57
+ text = text.replace("i.e.", "that is")
58
+ text = text.replace("i.e.", "that is")
59
+ text = text.replace("etc.", "etc")
60
+ text = re.sub(f"([{PUNCS}])", r" \1 ", text)
61
+ text = re.sub(rf"\s+", r" ", text)
62
+ return text
63
+
64
+ @classmethod
65
+ def process(cls, txt, pre_align_args):
66
+ txt = cls.preprocess_text(txt).strip()
67
+ phs = cls.g2p(txt)
68
+ phs_ = []
69
+ n_word_sep = 0
70
+ for p in phs:
71
+ if p.strip() == '':
72
+ phs_ += ['|']
73
+ n_word_sep += 1
74
+ else:
75
+ phs_ += p.split(" ")
76
+ phs = phs_
77
+ assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"")
78
+ return phs, txt
data_gen/tts/txt_processors/zh.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pypinyin import pinyin, Style
3
+ from data_gen.tts.data_gen_utils import PUNCS
4
+ from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
5
+ from utils.text_norm import NSWNormalizer
6
+
7
+
8
+ class TxtProcessor(BaseTxtProcessor):
9
+ table = {ord(f): ord(t) for f, t in zip(
10
+ u':,。!?【】()%#@&1234567890',
11
+ u':,.!?[]()%#@&1234567890')}
12
+
13
+ @staticmethod
14
+ def preprocess_text(text):
15
+ text = text.translate(TxtProcessor.table)
16
+ text = NSWNormalizer(text).normalize(remove_punc=False)
17
+ text = re.sub("[\'\"()]+", "", text)
18
+ text = re.sub("[-]+", " ", text)
19
+ text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text)
20
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
21
+ text = re.sub(f"([{PUNCS}])", r" \1 ", text)
22
+ text = re.sub(rf"\s+", r"", text)
23
+ return text
24
+
25
+ @classmethod
26
+ def process(cls, txt, pre_align_args):
27
+ txt = cls.preprocess_text(txt)
28
+ shengmu = pinyin(txt, style=Style.INITIALS) # https://blog.csdn.net/zhoulei124/article/details/89055403
29
+ yunmu_finals = pinyin(txt, style=Style.FINALS)
30
+ yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3)
31
+ yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \
32
+ if pre_align_args['use_tone'] else yunmu_finals
33
+
34
+ assert len(shengmu) == len(yunmu)
35
+ phs = ["|"]
36
+ for a, b, c in zip(shengmu, yunmu, yunmu_finals):
37
+ if a[0] == c[0]:
38
+ phs += [a[0], "|"]
39
+ else:
40
+ phs += [a[0], b[0], "|"]
41
+ return phs, txt
data_gen/tts/txt_processors/zh_g2pM.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import jieba
3
+ from pypinyin import pinyin, Style
4
+ from data_gen.tts.data_gen_utils import PUNCS
5
+ from data_gen.tts.txt_processors import zh
6
+ from g2pM import G2pM
7
+
8
+ ALL_SHENMU = ['b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh']
9
+ ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian', 'iang', 'iao',
10
+ 'ie', 'in', 'ing', 'iong', 'iou', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'uei',
11
+ 'uen', 'uo', 'v', 'van', 've', 'vn']
12
+
13
+
14
+ class TxtProcessor(zh.TxtProcessor):
15
+ model = G2pM()
16
+
17
+ @staticmethod
18
+ def sp_phonemes():
19
+ return ['|', '#']
20
+
21
+ @classmethod
22
+ def process(cls, txt, pre_align_args):
23
+ txt = cls.preprocess_text(txt)
24
+ ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True)
25
+ seg_list = '#'.join(jieba.cut(txt))
26
+ assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list)
27
+
28
+ # 加入词边界'#'
29
+ ph_list_ = []
30
+ seg_idx = 0
31
+ for p in ph_list:
32
+ p = p.replace("u:", "v")
33
+ if seg_list[seg_idx] == '#':
34
+ ph_list_.append('#')
35
+ seg_idx += 1
36
+ else:
37
+ ph_list_.append("|")
38
+ seg_idx += 1
39
+ if re.findall('[\u4e00-\u9fff]', p):
40
+ if pre_align_args['use_tone']:
41
+ p = pinyin(p, style=Style.TONE3, strict=True)[0][0]
42
+ if p[-1] not in ['1', '2', '3', '4', '5']:
43
+ p = p + '5'
44
+ else:
45
+ p = pinyin(p, style=Style.NORMAL, strict=True)[0][0]
46
+
47
+ finished = False
48
+ if len([c.isalpha() for c in p]) > 1:
49
+ for shenmu in ALL_SHENMU:
50
+ if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric():
51
+ ph_list_ += [shenmu, p.lstrip(shenmu)]
52
+ finished = True
53
+ break
54
+ if not finished:
55
+ ph_list_.append(p)
56
+
57
+ ph_list = ph_list_
58
+
59
+ # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...]
60
+ sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes()
61
+ ph_list_ = []
62
+ for i in range(0, len(ph_list), 1):
63
+ if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes):
64
+ ph_list_.append(ph_list[i])
65
+ ph_list = ph_list_
66
+ return ph_list, txt
67
+
68
+
69
+ if __name__ == '__main__':
70
+ phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True})
71
+ print(phs)
inference/m4singer/base_svs_infer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ from modules.hifigan.hifigan import HifiGanGenerator
6
+ from vocoders.hifigan import HifiGAN
7
+ from inference.m4singer.m4singer.map import m4singer_pinyin2ph_func
8
+
9
+ from utils import load_ckpt
10
+ from utils.hparams import set_hparams, hparams
11
+ from utils.text_encoder import TokenTextEncoder
12
+ from pypinyin import pinyin, lazy_pinyin, Style
13
+ import librosa
14
+ import glob
15
+ import re
16
+
17
+
18
+ class BaseSVSInfer:
19
+ def __init__(self, hparams, device=None):
20
+ if device is None:
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ self.hparams = hparams
23
+ self.device = device
24
+
25
+ phone_list = ["<AP>", "<SP>", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g", "h",
26
+ "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iou", "j", "k", "l", "m", "n", "o", "ong", "ou",
27
+ "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "uei", "uen", "uo", "v", "van", "ve", "vn",
28
+ "x", "z", "zh"]
29
+ self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
30
+ self.pinyin2phs = m4singer_pinyin2ph_func()
31
+ self.spk_map = {"Alto-1": 0, "Alto-2": 1, "Alto-3": 2, "Alto-4": 3, "Alto-5": 4, "Alto-6": 5, "Alto-7": 6, "Bass-1": 7,
32
+ "Bass-2": 8, "Bass-3": 9, "Soprano-1": 10, "Soprano-2": 11, "Soprano-3": 12, "Tenor-1": 13, "Tenor-2": 14,
33
+ "Tenor-3": 15, "Tenor-4": 16, "Tenor-5": 17, "Tenor-6": 18, "Tenor-7": 19}
34
+
35
+ self.model = self.build_model()
36
+ self.model.eval()
37
+ self.model.to(self.device)
38
+ self.vocoder = self.build_vocoder()
39
+ self.vocoder.eval()
40
+ self.vocoder.to(self.device)
41
+
42
+ def build_model(self):
43
+ raise NotImplementedError
44
+
45
+ def forward_model(self, inp):
46
+ raise NotImplementedError
47
+
48
+ def build_vocoder(self):
49
+ base_dir = hparams['vocoder_ckpt']
50
+ config_path = f'{base_dir}/config.yaml'
51
+ ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
52
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
53
+ print('| load HifiGAN: ', ckpt)
54
+ ckpt_dict = torch.load(ckpt, map_location="cpu")
55
+ config = set_hparams(config_path, global_hparams=False)
56
+ state = ckpt_dict["state_dict"]["model_gen"]
57
+ vocoder = HifiGanGenerator(config)
58
+ vocoder.load_state_dict(state, strict=True)
59
+ vocoder.remove_weight_norm()
60
+ vocoder = vocoder.eval().to(self.device)
61
+ return vocoder
62
+
63
+ def run_vocoder(self, c, **kwargs):
64
+ c = c.transpose(2, 1) # [B, 80, T]
65
+ f0 = kwargs.get('f0') # [B, T]
66
+ if f0 is not None and hparams.get('use_nsf'):
67
+ # f0 = torch.FloatTensor(f0).to(self.device)
68
+ y = self.vocoder(c, f0).view(-1)
69
+ else:
70
+ y = self.vocoder(c).view(-1)
71
+ # [T]
72
+ return y[None]
73
+
74
+ def preprocess_word_level_input(self, inp):
75
+ # Pypinyin can't solve polyphonic words
76
+ text_raw = inp['text']
77
+
78
+ # lyric
79
+ pinyins = lazy_pinyin(text_raw, strict=False)
80
+ ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
81
+
82
+ # Note
83
+ note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
84
+ mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
85
+
86
+ if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
87
+ print('Pass word-notes check.')
88
+ else:
89
+ print('The number of words does\'t match the number of notes\' windows. ',
90
+ 'You should split the note(s) for each word by | mark.')
91
+ print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
92
+ print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
93
+ return None
94
+
95
+ note_lst = []
96
+ ph_lst = []
97
+ midi_dur_lst = []
98
+ is_slur = []
99
+ for idx, ph_per_word in enumerate(ph_per_word_lst):
100
+ # for phs in one word:
101
+ # single ph like ['ai'] or multiple phs like ['n', 'i']
102
+ ph_in_this_word = ph_per_word.split()
103
+
104
+ # for notes in one word:
105
+ # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
106
+ note_in_this_word = note_per_word_lst[idx].split()
107
+ midi_dur_in_this_word = mididur_per_word_lst[idx].split()
108
+ # process for the model input
109
+ # Step 1.
110
+ # Deal with note of 'not slur' case or the first note of 'slur' case
111
+ # j ie
112
+ # F#4/Gb4 F#4/Gb4
113
+ # 0 0
114
+ for ph in ph_in_this_word:
115
+ ph_lst.append(ph)
116
+ note_lst.append(note_in_this_word[0])
117
+ midi_dur_lst.append(midi_dur_in_this_word[0])
118
+ is_slur.append(0)
119
+ # step 2.
120
+ # Deal with the 2nd, 3rd... notes of 'slur' case
121
+ # j ie ie
122
+ # F#4/Gb4 F#4/Gb4 C#4/Db4
123
+ # 0 0 1
124
+ if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
125
+ for idx in range(1, len(note_in_this_word)):
126
+ ph_lst.append(ph_in_this_word[-1])
127
+ note_lst.append(note_in_this_word[idx])
128
+ midi_dur_lst.append(midi_dur_in_this_word[idx])
129
+ is_slur.append(1)
130
+ ph_seq = ' '.join(ph_lst)
131
+
132
+ if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
133
+ print(len(ph_lst), len(note_lst), len(midi_dur_lst))
134
+ print('Pass word-notes check.')
135
+ else:
136
+ print('The number of words does\'t match the number of notes\' windows. ',
137
+ 'You should split the note(s) for each word by | mark.')
138
+ return None
139
+ return ph_seq, note_lst, midi_dur_lst, is_slur
140
+
141
+ def preprocess_phoneme_level_input(self, inp):
142
+ ph_seq = inp['ph_seq']
143
+ note_lst = inp['note_seq'].split()
144
+ midi_dur_lst = inp['note_dur_seq'].split()
145
+ is_slur = [float(x) for x in inp['is_slur_seq'].split()]
146
+ print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
147
+ if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
148
+ print('Pass word-notes check.')
149
+ else:
150
+ print('The number of words does\'t match the number of notes\' windows. ',
151
+ 'You should split the note(s) for each word by | mark.')
152
+ return None
153
+ return ph_seq, note_lst, midi_dur_lst, is_slur
154
+
155
+ def preprocess_input(self, inp, input_type='word'):
156
+ """
157
+
158
+ :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
159
+ :return:
160
+ """
161
+
162
+ item_name = inp.get('item_name', '<ITEM_NAME>')
163
+ spk_name = inp.get('spk_name', 'Alto-1')
164
+
165
+ # single spk
166
+ spk_id = self.spk_map[spk_name]
167
+
168
+ # get ph seq, note lst, midi dur lst, is slur lst.
169
+ if input_type == 'word':
170
+ ret = self.preprocess_word_level_input(inp)
171
+ elif input_type == 'phoneme':
172
+ ret = self.preprocess_phoneme_level_input(inp)
173
+ else:
174
+ print('Invalid input type.')
175
+ return None
176
+
177
+ if ret:
178
+ ph_seq, note_lst, midi_dur_lst, is_slur = ret
179
+ else:
180
+ print('==========> Preprocess_word_level or phone_level input wrong.')
181
+ return None
182
+
183
+ # convert note lst to midi id; convert note dur lst to midi duration
184
+ try:
185
+ midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
186
+ for x in note_lst]
187
+ midi_dur_lst = [float(x) for x in midi_dur_lst]
188
+ except Exception as e:
189
+ print(e)
190
+ print('Invalid Input Type.')
191
+ return None
192
+
193
+ ph_token = self.ph_encoder.encode(ph_seq)
194
+ item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
195
+ 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
196
+ 'is_slur': np.asarray(is_slur), }
197
+ item['ph_len'] = len(item['ph_token'])
198
+ return item
199
+
200
+ def input_to_batch(self, item):
201
+ item_names = [item['item_name']]
202
+ text = [item['text']]
203
+ ph = [item['ph']]
204
+ txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
205
+ txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
206
+ spk_ids = torch.LongTensor([item['spk_id']])[:].to(self.device)
207
+
208
+ pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
209
+ midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
210
+ is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
211
+
212
+ batch = {
213
+ 'item_name': item_names,
214
+ 'text': text,
215
+ 'ph': ph,
216
+ 'txt_tokens': txt_tokens,
217
+ 'txt_lengths': txt_lengths,
218
+ 'spk_ids': spk_ids,
219
+ 'pitch_midi': pitch_midi,
220
+ 'midi_dur': midi_dur,
221
+ 'is_slur': is_slur
222
+ }
223
+ return batch
224
+
225
+ def postprocess_output(self, output):
226
+ return output
227
+
228
+ def infer_once(self, inp):
229
+ inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
230
+ output = self.forward_model(inp)
231
+ output = self.postprocess_output(output)
232
+ return output
233
+
234
+ @classmethod
235
+ def example_run(cls, inp):
236
+ from utils.audio import save_wav
237
+ set_hparams(print_hparams=False)
238
+ infer_ins = cls(hparams)
239
+ out = infer_ins.infer_once(inp)
240
+ os.makedirs('infer_out', exist_ok=True)
241
+ f_name = inp['spk_name'] + ' | ' + inp['text']
242
+ save_wav(out, f'infer_out/{f_name}.wav', hparams['audio_sample_rate'])
inference/m4singer/ds_e2e.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from inference.tts.fs import FastSpeechInfer
3
+ # from modules.tts.fs2_orig import FastSpeech2Orig
4
+ from inference.m4singer.base_svs_infer import BaseSVSInfer
5
+ from utils import load_ckpt
6
+ from utils.hparams import hparams
7
+ from usr.diff.shallow_diffusion_tts import GaussianDiffusion
8
+ from usr.diffsinger_task import DIFF_DECODERS
9
+ from modules.fastspeech.pe import PitchExtractor
10
+ import utils
11
+
12
+
13
+ class DiffSingerE2EInfer(BaseSVSInfer):
14
+ def build_model(self):
15
+ model = GaussianDiffusion(
16
+ phone_encoder=self.ph_encoder,
17
+ out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
18
+ timesteps=hparams['timesteps'],
19
+ K_step=hparams['K_step'],
20
+ loss_type=hparams['diff_loss_type'],
21
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
22
+ )
23
+ model.eval()
24
+ load_ckpt(model, hparams['work_dir'], 'model')
25
+
26
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
27
+ self.pe = PitchExtractor().to(self.device)
28
+ utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
29
+ self.pe.eval()
30
+ return model
31
+
32
+ def forward_model(self, inp):
33
+ sample = self.input_to_batch(inp)
34
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
35
+ spk_id = sample.get('spk_ids')
36
+ with torch.no_grad():
37
+ output = self.model(txt_tokens, spk_embed=spk_id, ref_mels=None, infer=True,
38
+ pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
39
+ is_slur=sample['is_slur'])
40
+ mel_out = output['mel_out'] # [B, T,80]
41
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
42
+ f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel
43
+ else:
44
+ f0_pred = output['f0_denorm']
45
+ wav_out = self.run_vocoder(mel_out, f0=f0_pred)
46
+ wav_out = wav_out.cpu().numpy()
47
+ return wav_out[0]
48
+
49
+ if __name__ == '__main__':
50
+ inp = {
51
+ 'spk_name': 'Tenor-1',
52
+ 'text': 'AP你要相信AP相信我们会像童话故事里AP',
53
+ 'notes': 'rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest',
54
+ 'notes_duration': '0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14',
55
+ 'input_type': 'word',
56
+ }
57
+
58
+ c = {
59
+ 'spk_name': 'Tenor-1',
60
+ 'text': '你要相信相信我们会像童话故事里',
61
+ 'ph_seq': '<AP> n i iao iao x iang x in in <AP> x iang iang x in uo uo m en h uei x iang t ong ong h ua g u u sh i l i <AP>',
62
+ 'note_seq': 'rest G#3 G#3 A#3 C4 D#4 D#4 D#4 D#4 F4 rest E4 E4 F4 F4 F4 D#4 A#3 A#3 A#3 A#3 A#3 C#4 C#4 B3 B3 C4 C#4 C#4 B3 B3 C4 A#3 A#3 G#3 G#3 rest',
63
+ 'note_dur_seq': '0.14 0.47 0.47 0.1905 0.1895 0.41 0.41 0.3005 0.3005 0.3895 0.21 0.2391 0.2391 0.1809 0.32 0.32 0.4105 0.2095 0.35 0.35 0.43 0.43 0.45 0.45 0.2309 0.2309 0.2291 0.48 0.48 0.225 0.225 0.195 0.29 0.29 0.71 0.71 0.14',
64
+ 'is_slur_seq': '0 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0',
65
+ 'input_type': 'phoneme'
66
+ }
67
+ DiffSingerE2EInfer.example_run(inp)
inference/m4singer/gradio/gradio_settings.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ title: 'M4Singer'
2
+ description: |
3
+ This page aims to display the SVS function of M4Singer. SingerID can be switched freely to preview the timbre of each singer. Click examples below to quickly load scores and audio. (本页面为M4Singer歌声合成功能展示。SingerID可以自由切换用以预览各歌手的音色。点击下方Examples可以快速加载乐谱和音频。)
4
+
5
+ Please assign pitch and duration values to each Chinese character. The corresponding pitch and duration value of each character should be separated by a | separator. It is necessary to ensure that the note window separated by the separator is consistent with the number of Chinese characters (AP or SP is also viewed as a Chinese character). (请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用|分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数(AP或SP也算一个汉字)一致。)
6
+
7
+ Note: This page is running on CPU, please refer to github for the local running solutions.(注意:本页面使用CPU推理,本地运行方案请参考<a href='https://github.com/M4Singer/M4Singer' style='color:blue;' target='_blank\'>Github REPO</a>。)
8
+ article: |
9
+ Link to <a href='https://github.com/M4Singer/M4Singer' style='color:blue;' target='_blank\'>Github REPO</a>
10
+
11
+ example_inputs:
12
+ - |-
13
+ Tenor-1<sep>AP你要相信AP相信我们会像童话故事里AP<sep>rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest<sep>0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14
14
+ - |-
15
+ Tenor-1<sep>AP因为在一千年以后AP世界早已没有我AP<sep>rest | C#4 | D4 | E4 | F#4 | E4 | D4 G#3 | A3 | D4 E4 | rest | F#4 | E4 | D4 | C#4 | B3 F#3 | F#3 | C4 C#4 | rest<sep>0.18 | 0.32 | 0.38 | 0.81 | 0.38 | 0.39 | 0.3155 0.2045 | 0.28 | 0.4609 1.0291 | 0.27 | 0.42 | 0.15 | 0.53 | 0.22 | 0.3059 0.2841 | 0.4 | 0.2909 1.1091 | 0.3
16
+ - |-
17
+ Tenor-2<sep>AP可是你在敲打AP我的窗棂AP<sep>rest | G#3 | B3 | B3 C#4 | E4 | C#4 B3 | G#3 | rest | C3 | E3 | B3 G#3 | F#3 | rest<sep>0.2 | 0.38 | 0.48 | 0.41 0.72 | 0.39 | 0.5195 0.2905 | 0.5 | 0.33 | 0.4 | 0.31 | 0.565 0.265 | 1.15 | 0.24
18
+ - |-
19
+ Tenor-2<sep>SP一杯敬朝阳一杯敬月光AP<sep>rest | G#3 | G#3 | G#3 | G3 | G3 G#3 | G3 | C4 | C4 | A#3 | C4 | rest<sep>0.33 | 0.26 | 0.23 | 0.27 | 0.36 | 0.3159 0.4041 | 0.54 | 0.21 | 0.32 | 0.24 | 0.58 | 0.17
20
+ - |-
21
+ Soprano-1<sep>SP乱石穿空AP惊涛拍岸AP<sep>rest | C#5 | D#5 | F5 D#5 | C#5 | rest | C#5 | C#5 | C#5 G#4 | G#4 | rest<sep>0.325 | 0.75 | 0.54 | 0.48 0.55 | 1.38 | 0.31 | 0.55 | 0.48 | 0.4891 0.4709 | 1.15 | 0.22
22
+ - |-
23
+ Soprano-1<sep>AP点点滴滴染绿了村寨AP<sep>rest | C5 | A#4 | C5 | D#5 F5 D#5 | D#5 | C5 | C5 | C5 | A#4 | rest<sep>0.175 | 0.24 | 0.26 | 1.08 | 0.3541 0.4364 0.2195 | 0.47 | 0.27 | 0.12 | 0.51 | 0.72 | 0.15
24
+ - |-
25
+ Alto-2<sep>AP拒绝声色的张扬AP不拒绝你AP<sep>rest | C4 | C4 | C4 | B3 A3 | C4 | C4 D4 | D4 | rest | D4 | D4 | C4 | G4 E4 | rest<sep>0.49 | 0.31 | 0.18 | 0.48 | 0.3 0.4 | 0.25 | 0.3591 0.2409 | 0.46 | 0.34 | 0.4 | 0.45 | 0.45 | 2.4545 0.9855 | 0.215
26
+ - |-
27
+ Alto-2<sep>AP半醒着AP笑着哭着都快活AP<sep>rest | D4 | B3 | C4 D4 | rest | E4 | D4 | E4 | D4 | E4 | E4 F#4 | F4 F#4 | rest<sep>0.165 | 0.45 | 0.53 | 0.3859 0.2441 | 0.35 | 0.38 | 0.17 | 0.32 | 0.26 | 0.33 | 0.38 0.21 | 0.3309 0.9491 | 0.125
28
+
29
+
30
+ inference_cls: inference.m4singer.ds_e2e.DiffSingerE2EInfer
31
+ exp_name: m4singer_diff_e2e
inference/m4singer/gradio/infer.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import re
3
+
4
+ import gradio as gr
5
+ import yaml
6
+ from gradio.inputs import Textbox, Dropdown
7
+
8
+ from inference.m4singer.base_svs_infer import BaseSVSInfer
9
+ from utils.hparams import set_hparams
10
+ from utils.hparams import hparams as hp
11
+ import numpy as np
12
+ import torch, random
13
+
14
+ def setup_seed(seed):
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+ np.random.seed(seed)
18
+ random.seed(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+
21
+ class GradioInfer:
22
+ def __init__(self, exp_name, inference_cls, title, description, article, example_inputs):
23
+ self.exp_name = exp_name
24
+ self.title = title
25
+ self.description = description
26
+ self.article = article
27
+ self.example_inputs = example_inputs
28
+ pkg = ".".join(inference_cls.split(".")[:-1])
29
+ cls_name = inference_cls.split(".")[-1]
30
+ self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
31
+
32
+ def greet(self, singer, text, notes, notes_duration):
33
+ PUNCS = '。?;:'
34
+ sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
35
+ sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ','))
36
+ sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ','))
37
+
38
+ if sents[-1] not in list(PUNCS):
39
+ sents = sents + ['']
40
+ sents_notes = sents_notes + ['']
41
+ sents_notes_dur = sents_notes_dur + ['']
42
+
43
+ audio_outs = []
44
+ s, n, n_dur = "", "", ""
45
+ for i in range(0, len(sents), 2):
46
+ if len(sents[i]) > 0:
47
+ s += sents[i] + sents[i + 1]
48
+ n += sents_notes[i] + sents_notes[i+1]
49
+ n_dur += sents_notes_dur[i] + sents_notes_dur[i+1]
50
+ if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
51
+ audio_out = self.infer_ins.infer_once({
52
+ 'spk_name': singer,
53
+ 'text': s,
54
+ 'notes': n,
55
+ 'notes_duration': n_dur,
56
+ })
57
+ audio_out = audio_out * 32767
58
+ audio_out = audio_out.astype(np.int16)
59
+ audio_outs.append(audio_out)
60
+ audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
61
+ s = ""
62
+ n = ""
63
+ audio_outs = np.concatenate(audio_outs)
64
+ return hp['audio_sample_rate'], audio_outs
65
+
66
+ def run(self):
67
+ set_hparams(config=f'checkpoints/{self.exp_name}/config.yaml', exp_name=self.exp_name, print_hparams=False)
68
+ #setup_seed(1234)
69
+ infer_cls = self.inference_cls
70
+ self.infer_ins: BaseSVSInfer = infer_cls(hp)
71
+ example_inputs = self.example_inputs
72
+ for i in range(len(example_inputs)):
73
+ singer, text, notes, notes_dur = example_inputs[i].split('<sep>')
74
+ example_inputs[i] = [singer, text, notes, notes_dur]
75
+
76
+ singerList = \
77
+ [
78
+ 'Tenor-1', 'Tenor-2', 'Tenor-3', 'Tenor-4', 'Tenor-5', 'Tenor-6', 'Tenor-7',
79
+ 'Alto-1', 'Alto-2', 'Alto-3', 'Alto-4', 'Alto-5', 'Alto-6', 'Alto-7',
80
+ 'Soprano-1', 'Soprano-2', 'Soprano-3',
81
+ 'Bass-1', 'Bass-2', 'Bass-3',
82
+ ]
83
+
84
+ iface = gr.Interface(fn=self.greet,
85
+ inputs=[
86
+ Dropdown(choices=singerList, default=example_inputs[0][0], label="SingerID"),
87
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][1], label="input text"),
88
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][2], label="input note"),
89
+ Textbox(lines=2, placeholder=None, default=example_inputs[0][3], label="input duration")]
90
+ ,
91
+ outputs="audio",
92
+ allow_flagging="never",
93
+ title=self.title,
94
+ description=self.description,
95
+ article=self.article,
96
+ examples=example_inputs,
97
+ enable_queue=True)
98
+ iface.launch(share=True)# cache_examples=True)
99
+
100
+ if __name__ == '__main__':
101
+ gradio_config = yaml.safe_load(open('inference/m4singer/gradio/gradio_settings.yaml'))
102
+ g = GradioInfer(**gradio_config)
103
+ g.run()
104
+
inference/m4singer/m4singer/m4singer_pinyin2ph.txt ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | a | a |
2
+ | ai | ai |
3
+ | an | an |
4
+ | ang | ang |
5
+ | ao | ao |
6
+ | ba | b a |
7
+ | bai | b ai |
8
+ | ban | b an |
9
+ | bang | b ang |
10
+ | bao | b ao |
11
+ | bei | b ei |
12
+ | ben | b en |
13
+ | beng | b eng |
14
+ | bi | b i |
15
+ | bian | b ian |
16
+ | biao | b iao |
17
+ | bie | b ie |
18
+ | bin | b in |
19
+ | bing | b ing |
20
+ | bo | b o |
21
+ | bu | b u |
22
+ | ca | c a |
23
+ | cai | c ai |
24
+ | can | c an |
25
+ | cang | c ang |
26
+ | cao | c ao |
27
+ | ce | c e |
28
+ | cei | c ei |
29
+ | cen | c en |
30
+ | ceng | c eng |
31
+ | cha | ch a |
32
+ | chai | ch ai |
33
+ | chan | ch an |
34
+ | chang | ch ang |
35
+ | chao | ch ao |
36
+ | che | ch e |
37
+ | chen | ch en |
38
+ | cheng | ch eng |
39
+ | chi | ch i |
40
+ | chong | ch ong |
41
+ | chou | ch ou |
42
+ | chu | ch u |
43
+ | chua | ch ua |
44
+ | chuai | ch uai |
45
+ | chuan | ch uan |
46
+ | chuang | ch uang |
47
+ | chui | ch uei |
48
+ | chun | ch uen |
49
+ | chuo | ch uo |
50
+ | ci | c i |
51
+ | cong | c ong |
52
+ | cou | c ou |
53
+ | cu | c u |
54
+ | cuan | c uan |
55
+ | cui | c uei |
56
+ | cun | c uen |
57
+ | cuo | c uo |
58
+ | da | d a |
59
+ | dai | d ai |
60
+ | dan | d an |
61
+ | dang | d ang |
62
+ | dao | d ao |
63
+ | de | d e |
64
+ | dei | d ei |
65
+ | den | d en |
66
+ | deng | d eng |
67
+ | di | d i |
68
+ | dia | d ia |
69
+ | dian | d ian |
70
+ | diao | d iao |
71
+ | die | d ie |
72
+ | ding | d ing |
73
+ | diu | d iou |
74
+ | dong | d ong |
75
+ | dou | d ou |
76
+ | du | d u |
77
+ | duan | d uan |
78
+ | dui | d uei |
79
+ | dun | d uen |
80
+ | duo | d uo |
81
+ | e | e |
82
+ | ei | ei |
83
+ | en | en |
84
+ | eng | eng |
85
+ | er | er |
86
+ | fa | f a |
87
+ | fan | f an |
88
+ | fang | f ang |
89
+ | fei | f ei |
90
+ | fen | f en |
91
+ | feng | f eng |
92
+ | fo | f o |
93
+ | fou | f ou |
94
+ | fu | f u |
95
+ | ga | g a |
96
+ | gai | g ai |
97
+ | gan | g an |
98
+ | gang | g ang |
99
+ | gao | g ao |
100
+ | ge | g e |
101
+ | gei | g ei |
102
+ | gen | g en |
103
+ | geng | g eng |
104
+ | gong | g ong |
105
+ | gou | g ou |
106
+ | gu | g u |
107
+ | gua | g ua |
108
+ | guai | g uai |
109
+ | guan | g uan |
110
+ | guang | g uang |
111
+ | gui | g uei |
112
+ | gun | g uen |
113
+ | guo | g uo |
114
+ | ha | h a |
115
+ | hai | h ai |
116
+ | han | h an |
117
+ | hang | h ang |
118
+ | hao | h ao |
119
+ | he | h e |
120
+ | hei | h ei |
121
+ | hen | h en |
122
+ | heng | h eng |
123
+ | hong | h ong |
124
+ | hou | h ou |
125
+ | hu | h u |
126
+ | hua | h ua |
127
+ | huai | h uai |
128
+ | huan | h uan |
129
+ | huang | h uang |
130
+ | hui | h uei |
131
+ | hun | h uen |
132
+ | huo | h uo |
133
+ | ji | j i |
134
+ | jia | j ia |
135
+ | jian | j ian |
136
+ | jiang | j iang |
137
+ | jiao | j iao |
138
+ | jie | j ie |
139
+ | jin | j in |
140
+ | jing | j ing |
141
+ | jiong | j iong |
142
+ | jiu | j iou |
143
+ | ju | j v |
144
+ | juan | j van |
145
+ | jue | j ve |
146
+ | jun | j vn |
147
+ | ka | k a |
148
+ | kai | k ai |
149
+ | kan | k an |
150
+ | kang | k ang |
151
+ | kao | k ao |
152
+ | ke | k e |
153
+ | kei | k ei |
154
+ | ken | k en |
155
+ | keng | k eng |
156
+ | kong | k ong |
157
+ | kou | k ou |
158
+ | ku | k u |
159
+ | kua | k ua |
160
+ | kuai | k uai |
161
+ | kuan | k uan |
162
+ | kuang | k uang |
163
+ | kui | k uei |
164
+ | kun | k uen |
165
+ | kuo | k uo |
166
+ | la | l a |
167
+ | lai | l ai |
168
+ | lan | l an |
169
+ | lang | l ang |
170
+ | lao | l ao |
171
+ | le | l e |
172
+ | lei | l ei |
173
+ | leng | l eng |
174
+ | li | l i |
175
+ | lia | l ia |
176
+ | lian | l ian |
177
+ | liang | l iang |
178
+ | liao | l iao |
179
+ | lie | l ie |
180
+ | lin | l in |
181
+ | ling | l ing |
182
+ | liu | l iou |
183
+ | lo | l o |
184
+ | long | l ong |
185
+ | lou | l ou |
186
+ | lu | l u |
187
+ | luan | l uan |
188
+ | lun | l uen |
189
+ | luo | l uo |
190
+ | lv | l v |
191
+ | lve | l ve |
192
+ | m | m |
193
+ | ma | m a |
194
+ | mai | m ai |
195
+ | man | m an |
196
+ | mang | m ang |
197
+ | mao | m ao |
198
+ | me | m e |
199
+ | mei | m ei |
200
+ | men | m en |
201
+ | meng | m eng |
202
+ | mi | m i |
203
+ | mian | m ian |
204
+ | miao | m iao |
205
+ | mie | m ie |
206
+ | min | m in |
207
+ | ming | m ing |
208
+ | miu | m iou |
209
+ | mo | m o |
210
+ | mou | m ou |
211
+ | mu | m u |
212
+ | n | n |
213
+ | na | n a |
214
+ | nai | n ai |
215
+ | nan | n an |
216
+ | nang | n ang |
217
+ | nao | n ao |
218
+ | ne | n e |
219
+ | nei | n ei |
220
+ | nen | n en |
221
+ | neng | n eng |
222
+ | ni | n i |
223
+ | nian | n ian |
224
+ | niang | n iang |
225
+ | niao | n iao |
226
+ | nie | n ie |
227
+ | nin | n in |
228
+ | ning | n ing |
229
+ | niu | n iou |
230
+ | nong | n ong |
231
+ | nou | n ou |
232
+ | nu | n u |
233
+ | nuan | n uan |
234
+ | nuo | n uo |
235
+ | nv | n v |
236
+ | nve | n ve |
237
+ | o | o |
238
+ | ou | ou |
239
+ | pa | p a |
240
+ | pai | p ai |
241
+ | pan | p an |
242
+ | pang | p ang |
243
+ | pao | p ao |
244
+ | pei | p ei |
245
+ | pen | p en |
246
+ | peng | p eng |
247
+ | pi | p i |
248
+ | pian | p ian |
249
+ | piao | p iao |
250
+ | pie | p ie |
251
+ | pin | p in |
252
+ | ping | p ing |
253
+ | po | p o |
254
+ | pou | p ou |
255
+ | pu | p u |
256
+ | qi | q i |
257
+ | qia | q ia |
258
+ | qian | q ian |
259
+ | qiang | q iang |
260
+ | qiao | q iao |
261
+ | qie | q ie |
262
+ | qin | q in |
263
+ | qing | q ing |
264
+ | qiong | q iong |
265
+ | qiu | q iou |
266
+ | qu | q v |
267
+ | quan | q van |
268
+ | que | q ve |
269
+ | qun | q vn |
270
+ | ran | r an |
271
+ | rang | r ang |
272
+ | rao | r ao |
273
+ | re | r e |
274
+ | ren | r en |
275
+ | reng | r eng |
276
+ | ri | r i |
277
+ | rong | r ong |
278
+ | rou | r ou |
279
+ | ru | r u |
280
+ | rua | r ua |
281
+ | ruan | r uan |
282
+ | rui | r uei |
283
+ | run | r uen |
284
+ | ruo | r uo |
285
+ | sa | s a |
286
+ | sai | s ai |
287
+ | san | s an |
288
+ | sang | s ang |
289
+ | sao | s ao |
290
+ | se | s e |
291
+ | sen | s en |
292
+ | seng | s eng |
293
+ | sha | sh a |
294
+ | shai | sh ai |
295
+ | shan | sh an |
296
+ | shang | sh ang |
297
+ | shao | sh ao |
298
+ | she | sh e |
299
+ | shei | sh ei |
300
+ | shen | sh en |
301
+ | sheng | sh eng |
302
+ | shi | sh i |
303
+ | shou | sh ou |
304
+ | shu | sh u |
305
+ | shua | sh ua |
306
+ | shuai | sh uai |
307
+ | shuan | sh uan |
308
+ | shuang | sh uang |
309
+ | shui | sh uei |
310
+ | shun | sh uen |
311
+ | shuo | sh uo |
312
+ | si | s i |
313
+ | song | s ong |
314
+ | sou | s ou |
315
+ | su | s u |
316
+ | suan | s uan |
317
+ | sui | s uei |
318
+ | sun | s uen |
319
+ | suo | s uo |
320
+ | ta | t a |
321
+ | tai | t ai |
322
+ | tan | t an |
323
+ | tang | t ang |
324
+ | tao | t ao |
325
+ | te | t e |
326
+ | tei | t ei |
327
+ | teng | t eng |
328
+ | ti | t i |
329
+ | tian | t ian |
330
+ | tiao | t iao |
331
+ | tie | t ie |
332
+ | ting | t ing |
333
+ | tong | t ong |
334
+ | tou | t ou |
335
+ | tu | t u |
336
+ | tuan | t uan |
337
+ | tui | t uei |
338
+ | tun | t uen |
339
+ | tuo | t uo |
340
+ | wa | ua |
341
+ | wai | uai |
342
+ | wan | uan |
343
+ | wang | uang |
344
+ | wei | uei |
345
+ | wen | uen |
346
+ | weng | ueng |
347
+ | wo | uo |
348
+ | wu | u |
349
+ | xi | x i |
350
+ | xia | x ia |
351
+ | xian | x ian |
352
+ | xiang | x iang |
353
+ | xiao | x iao |
354
+ | xie | x ie |
355
+ | xin | x in |
356
+ | xing | x ing |
357
+ | xiong | x iong |
358
+ | xiu | x iou |
359
+ | xu | x v |
360
+ | xuan | x van |
361
+ | xue | x ve |
362
+ | xun | x vn |
363
+ | ya | ia |
364
+ | yan | ian |
365
+ | yang | iang |
366
+ | yao | iao |
367
+ | ye | ie |
368
+ | yi | i |
369
+ | yin | in |
370
+ | ying | ing |
371
+ | yong | iong |
372
+ | you | iou |
373
+ | yu | v |
374
+ | yuan | van |
375
+ | yue | ve |
376
+ | yun | vn |
377
+ | za | z a |
378
+ | zai | z ai |
379
+ | zan | z an |
380
+ | zang | z ang |
381
+ | zao | z ao |
382
+ | ze | z e |
383
+ | zei | z ei |
384
+ | zen | z en |
385
+ | zeng | z eng |
386
+ | zha | zh a |
387
+ | zhai | zh ai |
388
+ | zhan | zh an |
389
+ | zhang | zh ang |
390
+ | zhao | zh ao |
391
+ | zhe | zh e |
392
+ | zhei | zh ei |
393
+ | zhen | zh en |
394
+ | zheng | zh eng |
395
+ | zhi | zh i |
396
+ | zhong | zh ong |
397
+ | zhou | zh ou |
398
+ | zhu | zh u |
399
+ | zhua | zh ua |
400
+ | zhuai | zh uai |
401
+ | zhuan | zh uan |
402
+ | zhuang | zh uang |
403
+ | zhui | zh uei |
404
+ | zhun | zh uen |
405
+ | zhuo | zh uo |
406
+ | zi | z i |
407
+ | zong | z ong |
408
+ | zou | z ou |
409
+ | zu | z u |
410
+ | zuan | z uan |
411
+ | zui | z uei |
412
+ | zun | z uen |
413
+ | zuo | z uo |
inference/m4singer/m4singer/map.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def m4singer_pinyin2ph_func():
2
+ pinyin2phs = {'AP': '<AP>', 'SP': '<SP>'}
3
+ with open('inference/m4singer/m4singer/m4singer_pinyin2ph.txt') as rf:
4
+ for line in rf.readlines():
5
+ elements = [x.strip() for x in line.split('|') if x.strip() != '']
6
+ pinyin2phs[elements[0]] = elements[1]
7
+ return pinyin2phs
modules/__init__.py ADDED
File without changes
modules/commons/common_layers.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+ import torch.onnx.operators
6
+ import torch.nn.functional as F
7
+ import utils
8
+
9
+
10
+ class Reshape(nn.Module):
11
+ def __init__(self, *args):
12
+ super(Reshape, self).__init__()
13
+ self.shape = args
14
+
15
+ def forward(self, x):
16
+ return x.view(self.shape)
17
+
18
+
19
+ class Permute(nn.Module):
20
+ def __init__(self, *args):
21
+ super(Permute, self).__init__()
22
+ self.args = args
23
+
24
+ def forward(self, x):
25
+ return x.permute(self.args)
26
+
27
+
28
+ class LinearNorm(torch.nn.Module):
29
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
30
+ super(LinearNorm, self).__init__()
31
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
+
33
+ torch.nn.init.xavier_uniform_(
34
+ self.linear_layer.weight,
35
+ gain=torch.nn.init.calculate_gain(w_init_gain))
36
+
37
+ def forward(self, x):
38
+ return self.linear_layer(x)
39
+
40
+
41
+ class ConvNorm(torch.nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
43
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
44
+ super(ConvNorm, self).__init__()
45
+ if padding is None:
46
+ assert (kernel_size % 2 == 1)
47
+ padding = int(dilation * (kernel_size - 1) / 2)
48
+
49
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
50
+ kernel_size=kernel_size, stride=stride,
51
+ padding=padding, dilation=dilation,
52
+ bias=bias)
53
+
54
+ torch.nn.init.xavier_uniform_(
55
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56
+
57
+ def forward(self, signal):
58
+ conv_signal = self.conv(signal)
59
+ return conv_signal
60
+
61
+
62
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
63
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
64
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
65
+ if padding_idx is not None:
66
+ nn.init.constant_(m.weight[padding_idx], 0)
67
+ return m
68
+
69
+
70
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
71
+ if not export and torch.cuda.is_available():
72
+ try:
73
+ from apex.normalization import FusedLayerNorm
74
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
75
+ except ImportError:
76
+ pass
77
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78
+
79
+
80
+ def Linear(in_features, out_features, bias=True):
81
+ m = nn.Linear(in_features, out_features, bias)
82
+ nn.init.xavier_uniform_(m.weight)
83
+ if bias:
84
+ nn.init.constant_(m.bias, 0.)
85
+ return m
86
+
87
+
88
+ class SinusoidalPositionalEmbedding(nn.Module):
89
+ """This module produces sinusoidal positional embeddings of any length.
90
+
91
+ Padding symbols are ignored.
92
+ """
93
+
94
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
95
+ super().__init__()
96
+ self.embedding_dim = embedding_dim
97
+ self.padding_idx = padding_idx
98
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
99
+ init_size,
100
+ embedding_dim,
101
+ padding_idx,
102
+ )
103
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
104
+
105
+ @staticmethod
106
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
107
+ """Build sinusoidal embeddings.
108
+
109
+ This matches the implementation in tensor2tensor, but differs slightly
110
+ from the description in Section 3.5 of "Attention Is All You Need".
111
+ """
112
+ half_dim = embedding_dim // 2
113
+ emb = math.log(10000) / (half_dim - 1)
114
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
115
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
116
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
117
+ if embedding_dim % 2 == 1:
118
+ # zero pad
119
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
120
+ if padding_idx is not None:
121
+ emb[padding_idx, :] = 0
122
+ return emb
123
+
124
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
125
+ """Input is expected to be of size [bsz x seqlen]."""
126
+ bsz, seq_len = input.shape[:2]
127
+ max_pos = self.padding_idx + 1 + seq_len
128
+ if self.weights is None or max_pos > self.weights.size(0):
129
+ # recompute/expand embeddings if needed
130
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
131
+ max_pos,
132
+ self.embedding_dim,
133
+ self.padding_idx,
134
+ )
135
+ self.weights = self.weights.to(self._float_tensor)
136
+
137
+ if incremental_state is not None:
138
+ # positions is the same for every token when decoding a single step
139
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
140
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
141
+
142
+ positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
143
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
144
+
145
+ def max_positions(self):
146
+ """Maximum number of supported positions."""
147
+ return int(1e5) # an arbitrary large number
148
+
149
+
150
+ class ConvTBC(nn.Module):
151
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
152
+ super(ConvTBC, self).__init__()
153
+ self.in_channels = in_channels
154
+ self.out_channels = out_channels
155
+ self.kernel_size = kernel_size
156
+ self.padding = padding
157
+
158
+ self.weight = torch.nn.Parameter(torch.Tensor(
159
+ self.kernel_size, in_channels, out_channels))
160
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
161
+
162
+ def forward(self, input):
163
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
164
+
165
+
166
+ class MultiheadAttention(nn.Module):
167
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
168
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
169
+ encoder_decoder_attention=False):
170
+ super().__init__()
171
+ self.embed_dim = embed_dim
172
+ self.kdim = kdim if kdim is not None else embed_dim
173
+ self.vdim = vdim if vdim is not None else embed_dim
174
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
175
+
176
+ self.num_heads = num_heads
177
+ self.dropout = dropout
178
+ self.head_dim = embed_dim // num_heads
179
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
180
+ self.scaling = self.head_dim ** -0.5
181
+
182
+ self.self_attention = self_attention
183
+ self.encoder_decoder_attention = encoder_decoder_attention
184
+
185
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
186
+ 'value to be of the same size'
187
+
188
+ if self.qkv_same_dim:
189
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
190
+ else:
191
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
192
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
193
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
194
+
195
+ if bias:
196
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
197
+ else:
198
+ self.register_parameter('in_proj_bias', None)
199
+
200
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
201
+
202
+ if add_bias_kv:
203
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
204
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
205
+ else:
206
+ self.bias_k = self.bias_v = None
207
+
208
+ self.add_zero_attn = add_zero_attn
209
+
210
+ self.reset_parameters()
211
+
212
+ self.enable_torch_version = False
213
+ if hasattr(F, "multi_head_attention_forward"):
214
+ self.enable_torch_version = True
215
+ else:
216
+ self.enable_torch_version = False
217
+ self.last_attn_probs = None
218
+
219
+ def reset_parameters(self):
220
+ if self.qkv_same_dim:
221
+ nn.init.xavier_uniform_(self.in_proj_weight)
222
+ else:
223
+ nn.init.xavier_uniform_(self.k_proj_weight)
224
+ nn.init.xavier_uniform_(self.v_proj_weight)
225
+ nn.init.xavier_uniform_(self.q_proj_weight)
226
+
227
+ nn.init.xavier_uniform_(self.out_proj.weight)
228
+ if self.in_proj_bias is not None:
229
+ nn.init.constant_(self.in_proj_bias, 0.)
230
+ nn.init.constant_(self.out_proj.bias, 0.)
231
+ if self.bias_k is not None:
232
+ nn.init.xavier_normal_(self.bias_k)
233
+ if self.bias_v is not None:
234
+ nn.init.xavier_normal_(self.bias_v)
235
+
236
+ def forward(
237
+ self,
238
+ query, key, value,
239
+ key_padding_mask=None,
240
+ incremental_state=None,
241
+ need_weights=True,
242
+ static_kv=False,
243
+ attn_mask=None,
244
+ before_softmax=False,
245
+ need_head_weights=False,
246
+ enc_dec_attn_constraint_mask=None,
247
+ reset_attn_weight=None
248
+ ):
249
+ """Input shape: Time x Batch x Channel
250
+
251
+ Args:
252
+ key_padding_mask (ByteTensor, optional): mask to exclude
253
+ keys that are pads, of shape `(batch, src_len)`, where
254
+ padding elements are indicated by 1s.
255
+ need_weights (bool, optional): return the attention weights,
256
+ averaged over heads (default: False).
257
+ attn_mask (ByteTensor, optional): typically used to
258
+ implement causal attention, where the mask prevents the
259
+ attention from looking forward in time (default: None).
260
+ before_softmax (bool, optional): return the raw attention
261
+ weights and values before the attention softmax.
262
+ need_head_weights (bool, optional): return the attention
263
+ weights for each head. Implies *need_weights*. Default:
264
+ return the average attention weights over all heads.
265
+ """
266
+ if need_head_weights:
267
+ need_weights = True
268
+
269
+ tgt_len, bsz, embed_dim = query.size()
270
+ assert embed_dim == self.embed_dim
271
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
272
+
273
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
274
+ if self.qkv_same_dim:
275
+ return F.multi_head_attention_forward(query, key, value,
276
+ self.embed_dim, self.num_heads,
277
+ self.in_proj_weight,
278
+ self.in_proj_bias, self.bias_k, self.bias_v,
279
+ self.add_zero_attn, self.dropout,
280
+ self.out_proj.weight, self.out_proj.bias,
281
+ self.training, key_padding_mask, need_weights,
282
+ attn_mask)
283
+ else:
284
+ return F.multi_head_attention_forward(query, key, value,
285
+ self.embed_dim, self.num_heads,
286
+ torch.empty([0]),
287
+ self.in_proj_bias, self.bias_k, self.bias_v,
288
+ self.add_zero_attn, self.dropout,
289
+ self.out_proj.weight, self.out_proj.bias,
290
+ self.training, key_padding_mask, need_weights,
291
+ attn_mask, use_separate_proj_weight=True,
292
+ q_proj_weight=self.q_proj_weight,
293
+ k_proj_weight=self.k_proj_weight,
294
+ v_proj_weight=self.v_proj_weight)
295
+
296
+ if incremental_state is not None:
297
+ print('Not implemented error.')
298
+ exit()
299
+ else:
300
+ saved_state = None
301
+
302
+ if self.self_attention:
303
+ # self-attention
304
+ q, k, v = self.in_proj_qkv(query)
305
+ elif self.encoder_decoder_attention:
306
+ # encoder-decoder attention
307
+ q = self.in_proj_q(query)
308
+ if key is None:
309
+ assert value is None
310
+ k = v = None
311
+ else:
312
+ k = self.in_proj_k(key)
313
+ v = self.in_proj_v(key)
314
+
315
+ else:
316
+ q = self.in_proj_q(query)
317
+ k = self.in_proj_k(key)
318
+ v = self.in_proj_v(value)
319
+ q *= self.scaling
320
+
321
+ if self.bias_k is not None:
322
+ assert self.bias_v is not None
323
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = torch.cat(
329
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
330
+
331
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
332
+ if k is not None:
333
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
334
+ if v is not None:
335
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
336
+
337
+ if saved_state is not None:
338
+ print('Not implemented error.')
339
+ exit()
340
+
341
+ src_len = k.size(1)
342
+
343
+ # This is part of a workaround to get around fork/join parallelism
344
+ # not supporting Optional types.
345
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
346
+ key_padding_mask = None
347
+
348
+ if key_padding_mask is not None:
349
+ assert key_padding_mask.size(0) == bsz
350
+ assert key_padding_mask.size(1) == src_len
351
+
352
+ if self.add_zero_attn:
353
+ src_len += 1
354
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
355
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
356
+ if attn_mask is not None:
357
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
358
+ if key_padding_mask is not None:
359
+ key_padding_mask = torch.cat(
360
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
361
+
362
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
363
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
364
+
365
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
366
+
367
+ if attn_mask is not None:
368
+ if len(attn_mask.shape) == 2:
369
+ attn_mask = attn_mask.unsqueeze(0)
370
+ elif len(attn_mask.shape) == 3:
371
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
372
+ bsz * self.num_heads, tgt_len, src_len)
373
+ attn_weights = attn_weights + attn_mask
374
+
375
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
376
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
377
+ attn_weights = attn_weights.masked_fill(
378
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
379
+ -1e9,
380
+ )
381
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
382
+
383
+ if key_padding_mask is not None:
384
+ # don't attend to padding symbols
385
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
+ attn_weights = attn_weights.masked_fill(
387
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
388
+ -1e9,
389
+ )
390
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
391
+
392
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
393
+
394
+ if before_softmax:
395
+ return attn_weights, v
396
+
397
+ attn_weights_float = utils.softmax(attn_weights, dim=-1)
398
+ attn_weights = attn_weights_float.type_as(attn_weights)
399
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
400
+
401
+ if reset_attn_weight is not None:
402
+ if reset_attn_weight:
403
+ self.last_attn_probs = attn_probs.detach()
404
+ else:
405
+ assert self.last_attn_probs is not None
406
+ attn_probs = self.last_attn_probs
407
+ attn = torch.bmm(attn_probs, v)
408
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
409
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
410
+ attn = self.out_proj(attn)
411
+
412
+ if need_weights:
413
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
414
+ if not need_head_weights:
415
+ # average attention weights over heads
416
+ attn_weights = attn_weights.mean(dim=0)
417
+ else:
418
+ attn_weights = None
419
+
420
+ return attn, (attn_weights, attn_logits)
421
+
422
+ def in_proj_qkv(self, query):
423
+ return self._in_proj(query).chunk(3, dim=-1)
424
+
425
+ def in_proj_q(self, query):
426
+ if self.qkv_same_dim:
427
+ return self._in_proj(query, end=self.embed_dim)
428
+ else:
429
+ bias = self.in_proj_bias
430
+ if bias is not None:
431
+ bias = bias[:self.embed_dim]
432
+ return F.linear(query, self.q_proj_weight, bias)
433
+
434
+ def in_proj_k(self, key):
435
+ if self.qkv_same_dim:
436
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
437
+ else:
438
+ weight = self.k_proj_weight
439
+ bias = self.in_proj_bias
440
+ if bias is not None:
441
+ bias = bias[self.embed_dim:2 * self.embed_dim]
442
+ return F.linear(key, weight, bias)
443
+
444
+ def in_proj_v(self, value):
445
+ if self.qkv_same_dim:
446
+ return self._in_proj(value, start=2 * self.embed_dim)
447
+ else:
448
+ weight = self.v_proj_weight
449
+ bias = self.in_proj_bias
450
+ if bias is not None:
451
+ bias = bias[2 * self.embed_dim:]
452
+ return F.linear(value, weight, bias)
453
+
454
+ def _in_proj(self, input, start=0, end=None):
455
+ weight = self.in_proj_weight
456
+ bias = self.in_proj_bias
457
+ weight = weight[start:end, :]
458
+ if bias is not None:
459
+ bias = bias[start:end]
460
+ return F.linear(input, weight, bias)
461
+
462
+
463
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
464
+ return attn_weights
465
+
466
+
467
+ class Swish(torch.autograd.Function):
468
+ @staticmethod
469
+ def forward(ctx, i):
470
+ result = i * torch.sigmoid(i)
471
+ ctx.save_for_backward(i)
472
+ return result
473
+
474
+ @staticmethod
475
+ def backward(ctx, grad_output):
476
+ i = ctx.saved_variables[0]
477
+ sigmoid_i = torch.sigmoid(i)
478
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
479
+
480
+
481
+ class CustomSwish(nn.Module):
482
+ def forward(self, input_tensor):
483
+ return Swish.apply(input_tensor)
484
+
485
+
486
+ class TransformerFFNLayer(nn.Module):
487
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
488
+ super().__init__()
489
+ self.kernel_size = kernel_size
490
+ self.dropout = dropout
491
+ self.act = act
492
+ if padding == 'SAME':
493
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
494
+ elif padding == 'LEFT':
495
+ self.ffn_1 = nn.Sequential(
496
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
497
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
498
+ )
499
+ self.ffn_2 = Linear(filter_size, hidden_size)
500
+ if self.act == 'swish':
501
+ self.swish_fn = CustomSwish()
502
+
503
+ def forward(self, x, incremental_state=None):
504
+ # x: T x B x C
505
+ if incremental_state is not None:
506
+ assert incremental_state is None, 'Nar-generation does not allow this.'
507
+ exit(1)
508
+
509
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
510
+ x = x * self.kernel_size ** -0.5
511
+
512
+ if incremental_state is not None:
513
+ x = x[-1:]
514
+ if self.act == 'gelu':
515
+ x = F.gelu(x)
516
+ if self.act == 'relu':
517
+ x = F.relu(x)
518
+ if self.act == 'swish':
519
+ x = self.swish_fn(x)
520
+ x = F.dropout(x, self.dropout, training=self.training)
521
+ x = self.ffn_2(x)
522
+ return x
523
+
524
+
525
+ class BatchNorm1dTBC(nn.Module):
526
+ def __init__(self, c):
527
+ super(BatchNorm1dTBC, self).__init__()
528
+ self.bn = nn.BatchNorm1d(c)
529
+
530
+ def forward(self, x):
531
+ """
532
+
533
+ :param x: [T, B, C]
534
+ :return: [T, B, C]
535
+ """
536
+ x = x.permute(1, 2, 0) # [B, C, T]
537
+ x = self.bn(x) # [B, C, T]
538
+ x = x.permute(2, 0, 1) # [T, B, C]
539
+ return x
540
+
541
+
542
+ class EncSALayer(nn.Module):
543
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
544
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
545
+ super().__init__()
546
+ self.c = c
547
+ self.dropout = dropout
548
+ self.num_heads = num_heads
549
+ if num_heads > 0:
550
+ if norm == 'ln':
551
+ self.layer_norm1 = LayerNorm(c)
552
+ elif norm == 'bn':
553
+ self.layer_norm1 = BatchNorm1dTBC(c)
554
+ self.self_attn = MultiheadAttention(
555
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
556
+ )
557
+ if norm == 'ln':
558
+ self.layer_norm2 = LayerNorm(c)
559
+ elif norm == 'bn':
560
+ self.layer_norm2 = BatchNorm1dTBC(c)
561
+ self.ffn = TransformerFFNLayer(
562
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
563
+
564
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
565
+ layer_norm_training = kwargs.get('layer_norm_training', None)
566
+ if layer_norm_training is not None:
567
+ self.layer_norm1.training = layer_norm_training
568
+ self.layer_norm2.training = layer_norm_training
569
+ if self.num_heads > 0:
570
+ residual = x
571
+ x = self.layer_norm1(x)
572
+ x, _, = self.self_attn(
573
+ query=x,
574
+ key=x,
575
+ value=x,
576
+ key_padding_mask=encoder_padding_mask
577
+ )
578
+ x = F.dropout(x, self.dropout, training=self.training)
579
+ x = residual + x
580
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
581
+
582
+ residual = x
583
+ x = self.layer_norm2(x)
584
+ x = self.ffn(x)
585
+ x = F.dropout(x, self.dropout, training=self.training)
586
+ x = residual + x
587
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
588
+ return x
589
+
590
+
591
+ class DecSALayer(nn.Module):
592
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
593
+ super().__init__()
594
+ self.c = c
595
+ self.dropout = dropout
596
+ self.layer_norm1 = LayerNorm(c)
597
+ self.self_attn = MultiheadAttention(
598
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
599
+ )
600
+ self.layer_norm2 = LayerNorm(c)
601
+ self.encoder_attn = MultiheadAttention(
602
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
603
+ )
604
+ self.layer_norm3 = LayerNorm(c)
605
+ self.ffn = TransformerFFNLayer(
606
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
607
+
608
+ def forward(
609
+ self,
610
+ x,
611
+ encoder_out=None,
612
+ encoder_padding_mask=None,
613
+ incremental_state=None,
614
+ self_attn_mask=None,
615
+ self_attn_padding_mask=None,
616
+ attn_out=None,
617
+ reset_attn_weight=None,
618
+ **kwargs,
619
+ ):
620
+ layer_norm_training = kwargs.get('layer_norm_training', None)
621
+ if layer_norm_training is not None:
622
+ self.layer_norm1.training = layer_norm_training
623
+ self.layer_norm2.training = layer_norm_training
624
+ self.layer_norm3.training = layer_norm_training
625
+ residual = x
626
+ x = self.layer_norm1(x)
627
+ x, _ = self.self_attn(
628
+ query=x,
629
+ key=x,
630
+ value=x,
631
+ key_padding_mask=self_attn_padding_mask,
632
+ incremental_state=incremental_state,
633
+ attn_mask=self_attn_mask
634
+ )
635
+ x = F.dropout(x, self.dropout, training=self.training)
636
+ x = residual + x
637
+
638
+ residual = x
639
+ x = self.layer_norm2(x)
640
+ if encoder_out is not None:
641
+ x, attn = self.encoder_attn(
642
+ query=x,
643
+ key=encoder_out,
644
+ value=encoder_out,
645
+ key_padding_mask=encoder_padding_mask,
646
+ incremental_state=incremental_state,
647
+ static_kv=True,
648
+ enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
649
+ reset_attn_weight=reset_attn_weight
650
+ )
651
+ attn_logits = attn[1]
652
+ else:
653
+ assert attn_out is not None
654
+ x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
655
+ attn_logits = None
656
+ x = F.dropout(x, self.dropout, training=self.training)
657
+ x = residual + x
658
+
659
+ residual = x
660
+ x = self.layer_norm3(x)
661
+ x = self.ffn(x, incremental_state=incremental_state)
662
+ x = F.dropout(x, self.dropout, training=self.training)
663
+ x = residual + x
664
+ # if len(attn_logits.size()) > 3:
665
+ # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
666
+ # attn_logits = attn_logits.gather(1,
667
+ # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
668
+ return x, attn_logits
modules/commons/espnet_positional_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class PositionalEncoding(torch.nn.Module):
6
+ """Positional encoding.
7
+ Args:
8
+ d_model (int): Embedding dimension.
9
+ dropout_rate (float): Dropout rate.
10
+ max_len (int): Maximum input length.
11
+ reverse (bool): Whether to reverse the input position.
12
+ """
13
+
14
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
+ """Construct an PositionalEncoding object."""
16
+ super(PositionalEncoding, self).__init__()
17
+ self.d_model = d_model
18
+ self.reverse = reverse
19
+ self.xscale = math.sqrt(self.d_model)
20
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
21
+ self.pe = None
22
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
+
24
+ def extend_pe(self, x):
25
+ """Reset the positional encodings."""
26
+ if self.pe is not None:
27
+ if self.pe.size(1) >= x.size(1):
28
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
+ return
31
+ pe = torch.zeros(x.size(1), self.d_model)
32
+ if self.reverse:
33
+ position = torch.arange(
34
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
+ ).unsqueeze(1)
36
+ else:
37
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
+ div_term = torch.exp(
39
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
+ * -(math.log(10000.0) / self.d_model)
41
+ )
42
+ pe[:, 0::2] = torch.sin(position * div_term)
43
+ pe[:, 1::2] = torch.cos(position * div_term)
44
+ pe = pe.unsqueeze(0)
45
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ """Add positional encoding.
49
+ Args:
50
+ x (torch.Tensor): Input tensor (batch, time, `*`).
51
+ Returns:
52
+ torch.Tensor: Encoded tensor (batch, time, `*`).
53
+ """
54
+ self.extend_pe(x)
55
+ x = x * self.xscale + self.pe[:, : x.size(1)]
56
+ return self.dropout(x)
57
+
58
+
59
+ class ScaledPositionalEncoding(PositionalEncoding):
60
+ """Scaled positional encoding module.
61
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
+ Args:
63
+ d_model (int): Embedding dimension.
64
+ dropout_rate (float): Dropout rate.
65
+ max_len (int): Maximum input length.
66
+ """
67
+
68
+ def __init__(self, d_model, dropout_rate, max_len=5000):
69
+ """Initialize class."""
70
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
+
73
+ def reset_parameters(self):
74
+ """Reset parameters."""
75
+ self.alpha.data = torch.tensor(1.0)
76
+
77
+ def forward(self, x):
78
+ """Add positional encoding.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (batch, time, `*`).
81
+ Returns:
82
+ torch.Tensor: Encoded tensor (batch, time, `*`).
83
+ """
84
+ self.extend_pe(x)
85
+ x = x + self.alpha * self.pe[:, : x.size(1)]
86
+ return self.dropout(x)
87
+
88
+
89
+ class RelPositionalEncoding(PositionalEncoding):
90
+ """Relative positional encoding module.
91
+ See : Appendix B in https://arxiv.org/abs/1901.02860
92
+ Args:
93
+ d_model (int): Embedding dimension.
94
+ dropout_rate (float): Dropout rate.
95
+ max_len (int): Maximum input length.
96
+ """
97
+
98
+ def __init__(self, d_model, dropout_rate, max_len=5000):
99
+ """Initialize class."""
100
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
+
102
+ def forward(self, x):
103
+ """Compute positional encoding.
104
+ Args:
105
+ x (torch.Tensor): Input tensor (batch, time, `*`).
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
109
+ """
110
+ self.extend_pe(x)
111
+ x = x * self.xscale
112
+ pos_emb = self.pe[:, : x.size(1)]
113
+ return self.dropout(x) + self.dropout(pos_emb)
modules/commons/ssim.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # '''
2
+ # https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
3
+ # '''
4
+ #
5
+ # import torch
6
+ # import torch.jit
7
+ # import torch.nn.functional as F
8
+ #
9
+ #
10
+ # @torch.jit.script
11
+ # def create_window(window_size: int, sigma: float, channel: int):
12
+ # '''
13
+ # Create 1-D gauss kernel
14
+ # :param window_size: the size of gauss kernel
15
+ # :param sigma: sigma of normal distribution
16
+ # :param channel: input channel
17
+ # :return: 1D kernel
18
+ # '''
19
+ # coords = torch.arange(window_size, dtype=torch.float)
20
+ # coords -= window_size // 2
21
+ #
22
+ # g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
23
+ # g /= g.sum()
24
+ #
25
+ # g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
26
+ # return g
27
+ #
28
+ #
29
+ # @torch.jit.script
30
+ # def _gaussian_filter(x, window_1d, use_padding: bool):
31
+ # '''
32
+ # Blur input with 1-D kernel
33
+ # :param x: batch of tensors to be blured
34
+ # :param window_1d: 1-D gauss kernel
35
+ # :param use_padding: padding image before conv
36
+ # :return: blured tensors
37
+ # '''
38
+ # C = x.shape[1]
39
+ # padding = 0
40
+ # if use_padding:
41
+ # window_size = window_1d.shape[3]
42
+ # padding = window_size // 2
43
+ # out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
44
+ # out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
45
+ # return out
46
+ #
47
+ #
48
+ # @torch.jit.script
49
+ # def ssim(X, Y, window, data_range: float, use_padding: bool = False):
50
+ # '''
51
+ # Calculate ssim index for X and Y
52
+ # :param X: images [B, C, H, N_bins]
53
+ # :param Y: images [B, C, H, N_bins]
54
+ # :param window: 1-D gauss kernel
55
+ # :param data_range: value range of input images. (usually 1.0 or 255)
56
+ # :param use_padding: padding image before conv
57
+ # :return:
58
+ # '''
59
+ #
60
+ # K1 = 0.01
61
+ # K2 = 0.03
62
+ # compensation = 1.0
63
+ #
64
+ # C1 = (K1 * data_range) ** 2
65
+ # C2 = (K2 * data_range) ** 2
66
+ #
67
+ # mu1 = _gaussian_filter(X, window, use_padding)
68
+ # mu2 = _gaussian_filter(Y, window, use_padding)
69
+ # sigma1_sq = _gaussian_filter(X * X, window, use_padding)
70
+ # sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
71
+ # sigma12 = _gaussian_filter(X * Y, window, use_padding)
72
+ #
73
+ # mu1_sq = mu1.pow(2)
74
+ # mu2_sq = mu2.pow(2)
75
+ # mu1_mu2 = mu1 * mu2
76
+ #
77
+ # sigma1_sq = compensation * (sigma1_sq - mu1_sq)
78
+ # sigma2_sq = compensation * (sigma2_sq - mu2_sq)
79
+ # sigma12 = compensation * (sigma12 - mu1_mu2)
80
+ #
81
+ # cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
82
+ # # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
83
+ # cs_map = cs_map.clamp_min(0.)
84
+ # ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
85
+ #
86
+ # ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
87
+ # cs = cs_map.mean(dim=(1, 2, 3))
88
+ #
89
+ # return ssim_val, cs
90
+ #
91
+ #
92
+ # @torch.jit.script
93
+ # def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
94
+ # '''
95
+ # interface of ms-ssim
96
+ # :param X: a batch of images, (N,C,H,W)
97
+ # :param Y: a batch of images, (N,C,H,W)
98
+ # :param window: 1-D gauss kernel
99
+ # :param data_range: value range of input images. (usually 1.0 or 255)
100
+ # :param weights: weights for different levels
101
+ # :param use_padding: padding image before conv
102
+ # :param eps: use for avoid grad nan.
103
+ # :return:
104
+ # '''
105
+ # levels = weights.shape[0]
106
+ # cs_vals = []
107
+ # ssim_vals = []
108
+ # for _ in range(levels):
109
+ # ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
110
+ # # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
111
+ # ssim_val = ssim_val.clamp_min(eps)
112
+ # cs = cs.clamp_min(eps)
113
+ # cs_vals.append(cs)
114
+ #
115
+ # ssim_vals.append(ssim_val)
116
+ # padding = (X.shape[2] % 2, X.shape[3] % 2)
117
+ # X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
118
+ # Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
119
+ #
120
+ # cs_vals = torch.stack(cs_vals, dim=0)
121
+ # ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
122
+ # return ms_ssim_val
123
+ #
124
+ #
125
+ # class SSIM(torch.jit.ScriptModule):
126
+ # __constants__ = ['data_range', 'use_padding']
127
+ #
128
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
129
+ # '''
130
+ # :param window_size: the size of gauss kernel
131
+ # :param window_sigma: sigma of normal distribution
132
+ # :param data_range: value range of input images. (usually 1.0 or 255)
133
+ # :param channel: input channels (default: 3)
134
+ # :param use_padding: padding image before conv
135
+ # '''
136
+ # super().__init__()
137
+ # assert window_size % 2 == 1, 'Window size must be odd.'
138
+ # window = create_window(window_size, window_sigma, channel)
139
+ # self.register_buffer('window', window)
140
+ # self.data_range = data_range
141
+ # self.use_padding = use_padding
142
+ #
143
+ # @torch.jit.script_method
144
+ # def forward(self, X, Y):
145
+ # r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
146
+ # return r[0]
147
+ #
148
+ #
149
+ # class MS_SSIM(torch.jit.ScriptModule):
150
+ # __constants__ = ['data_range', 'use_padding', 'eps']
151
+ #
152
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
153
+ # levels=None, eps=1e-8):
154
+ # '''
155
+ # class for ms-ssim
156
+ # :param window_size: the size of gauss kernel
157
+ # :param window_sigma: sigma of normal distribution
158
+ # :param data_range: value range of input images. (usually 1.0 or 255)
159
+ # :param channel: input channels
160
+ # :param use_padding: padding image before conv
161
+ # :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
162
+ # :param levels: number of downsampling
163
+ # :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
164
+ # '''
165
+ # super().__init__()
166
+ # assert window_size % 2 == 1, 'Window size must be odd.'
167
+ # self.data_range = data_range
168
+ # self.use_padding = use_padding
169
+ # self.eps = eps
170
+ #
171
+ # window = create_window(window_size, window_sigma, channel)
172
+ # self.register_buffer('window', window)
173
+ #
174
+ # if weights is None:
175
+ # weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
176
+ # weights = torch.tensor(weights, dtype=torch.float)
177
+ #
178
+ # if levels is not None:
179
+ # weights = weights[:levels]
180
+ # weights = weights / weights.sum()
181
+ #
182
+ # self.register_buffer('weights', weights)
183
+ #
184
+ # @torch.jit.script_method
185
+ # def forward(self, X, Y):
186
+ # return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
187
+ # use_padding=self.use_padding, eps=self.eps)
188
+ #
189
+ #
190
+ # if __name__ == '__main__':
191
+ # print('Simple Test')
192
+ # im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
193
+ # img1 = im / 255
194
+ # img2 = img1 * 0.5
195
+ #
196
+ # losser = SSIM(data_range=1.).cuda()
197
+ # loss = losser(img1, img2).mean()
198
+ #
199
+ # losser2 = MS_SSIM(data_range=1.).cuda()
200
+ # loss2 = losser2(img1, img2).mean()
201
+ #
202
+ # print(loss.item())
203
+ # print(loss2.item())
204
+ #
205
+ # if __name__ == '__main__':
206
+ # print('Training Test')
207
+ # import cv2
208
+ # import torch.optim
209
+ # import numpy as np
210
+ # import imageio
211
+ # import time
212
+ #
213
+ # out_test_video = False
214
+ # # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
215
+ # video_use_gif = False
216
+ #
217
+ # im = cv2.imread('test_img1.jpg', 1)
218
+ # t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
219
+ #
220
+ # if out_test_video:
221
+ # if video_use_gif:
222
+ # fps = 0.5
223
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
224
+ # suffix = '.gif'
225
+ # else:
226
+ # fps = 5
227
+ # out_wh = (im.shape[1], im.shape[0])
228
+ # suffix = '.mkv'
229
+ # video_last_time = time.perf_counter()
230
+ # video = imageio.get_writer('ssim_test' + suffix, fps=fps)
231
+ #
232
+ # # 测试ssim
233
+ # print('Training SSIM')
234
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
235
+ # rand_im.requires_grad = True
236
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
237
+ # losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
238
+ # ssim_score = 0
239
+ # while ssim_score < 0.999:
240
+ # optim.zero_grad()
241
+ # loss = losser(rand_im, t_im)
242
+ # (-loss).sum().backward()
243
+ # ssim_score = loss.item()
244
+ # optim.step()
245
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
246
+ # r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
247
+ #
248
+ # if out_test_video:
249
+ # if time.perf_counter() - video_last_time > 1. / fps:
250
+ # video_last_time = time.perf_counter()
251
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
252
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
253
+ # if isinstance(out_frame, cv2.UMat):
254
+ # out_frame = out_frame.get()
255
+ # video.append_data(out_frame)
256
+ #
257
+ # cv2.imshow('ssim', r_im)
258
+ # cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
259
+ # cv2.waitKey(1)
260
+ #
261
+ # if out_test_video:
262
+ # video.close()
263
+ #
264
+ # # 测试ms_ssim
265
+ # if out_test_video:
266
+ # if video_use_gif:
267
+ # fps = 0.5
268
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
269
+ # suffix = '.gif'
270
+ # else:
271
+ # fps = 5
272
+ # out_wh = (im.shape[1], im.shape[0])
273
+ # suffix = '.mkv'
274
+ # video_last_time = time.perf_counter()
275
+ # video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
276
+ #
277
+ # print('Training MS_SSIM')
278
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
279
+ # rand_im.requires_grad = True
280
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
281
+ # losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
282
+ # ssim_score = 0
283
+ # while ssim_score < 0.999:
284
+ # optim.zero_grad()
285
+ # loss = losser(rand_im, t_im)
286
+ # (-loss).sum().backward()
287
+ # ssim_score = loss.item()
288
+ # optim.step()
289
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
290
+ # r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
291
+ #
292
+ # if out_test_video:
293
+ # if time.perf_counter() - video_last_time > 1. / fps:
294
+ # video_last_time = time.perf_counter()
295
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
296
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
297
+ # if isinstance(out_frame, cv2.UMat):
298
+ # out_frame = out_frame.get()
299
+ # video.append_data(out_frame)
300
+ #
301
+ # cv2.imshow('ms_ssim', r_im)
302
+ # cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
303
+ # cv2.waitKey(1)
304
+ #
305
+ # if out_test_video:
306
+ # video.close()
307
+
308
+ """
309
+ Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
310
+ """
311
+
312
+ import torch
313
+ import torch.nn.functional as F
314
+ from torch.autograd import Variable
315
+ import numpy as np
316
+ from math import exp
317
+
318
+
319
+ def gaussian(window_size, sigma):
320
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
321
+ return gauss / gauss.sum()
322
+
323
+
324
+ def create_window(window_size, channel):
325
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
326
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
327
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
328
+ return window
329
+
330
+
331
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
332
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
333
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
334
+
335
+ mu1_sq = mu1.pow(2)
336
+ mu2_sq = mu2.pow(2)
337
+ mu1_mu2 = mu1 * mu2
338
+
339
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
340
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
341
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
342
+
343
+ C1 = 0.01 ** 2
344
+ C2 = 0.03 ** 2
345
+
346
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
347
+
348
+ if size_average:
349
+ return ssim_map.mean()
350
+ else:
351
+ return ssim_map.mean(1)
352
+
353
+
354
+ class SSIM(torch.nn.Module):
355
+ def __init__(self, window_size=11, size_average=True):
356
+ super(SSIM, self).__init__()
357
+ self.window_size = window_size
358
+ self.size_average = size_average
359
+ self.channel = 1
360
+ self.window = create_window(window_size, self.channel)
361
+
362
+ def forward(self, img1, img2):
363
+ (_, channel, _, _) = img1.size()
364
+
365
+ if channel == self.channel and self.window.data.type() == img1.data.type():
366
+ window = self.window
367
+ else:
368
+ window = create_window(self.window_size, channel)
369
+
370
+ if img1.is_cuda:
371
+ window = window.cuda(img1.get_device())
372
+ window = window.type_as(img1)
373
+
374
+ self.window = window
375
+ self.channel = channel
376
+
377
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
378
+
379
+
380
+ window = None
381
+
382
+
383
+ def ssim(img1, img2, window_size=11, size_average=True):
384
+ (_, channel, _, _) = img1.size()
385
+ global window
386
+ if window is None:
387
+ window = create_window(window_size, channel)
388
+ if img1.is_cuda:
389
+ window = window.cuda(img1.get_device())
390
+ window = window.type_as(img1)
391
+ return _ssim(img1, img2, window, window_size, channel, size_average)
modules/diffsinger_midi/fs2.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from modules.commons.common_layers import Embedding
3
+ from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
+ EnergyPredictor, FastspeechEncoder
5
+ from utils.cwt import cwt2f0
6
+ from utils.hparams import hparams
7
+ from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
+ from modules.fastspeech.fs2 import FastSpeech2
9
+
10
+
11
+ class FastspeechMIDIEncoder(FastspeechEncoder):
12
+ def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
13
+ # embed tokens and positions
14
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
15
+ x = x + midi_embedding + midi_dur_embedding + slur_embedding
16
+ if hparams['use_pos_embed']:
17
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
18
+ x = self.embed_positions(x)
19
+ else:
20
+ positions = self.embed_positions(txt_tokens)
21
+ x = x + positions
22
+ x = F.dropout(x, p=self.dropout, training=self.training)
23
+ return x
24
+
25
+ def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
26
+ """
27
+
28
+ :param txt_tokens: [B, T]
29
+ :return: {
30
+ 'encoder_out': [T x B x C]
31
+ }
32
+ """
33
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
34
+ x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, H]
35
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
36
+ return x
37
+
38
+
39
+ FS_ENCODERS = {
40
+ 'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder(
41
+ embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
42
+ num_heads=hp['num_heads']),
43
+ }
44
+
45
+
46
+ class FastSpeech2MIDI(FastSpeech2):
47
+ def __init__(self, dictionary, out_dims=None):
48
+ super().__init__(dictionary, out_dims)
49
+ del self.encoder
50
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
51
+ self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
52
+ self.midi_dur_layer = Linear(1, self.hidden_size)
53
+ self.is_slur_embed = Embedding(2, self.hidden_size)
54
+
55
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
56
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
57
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
58
+ ret = {}
59
+
60
+ midi_embedding = self.midi_embed(kwargs['pitch_midi'])
61
+ midi_dur_embedding, slur_embedding = 0, 0
62
+ if kwargs.get('midi_dur') is not None:
63
+ midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
64
+ if kwargs.get('is_slur') is not None:
65
+ slur_embedding = self.is_slur_embed(kwargs['is_slur'])
66
+ encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, C]
67
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
68
+
69
+ # add ref style embed
70
+ # Not implemented
71
+ # variance encoder
72
+ var_embed = 0
73
+
74
+ # encoder_out_dur denotes encoder outputs for duration predictor
75
+ # in speech adaptation, duration predictor use old speaker embedding
76
+ if hparams['use_spk_embed']:
77
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
78
+ elif hparams['use_spk_id']:
79
+ spk_embed_id = spk_embed
80
+ if spk_embed_dur_id is None:
81
+ spk_embed_dur_id = spk_embed_id
82
+ if spk_embed_f0_id is None:
83
+ spk_embed_f0_id = spk_embed_id
84
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
85
+ spk_embed_dur = spk_embed_f0 = spk_embed
86
+ if hparams['use_split_spk_id']:
87
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
88
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
89
+ else:
90
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
91
+
92
+ # add dur
93
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
94
+
95
+ mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
96
+
97
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
98
+
99
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
100
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
101
+
102
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
103
+
104
+ # add pitch and energy embed
105
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
106
+ if hparams['use_pitch_embed']:
107
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
108
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
109
+ if hparams['use_energy_embed']:
110
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
111
+
112
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
113
+
114
+ if skip_decoder:
115
+ return ret
116
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
117
+
118
+ return ret
modules/fastspeech/fs2.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from modules.commons.common_layers import Embedding
3
+ from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
+ EnergyPredictor, FastspeechEncoder
5
+ from utils.cwt import cwt2f0
6
+ from utils.hparams import hparams
7
+ from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
+
9
+ FS_ENCODERS = {
10
+ 'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
11
+ embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
12
+ num_heads=hp['num_heads']),
13
+ }
14
+
15
+ FS_DECODERS = {
16
+ 'fft': lambda hp: FastspeechDecoder(
17
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
18
+ }
19
+
20
+
21
+ class FastSpeech2(nn.Module):
22
+ def __init__(self, dictionary, out_dims=None):
23
+ super().__init__()
24
+ self.dictionary = dictionary
25
+ self.padding_idx = dictionary.pad()
26
+ self.enc_layers = hparams['enc_layers']
27
+ self.dec_layers = hparams['dec_layers']
28
+ self.hidden_size = hparams['hidden_size']
29
+ self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
30
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
31
+ self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
32
+ self.out_dims = out_dims
33
+ if out_dims is None:
34
+ self.out_dims = hparams['audio_num_mel_bins']
35
+ self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
36
+
37
+ if hparams['use_spk_id']:
38
+ self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
39
+ if hparams['use_split_spk_id']:
40
+ self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
41
+ self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
42
+ elif hparams['use_spk_embed']:
43
+ self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
44
+ predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
45
+ self.dur_predictor = DurationPredictor(
46
+ self.hidden_size,
47
+ n_chans=predictor_hidden,
48
+ n_layers=hparams['dur_predictor_layers'],
49
+ dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
50
+ kernel_size=hparams['dur_predictor_kernel'])
51
+ self.length_regulator = LengthRegulator()
52
+ if hparams['use_pitch_embed']:
53
+ self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
54
+ if hparams['pitch_type'] == 'cwt':
55
+ h = hparams['cwt_hidden_size']
56
+ cwt_out_dims = 10
57
+ if hparams['use_uv']:
58
+ cwt_out_dims = cwt_out_dims + 1
59
+ self.cwt_predictor = nn.Sequential(
60
+ nn.Linear(self.hidden_size, h),
61
+ PitchPredictor(
62
+ h,
63
+ n_chans=predictor_hidden,
64
+ n_layers=hparams['predictor_layers'],
65
+ dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
66
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
67
+ self.cwt_stats_layers = nn.Sequential(
68
+ nn.Linear(self.hidden_size, h), nn.ReLU(),
69
+ nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
70
+ )
71
+ else:
72
+ self.pitch_predictor = PitchPredictor(
73
+ self.hidden_size,
74
+ n_chans=predictor_hidden,
75
+ n_layers=hparams['predictor_layers'],
76
+ dropout_rate=hparams['predictor_dropout'],
77
+ odim=2 if hparams['pitch_type'] == 'frame' else 1,
78
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
79
+ if hparams['use_energy_embed']:
80
+ self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
81
+ self.energy_predictor = EnergyPredictor(
82
+ self.hidden_size,
83
+ n_chans=predictor_hidden,
84
+ n_layers=hparams['predictor_layers'],
85
+ dropout_rate=hparams['predictor_dropout'], odim=1,
86
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
87
+
88
+ def build_embedding(self, dictionary, embed_dim):
89
+ num_embeddings = len(dictionary)
90
+ emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
91
+ return emb
92
+
93
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
94
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
95
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
96
+ ret = {}
97
+ encoder_out = self.encoder(txt_tokens) # [B, T, C]
98
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
99
+
100
+ # add ref style embed
101
+ # Not implemented
102
+ # variance encoder
103
+ var_embed = 0
104
+
105
+ # encoder_out_dur denotes encoder outputs for duration predictor
106
+ # in speech adaptation, duration predictor use old speaker embedding
107
+ if hparams['use_spk_embed']:
108
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
109
+ elif hparams['use_spk_id']:
110
+ spk_embed_id = spk_embed
111
+ if spk_embed_dur_id is None:
112
+ spk_embed_dur_id = spk_embed_id
113
+ if spk_embed_f0_id is None:
114
+ spk_embed_f0_id = spk_embed_id
115
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
116
+ spk_embed_dur = spk_embed_f0 = spk_embed
117
+ if hparams['use_split_spk_id']:
118
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
119
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
120
+ else:
121
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
122
+
123
+ # add dur
124
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
125
+
126
+ mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
127
+
128
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
129
+
130
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
131
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
132
+
133
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
134
+
135
+ # add pitch and energy embed
136
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
137
+ if hparams['use_pitch_embed']:
138
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
139
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
140
+ if hparams['use_energy_embed']:
141
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
142
+
143
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
144
+
145
+ if skip_decoder:
146
+ return ret
147
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
148
+
149
+ return ret
150
+
151
+ def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
152
+ """
153
+
154
+ :param dur_input: [B, T_txt, H]
155
+ :param mel2ph: [B, T_mel]
156
+ :param txt_tokens: [B, T_txt]
157
+ :param ret:
158
+ :return:
159
+ """
160
+ src_padding = txt_tokens == 0
161
+ dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
162
+ if mel2ph is None:
163
+ dur, xs = self.dur_predictor.inference(dur_input, src_padding)
164
+ ret['dur'] = xs
165
+ ret['dur_choice'] = dur
166
+ mel2ph = self.length_regulator(dur, src_padding).detach()
167
+ # from modules.fastspeech.fake_modules import FakeLengthRegulator
168
+ # fake_lr = FakeLengthRegulator()
169
+ # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
170
+ # print(mel2ph == fake_mel2ph)
171
+ else:
172
+ ret['dur'] = self.dur_predictor(dur_input, src_padding)
173
+ ret['mel2ph'] = mel2ph
174
+ return mel2ph
175
+
176
+ def add_energy(self, decoder_inp, energy, ret):
177
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
178
+ ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
179
+ if energy is None:
180
+ energy = energy_pred
181
+ energy = torch.clamp(energy * 256 // 4, max=255).long()
182
+ energy_embed = self.energy_embed(energy)
183
+ return energy_embed
184
+
185
+ def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
186
+ if hparams['pitch_type'] == 'ph':
187
+ pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
188
+ pitch_padding = encoder_out.sum().abs() == 0
189
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
190
+ if f0 is None:
191
+ f0 = pitch_pred[:, :, 0]
192
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
193
+ pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
194
+ pitch = F.pad(pitch, [1, 0])
195
+ pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
196
+ pitch_embed = self.pitch_embed(pitch)
197
+ return pitch_embed
198
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
199
+
200
+ pitch_padding = mel2ph == 0
201
+
202
+ if hparams['pitch_type'] == 'cwt':
203
+ pitch_padding = None
204
+ ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
205
+ stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
206
+ mean = ret['f0_mean'] = stats_out[:, 0]
207
+ std = ret['f0_std'] = stats_out[:, 1]
208
+ cwt_spec = cwt_out[:, :, :10]
209
+ if f0 is None:
210
+ std = std * hparams['cwt_std_scale']
211
+ f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
212
+ if hparams['use_uv']:
213
+ assert cwt_out.shape[-1] == 11
214
+ uv = cwt_out[:, :, -1] > 0
215
+ elif hparams['pitch_ar']:
216
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
217
+ if f0 is None:
218
+ f0 = pitch_pred[:, :, 0]
219
+ else:
220
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
221
+ if f0 is None:
222
+ f0 = pitch_pred[:, :, 0]
223
+ if hparams['use_uv'] and uv is None:
224
+ uv = pitch_pred[:, :, 1] > 0
225
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
226
+ if pitch_padding is not None:
227
+ f0[pitch_padding] = 0
228
+
229
+ pitch = f0_to_coarse(f0_denorm) # start from 0
230
+ pitch_embed = self.pitch_embed(pitch)
231
+ return pitch_embed
232
+
233
+ def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
234
+ x = decoder_inp # [B, T, H]
235
+ x = self.decoder(x)
236
+ x = self.mel_out(x)
237
+ return x * tgt_nonpadding
238
+
239
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
240
+ f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
241
+ f0 = torch.cat(
242
+ [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
243
+ f0_norm = norm_f0(f0, None, hparams)
244
+ return f0_norm
245
+
246
+ def out2mel(self, out):
247
+ return out
248
+
249
+ @staticmethod
250
+ def mel_norm(x):
251
+ return (x + 5.5) / (6.3 / 2) - 1
252
+
253
+ @staticmethod
254
+ def mel_denorm(x):
255
+ return (x + 1) * (6.3 / 2) - 5.5
modules/fastspeech/pe.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from utils.hparams import hparams
3
+ from modules.fastspeech.tts_modules import PitchPredictor
4
+ from utils.pitch_utils import denorm_f0
5
+
6
+
7
+ class Prenet(nn.Module):
8
+ def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
9
+ super(Prenet, self).__init__()
10
+ padding = kernel // 2
11
+ self.layers = []
12
+ self.strides = strides if strides is not None else [1] * n_layers
13
+ for l in range(n_layers):
14
+ self.layers.append(nn.Sequential(
15
+ nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
16
+ nn.ReLU(),
17
+ nn.BatchNorm1d(out_dim)
18
+ ))
19
+ in_dim = out_dim
20
+ self.layers = nn.ModuleList(self.layers)
21
+ self.out_proj = nn.Linear(out_dim, out_dim)
22
+
23
+ def forward(self, x):
24
+ """
25
+
26
+ :param x: [B, T, 80]
27
+ :return: [L, B, T, H], [B, T, H]
28
+ """
29
+ padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
30
+ nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
31
+ x = x.transpose(1, 2)
32
+ hiddens = []
33
+ for i, l in enumerate(self.layers):
34
+ nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
35
+ x = l(x) * nonpadding_mask_TB
36
+ hiddens.append(x)
37
+ hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
38
+ hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
39
+ x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
40
+ x = x * nonpadding_mask_TB.transpose(1, 2)
41
+ return hiddens, x
42
+
43
+
44
+ class ConvBlock(nn.Module):
45
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
46
+ super().__init__()
47
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
48
+ self.norm = norm
49
+ if self.norm == 'bn':
50
+ self.norm = nn.BatchNorm1d(n_chans)
51
+ elif self.norm == 'in':
52
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
53
+ elif self.norm == 'gn':
54
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
55
+ elif self.norm == 'ln':
56
+ self.norm = LayerNorm(n_chans // 16, n_chans)
57
+ elif self.norm == 'wn':
58
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
59
+ self.dropout = nn.Dropout(dropout)
60
+ self.relu = nn.ReLU()
61
+
62
+ def forward(self, x):
63
+ """
64
+
65
+ :param x: [B, C, T]
66
+ :return: [B, C, T]
67
+ """
68
+ x = self.conv(x)
69
+ if not isinstance(self.norm, str):
70
+ if self.norm == 'none':
71
+ pass
72
+ elif self.norm == 'ln':
73
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
74
+ else:
75
+ x = self.norm(x)
76
+ x = self.relu(x)
77
+ x = self.dropout(x)
78
+ return x
79
+
80
+
81
+ class ConvStacks(nn.Module):
82
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
83
+ dropout=0, strides=None, res=True):
84
+ super().__init__()
85
+ self.conv = torch.nn.ModuleList()
86
+ self.kernel_size = kernel_size
87
+ self.res = res
88
+ self.in_proj = Linear(idim, n_chans)
89
+ if strides is None:
90
+ strides = [1] * n_layers
91
+ else:
92
+ assert len(strides) == n_layers
93
+ for idx in range(n_layers):
94
+ self.conv.append(ConvBlock(
95
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
96
+ self.out_proj = Linear(n_chans, odim)
97
+
98
+ def forward(self, x, return_hiddens=False):
99
+ """
100
+
101
+ :param x: [B, T, H]
102
+ :return: [B, T, H]
103
+ """
104
+ x = self.in_proj(x)
105
+ x = x.transpose(1, -1) # (B, idim, Tmax)
106
+ hiddens = []
107
+ for f in self.conv:
108
+ x_ = f(x)
109
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
110
+ hiddens.append(x)
111
+ x = x.transpose(1, -1)
112
+ x = self.out_proj(x) # (B, Tmax, H)
113
+ if return_hiddens:
114
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
115
+ return x, hiddens
116
+ return x
117
+
118
+
119
+ class PitchExtractor(nn.Module):
120
+ def __init__(self, n_mel_bins=80, conv_layers=2):
121
+ super().__init__()
122
+ self.hidden_size = 256
123
+ self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
124
+ self.conv_layers = conv_layers
125
+
126
+ self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
127
+ if self.conv_layers > 0:
128
+ self.mel_encoder = ConvStacks(
129
+ idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
130
+ self.pitch_predictor = PitchPredictor(
131
+ self.hidden_size, n_chans=self.predictor_hidden,
132
+ n_layers=5, dropout_rate=0.5, odim=2,
133
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
134
+
135
+ def forward(self, mel_input=None):
136
+ ret = {}
137
+ mel_hidden = self.mel_prenet(mel_input)[1]
138
+ if self.conv_layers > 0:
139
+ mel_hidden = self.mel_encoder(mel_hidden)
140
+
141
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
142
+
143
+ pitch_padding = mel_input.abs().sum(-1) == 0
144
+ use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
145
+
146
+ ret['f0_denorm_pred'] = denorm_f0(
147
+ pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
148
+ hparams, pitch_padding=pitch_padding)
149
+ return ret
modules/fastspeech/tts_modules.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from modules.commons.espnet_positional_embedding import RelPositionalEncoding
9
+ from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
10
+ from utils.hparams import hparams
11
+
12
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
13
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
14
+
15
+
16
+ class TransformerEncoderLayer(nn.Module):
17
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
18
+ super().__init__()
19
+ self.hidden_size = hidden_size
20
+ self.dropout = dropout
21
+ self.num_heads = num_heads
22
+ self.op = EncSALayer(
23
+ hidden_size, num_heads, dropout=dropout,
24
+ attention_dropout=0.0, relu_dropout=dropout,
25
+ kernel_size=kernel_size
26
+ if kernel_size is not None else hparams['enc_ffn_kernel_size'],
27
+ padding=hparams['ffn_padding'],
28
+ norm=norm, act=hparams['ffn_act'])
29
+
30
+ def forward(self, x, **kwargs):
31
+ return self.op(x, **kwargs)
32
+
33
+
34
+ ######################
35
+ # fastspeech modules
36
+ ######################
37
+ class LayerNorm(torch.nn.LayerNorm):
38
+ """Layer normalization module.
39
+ :param int nout: output dim size
40
+ :param int dim: dimension to be normalized
41
+ """
42
+
43
+ def __init__(self, nout, dim=-1):
44
+ """Construct an LayerNorm object."""
45
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
46
+ self.dim = dim
47
+
48
+ def forward(self, x):
49
+ """Apply layer normalization.
50
+ :param torch.Tensor x: input tensor
51
+ :return: layer normalized tensor
52
+ :rtype torch.Tensor
53
+ """
54
+ if self.dim == -1:
55
+ return super(LayerNorm, self).forward(x)
56
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
57
+
58
+
59
+ class DurationPredictor(torch.nn.Module):
60
+ """Duration predictor module.
61
+ This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
62
+ The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
63
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
64
+ https://arxiv.org/pdf/1905.09263.pdf
65
+ Note:
66
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
67
+ the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
68
+ """
69
+
70
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
71
+ """Initilize duration predictor module.
72
+ Args:
73
+ idim (int): Input dimension.
74
+ n_layers (int, optional): Number of convolutional layers.
75
+ n_chans (int, optional): Number of channels of convolutional layers.
76
+ kernel_size (int, optional): Kernel size of convolutional layers.
77
+ dropout_rate (float, optional): Dropout rate.
78
+ offset (float, optional): Offset value to avoid nan in log domain.
79
+ """
80
+ super(DurationPredictor, self).__init__()
81
+ self.offset = offset
82
+ self.conv = torch.nn.ModuleList()
83
+ self.kernel_size = kernel_size
84
+ self.padding = padding
85
+ for idx in range(n_layers):
86
+ in_chans = idim if idx == 0 else n_chans
87
+ self.conv += [torch.nn.Sequential(
88
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
89
+ if padding == 'SAME'
90
+ else (kernel_size - 1, 0), 0),
91
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
92
+ torch.nn.ReLU(),
93
+ LayerNorm(n_chans, dim=1),
94
+ torch.nn.Dropout(dropout_rate)
95
+ )]
96
+ if hparams['dur_loss'] in ['mse', 'huber']:
97
+ odims = 1
98
+ elif hparams['dur_loss'] == 'mog':
99
+ odims = 15
100
+ elif hparams['dur_loss'] == 'crf':
101
+ odims = 32
102
+ from torchcrf import CRF
103
+ self.crf = CRF(odims, batch_first=True)
104
+ self.linear = torch.nn.Linear(n_chans, odims)
105
+
106
+ def _forward(self, xs, x_masks=None, is_inference=False):
107
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
108
+ for f in self.conv:
109
+ xs = f(xs) # (B, C, Tmax)
110
+ if x_masks is not None:
111
+ xs = xs * (1 - x_masks.float())[:, None, :]
112
+
113
+ xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
114
+ xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
115
+ if is_inference:
116
+ return self.out2dur(xs), xs
117
+ else:
118
+ if hparams['dur_loss'] in ['mse']:
119
+ xs = xs.squeeze(-1) # (B, Tmax)
120
+ return xs
121
+
122
+ def out2dur(self, xs):
123
+ if hparams['dur_loss'] in ['mse']:
124
+ # NOTE: calculate in log domain
125
+ xs = xs.squeeze(-1) # (B, Tmax)
126
+ dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
127
+ elif hparams['dur_loss'] == 'mog':
128
+ return NotImplementedError
129
+ elif hparams['dur_loss'] == 'crf':
130
+ dur = torch.LongTensor(self.crf.decode(xs)).cuda()
131
+ return dur
132
+
133
+ def forward(self, xs, x_masks=None):
134
+ """Calculate forward propagation.
135
+ Args:
136
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
137
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
138
+ Returns:
139
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
140
+ """
141
+ return self._forward(xs, x_masks, False)
142
+
143
+ def inference(self, xs, x_masks=None):
144
+ """Inference duration.
145
+ Args:
146
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
147
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
148
+ Returns:
149
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
150
+ """
151
+ return self._forward(xs, x_masks, True)
152
+
153
+
154
+ class LengthRegulator(torch.nn.Module):
155
+ def __init__(self, pad_value=0.0):
156
+ super(LengthRegulator, self).__init__()
157
+ self.pad_value = pad_value
158
+
159
+ def forward(self, dur, dur_padding=None, alpha=1.0):
160
+ """
161
+ Example (no batch dim version):
162
+ 1. dur = [2,2,3]
163
+ 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
164
+ 3. token_mask = [[1,1,0,0,0,0,0],
165
+ [0,0,1,1,0,0,0],
166
+ [0,0,0,0,1,1,1]]
167
+ 4. token_idx * token_mask = [[1,1,0,0,0,0,0],
168
+ [0,0,2,2,0,0,0],
169
+ [0,0,0,0,3,3,3]]
170
+ 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
171
+
172
+ :param dur: Batch of durations of each frame (B, T_txt)
173
+ :param dur_padding: Batch of padding of each frame (B, T_txt)
174
+ :param alpha: duration rescale coefficient
175
+ :return:
176
+ mel2ph (B, T_speech)
177
+ """
178
+ assert alpha > 0
179
+ dur = torch.round(dur.float() * alpha).long()
180
+ if dur_padding is not None:
181
+ dur = dur * (1 - dur_padding.long())
182
+ token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
183
+ dur_cumsum = torch.cumsum(dur, 1)
184
+ dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
185
+
186
+ pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
187
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
188
+ mel2ph = (token_idx * token_mask.long()).sum(1)
189
+ return mel2ph
190
+
191
+
192
+ class PitchPredictor(torch.nn.Module):
193
+ def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
194
+ dropout_rate=0.1, padding='SAME'):
195
+ """Initilize pitch predictor module.
196
+ Args:
197
+ idim (int): Input dimension.
198
+ n_layers (int, optional): Number of convolutional layers.
199
+ n_chans (int, optional): Number of channels of convolutional layers.
200
+ kernel_size (int, optional): Kernel size of convolutional layers.
201
+ dropout_rate (float, optional): Dropout rate.
202
+ """
203
+ super(PitchPredictor, self).__init__()
204
+ self.conv = torch.nn.ModuleList()
205
+ self.kernel_size = kernel_size
206
+ self.padding = padding
207
+ for idx in range(n_layers):
208
+ in_chans = idim if idx == 0 else n_chans
209
+ self.conv += [torch.nn.Sequential(
210
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
211
+ if padding == 'SAME'
212
+ else (kernel_size - 1, 0), 0),
213
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
214
+ torch.nn.ReLU(),
215
+ LayerNorm(n_chans, dim=1),
216
+ torch.nn.Dropout(dropout_rate)
217
+ )]
218
+ self.linear = torch.nn.Linear(n_chans, odim)
219
+ self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
220
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
221
+
222
+ def forward(self, xs):
223
+ """
224
+
225
+ :param xs: [B, T, H]
226
+ :return: [B, T, H]
227
+ """
228
+ positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
229
+ xs = xs + positions
230
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
231
+ for f in self.conv:
232
+ xs = f(xs) # (B, C, Tmax)
233
+ # NOTE: calculate in log domain
234
+ xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
235
+ return xs
236
+
237
+
238
+ class EnergyPredictor(PitchPredictor):
239
+ pass
240
+
241
+
242
+ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
243
+ B, _ = mel2ph.shape
244
+ dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
245
+ dur = dur[:, 1:]
246
+ if max_dur is not None:
247
+ dur = dur.clamp(max=max_dur)
248
+ return dur
249
+
250
+
251
+ class FFTBlocks(nn.Module):
252
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
253
+ use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
254
+ super().__init__()
255
+ self.num_layers = num_layers
256
+ embed_dim = self.hidden_size = hidden_size
257
+ self.dropout = dropout if dropout is not None else hparams['dropout']
258
+ self.use_pos_embed = use_pos_embed
259
+ self.use_last_norm = use_last_norm
260
+ if use_pos_embed:
261
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
262
+ self.padding_idx = 0
263
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
264
+ self.embed_positions = SinusoidalPositionalEmbedding(
265
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
266
+ )
267
+
268
+ self.layers = nn.ModuleList([])
269
+ self.layers.extend([
270
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
271
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
272
+ for _ in range(self.num_layers)
273
+ ])
274
+ if self.use_last_norm:
275
+ if norm == 'ln':
276
+ self.layer_norm = nn.LayerNorm(embed_dim)
277
+ elif norm == 'bn':
278
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
279
+ else:
280
+ self.layer_norm = None
281
+
282
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
283
+ """
284
+ :param x: [B, T, C]
285
+ :param padding_mask: [B, T]
286
+ :return: [B, T, C] or [L, B, T, C]
287
+ """
288
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
289
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
290
+ if self.use_pos_embed:
291
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
292
+ x = x + positions
293
+ x = F.dropout(x, p=self.dropout, training=self.training)
294
+ # B x T x C -> T x B x C
295
+ x = x.transpose(0, 1) * nonpadding_mask_TB
296
+ hiddens = []
297
+ for layer in self.layers:
298
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
299
+ hiddens.append(x)
300
+ if self.use_last_norm:
301
+ x = self.layer_norm(x) * nonpadding_mask_TB
302
+ if return_hiddens:
303
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
304
+ x = x.transpose(1, 2) # [L, B, T, C]
305
+ else:
306
+ x = x.transpose(0, 1) # [B, T, C]
307
+ return x
308
+
309
+
310
+ class FastspeechEncoder(FFTBlocks):
311
+ def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
312
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
313
+ kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
314
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
315
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
316
+ use_pos_embed=False) # use_pos_embed_alpha for compatibility
317
+ self.embed_tokens = embed_tokens
318
+ self.embed_scale = math.sqrt(hidden_size)
319
+ self.padding_idx = 0
320
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
321
+ self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
322
+ else:
323
+ self.embed_positions = SinusoidalPositionalEmbedding(
324
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
325
+ )
326
+
327
+ def forward(self, txt_tokens):
328
+ """
329
+
330
+ :param txt_tokens: [B, T]
331
+ :return: {
332
+ 'encoder_out': [T x B x C]
333
+ }
334
+ """
335
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
336
+ x = self.forward_embedding(txt_tokens) # [B, T, H]
337
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
338
+ return x
339
+
340
+ def forward_embedding(self, txt_tokens):
341
+ # embed tokens and positions
342
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
343
+ if hparams['use_pos_embed']:
344
+ positions = self.embed_positions(txt_tokens)
345
+ x = x + positions
346
+ x = F.dropout(x, p=self.dropout, training=self.training)
347
+ return x
348
+
349
+
350
+ class FastspeechDecoder(FFTBlocks):
351
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
352
+ num_heads = hparams['num_heads'] if num_heads is None else num_heads
353
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
354
+ kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
355
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
356
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
357
+
modules/hifigan/hifigan.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+
7
+ from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
8
+ from modules.parallel_wavegan.models.source import SourceModuleHnNSF
9
+ import numpy as np
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ def init_weights(m, mean=0.0, std=0.01):
15
+ classname = m.__class__.__name__
16
+ if classname.find("Conv") != -1:
17
+ m.weight.data.normal_(mean, std)
18
+
19
+
20
+ def apply_weight_norm(m):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ weight_norm(m)
24
+
25
+
26
+ def get_padding(kernel_size, dilation=1):
27
+ return int((kernel_size * dilation - dilation) / 2)
28
+
29
+
30
+ class ResBlock1(torch.nn.Module):
31
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32
+ super(ResBlock1, self).__init__()
33
+ self.h = h
34
+ self.convs1 = nn.ModuleList([
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36
+ padding=get_padding(kernel_size, dilation[0]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38
+ padding=get_padding(kernel_size, dilation[1]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40
+ padding=get_padding(kernel_size, dilation[2])))
41
+ ])
42
+ self.convs1.apply(init_weights)
43
+
44
+ self.convs2 = nn.ModuleList([
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1)))
51
+ ])
52
+ self.convs2.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ for c1, c2 in zip(self.convs1, self.convs2):
56
+ xt = F.leaky_relu(x, LRELU_SLOPE)
57
+ xt = c1(xt)
58
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
59
+ xt = c2(xt)
60
+ x = xt + x
61
+ return x
62
+
63
+ def remove_weight_norm(self):
64
+ for l in self.convs1:
65
+ remove_weight_norm(l)
66
+ for l in self.convs2:
67
+ remove_weight_norm(l)
68
+
69
+
70
+ class ResBlock2(torch.nn.Module):
71
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72
+ super(ResBlock2, self).__init__()
73
+ self.h = h
74
+ self.convs = nn.ModuleList([
75
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76
+ padding=get_padding(kernel_size, dilation[0]))),
77
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78
+ padding=get_padding(kernel_size, dilation[1])))
79
+ ])
80
+ self.convs.apply(init_weights)
81
+
82
+ def forward(self, x):
83
+ for c in self.convs:
84
+ xt = F.leaky_relu(x, LRELU_SLOPE)
85
+ xt = c(xt)
86
+ x = xt + x
87
+ return x
88
+
89
+ def remove_weight_norm(self):
90
+ for l in self.convs:
91
+ remove_weight_norm(l)
92
+
93
+
94
+ class Conv1d1x1(Conv1d):
95
+ """1x1 Conv1d with customized initialization."""
96
+
97
+ def __init__(self, in_channels, out_channels, bias):
98
+ """Initialize 1x1 Conv1d module."""
99
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
100
+ kernel_size=1, padding=0,
101
+ dilation=1, bias=bias)
102
+
103
+
104
+ class HifiGanGenerator(torch.nn.Module):
105
+ def __init__(self, h, c_out=1):
106
+ super(HifiGanGenerator, self).__init__()
107
+ self.h = h
108
+ self.num_kernels = len(h['resblock_kernel_sizes'])
109
+ self.num_upsamples = len(h['upsample_rates'])
110
+
111
+ if h['use_pitch_embed']:
112
+ self.harmonic_num = 8
113
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114
+ self.m_source = SourceModuleHnNSF(
115
+ sampling_rate=h['audio_sample_rate'],
116
+ harmonic_num=self.harmonic_num)
117
+ self.noise_convs = nn.ModuleList()
118
+ self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119
+ resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120
+
121
+ self.ups = nn.ModuleList()
122
+ for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123
+ c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124
+ self.ups.append(weight_norm(
125
+ ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126
+ if h['use_pitch_embed']:
127
+ if i + 1 < len(h['upsample_rates']):
128
+ stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129
+ self.noise_convs.append(Conv1d(
130
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131
+ else:
132
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133
+
134
+ self.resblocks = nn.ModuleList()
135
+ for i in range(len(self.ups)):
136
+ ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137
+ for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138
+ self.resblocks.append(resblock(h, ch, k, d))
139
+
140
+ self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141
+ self.ups.apply(init_weights)
142
+ self.conv_post.apply(init_weights)
143
+
144
+ def forward(self, x, f0=None):
145
+ if f0 is not None:
146
+ # harmonic-source signal, noise-source signal, uv flag
147
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148
+ har_source, noi_source, uv = self.m_source(f0)
149
+ har_source = har_source.transpose(1, 2)
150
+
151
+ x = self.conv_pre(x)
152
+ for i in range(self.num_upsamples):
153
+ x = F.leaky_relu(x, LRELU_SLOPE)
154
+ x = self.ups[i](x)
155
+ if f0 is not None:
156
+ x_source = self.noise_convs[i](har_source)
157
+ x_source = torch.nn.functional.relu(x_source)
158
+ tmp_shape = x_source.shape[1]
159
+ x_source = torch.nn.functional.layer_norm(x_source.transpose(1, -1), (tmp_shape, )).transpose(1, -1)
160
+ x = x + x_source
161
+ xs = None
162
+ for j in range(self.num_kernels):
163
+ xs_ = self.resblocks[i * self.num_kernels + j](x)
164
+ if xs is None:
165
+ xs = xs_
166
+ else:
167
+ xs += xs_
168
+ x = xs / self.num_kernels
169
+ x = F.leaky_relu(x)
170
+ x = self.conv_post(x)
171
+ x = torch.tanh(x)
172
+
173
+ return x
174
+
175
+ def remove_weight_norm(self):
176
+ print('Removing weight norm...')
177
+ for l in self.ups:
178
+ remove_weight_norm(l)
179
+ for l in self.resblocks:
180
+ l.remove_weight_norm()
181
+ remove_weight_norm(self.conv_pre)
182
+ remove_weight_norm(self.conv_post)
183
+
184
+
185
+ class DiscriminatorP(torch.nn.Module):
186
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
187
+ super(DiscriminatorP, self).__init__()
188
+ self.use_cond = use_cond
189
+ if use_cond:
190
+ from utils.hparams import hparams
191
+ t = hparams['hop_size']
192
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
193
+ c_in = 2
194
+
195
+ self.period = period
196
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
197
+ self.convs = nn.ModuleList([
198
+ norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
199
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
200
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
201
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
202
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
203
+ ])
204
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
205
+
206
+ def forward(self, x, mel):
207
+ fmap = []
208
+ if self.use_cond:
209
+ x_mel = self.cond_net(mel)
210
+ x = torch.cat([x_mel, x], 1)
211
+ # 1d to 2d
212
+ b, c, t = x.shape
213
+ if t % self.period != 0: # pad first
214
+ n_pad = self.period - (t % self.period)
215
+ x = F.pad(x, (0, n_pad), "reflect")
216
+ t = t + n_pad
217
+ x = x.view(b, c, t // self.period, self.period)
218
+
219
+ for l in self.convs:
220
+ x = l(x)
221
+ x = F.leaky_relu(x, LRELU_SLOPE)
222
+ fmap.append(x)
223
+ x = self.conv_post(x)
224
+ fmap.append(x)
225
+ x = torch.flatten(x, 1, -1)
226
+
227
+ return x, fmap
228
+
229
+
230
+ class MultiPeriodDiscriminator(torch.nn.Module):
231
+ def __init__(self, use_cond=False, c_in=1):
232
+ super(MultiPeriodDiscriminator, self).__init__()
233
+ self.discriminators = nn.ModuleList([
234
+ DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
235
+ DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
236
+ DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
237
+ DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
238
+ DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
239
+ ])
240
+
241
+ def forward(self, y, y_hat, mel=None):
242
+ y_d_rs = []
243
+ y_d_gs = []
244
+ fmap_rs = []
245
+ fmap_gs = []
246
+ for i, d in enumerate(self.discriminators):
247
+ y_d_r, fmap_r = d(y, mel)
248
+ y_d_g, fmap_g = d(y_hat, mel)
249
+ y_d_rs.append(y_d_r)
250
+ fmap_rs.append(fmap_r)
251
+ y_d_gs.append(y_d_g)
252
+ fmap_gs.append(fmap_g)
253
+
254
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
255
+
256
+
257
+ class DiscriminatorS(torch.nn.Module):
258
+ def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
259
+ super(DiscriminatorS, self).__init__()
260
+ self.use_cond = use_cond
261
+ if use_cond:
262
+ t = np.prod(upsample_rates)
263
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
264
+ c_in = 2
265
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
266
+ self.convs = nn.ModuleList([
267
+ norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
268
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
269
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
270
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
271
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
272
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
273
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
274
+ ])
275
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
276
+
277
+ def forward(self, x, mel):
278
+ if self.use_cond:
279
+ x_mel = self.cond_net(mel)
280
+ x = torch.cat([x_mel, x], 1)
281
+ fmap = []
282
+ for l in self.convs:
283
+ x = l(x)
284
+ x = F.leaky_relu(x, LRELU_SLOPE)
285
+ fmap.append(x)
286
+ x = self.conv_post(x)
287
+ fmap.append(x)
288
+ x = torch.flatten(x, 1, -1)
289
+
290
+ return x, fmap
291
+
292
+
293
+ class MultiScaleDiscriminator(torch.nn.Module):
294
+ def __init__(self, use_cond=False, c_in=1):
295
+ super(MultiScaleDiscriminator, self).__init__()
296
+ from utils.hparams import hparams
297
+ self.discriminators = nn.ModuleList([
298
+ DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
299
+ upsample_rates=[4, 4, hparams['hop_size'] // 16],
300
+ c_in=c_in),
301
+ DiscriminatorS(use_cond=use_cond,
302
+ upsample_rates=[4, 4, hparams['hop_size'] // 32],
303
+ c_in=c_in),
304
+ DiscriminatorS(use_cond=use_cond,
305
+ upsample_rates=[4, 4, hparams['hop_size'] // 64],
306
+ c_in=c_in),
307
+ ])
308
+ self.meanpools = nn.ModuleList([
309
+ AvgPool1d(4, 2, padding=1),
310
+ AvgPool1d(4, 2, padding=1)
311
+ ])
312
+
313
+ def forward(self, y, y_hat, mel=None):
314
+ y_d_rs = []
315
+ y_d_gs = []
316
+ fmap_rs = []
317
+ fmap_gs = []
318
+ for i, d in enumerate(self.discriminators):
319
+ if i != 0:
320
+ y = self.meanpools[i - 1](y)
321
+ y_hat = self.meanpools[i - 1](y_hat)
322
+ y_d_r, fmap_r = d(y, mel)
323
+ y_d_g, fmap_g = d(y_hat, mel)
324
+ y_d_rs.append(y_d_r)
325
+ fmap_rs.append(fmap_r)
326
+ y_d_gs.append(y_d_g)
327
+ fmap_gs.append(fmap_g)
328
+
329
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
330
+
331
+
332
+ def feature_loss(fmap_r, fmap_g):
333
+ loss = 0
334
+ for dr, dg in zip(fmap_r, fmap_g):
335
+ for rl, gl in zip(dr, dg):
336
+ loss += torch.mean(torch.abs(rl - gl))
337
+
338
+ return loss * 2
339
+
340
+
341
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
342
+ r_losses = 0
343
+ g_losses = 0
344
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
345
+ r_loss = torch.mean((1 - dr) ** 2)
346
+ g_loss = torch.mean(dg ** 2)
347
+ r_losses += r_loss
348
+ g_losses += g_loss
349
+ r_losses = r_losses / len(disc_real_outputs)
350
+ g_losses = g_losses / len(disc_real_outputs)
351
+ return r_losses, g_losses
352
+
353
+
354
+ def cond_discriminator_loss(outputs):
355
+ loss = 0
356
+ for dg in outputs:
357
+ g_loss = torch.mean(dg ** 2)
358
+ loss += g_loss
359
+ loss = loss / len(outputs)
360
+ return loss
361
+
362
+
363
+ def generator_loss(disc_outputs):
364
+ loss = 0
365
+ for dg in disc_outputs:
366
+ l = torch.mean((1 - dg) ** 2)
367
+ loss += l
368
+ loss = loss / len(disc_outputs)
369
+ return loss
370
+
modules/hifigan/mel_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, hparams, center=False, complex=False):
46
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
47
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
48
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
49
+ # fmax: 10000 # To be increased/reduced depending on data.
50
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
51
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
52
+ n_fft = hparams['fft_size']
53
+ num_mels = hparams['audio_num_mel_bins']
54
+ sampling_rate = hparams['audio_sample_rate']
55
+ hop_size = hparams['hop_size']
56
+ win_size = hparams['win_size']
57
+ fmin = hparams['fmin']
58
+ fmax = hparams['fmax']
59
+ y = y.clamp(min=-1., max=1.)
60
+ global mel_basis, hann_window
61
+ if fmax not in mel_basis:
62
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
63
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
64
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
65
+
66
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
67
+ mode='reflect')
68
+ y = y.squeeze(1)
69
+
70
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
71
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
72
+
73
+ if not complex:
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
75
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
76
+ spec = spectral_normalize_torch(spec)
77
+ else:
78
+ B, C, T, _ = spec.shape
79
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
80
+ return spec
81
+
modules/parallel_wavegan/__init__.py ADDED
File without changes
modules/parallel_wavegan/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .causal_conv import * # NOQA
2
+ from .pqmf import * # NOQA
3
+ from .residual_block import * # NOQA
4
+ from modules.parallel_wavegan.layers.residual_stack import * # NOQA
5
+ from .upsample import * # NOQA
modules/parallel_wavegan/layers/causal_conv.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Causal convolusion layer modules."""
7
+
8
+
9
+ import torch
10
+
11
+
12
+ class CausalConv1d(torch.nn.Module):
13
+ """CausalConv1d module with customized initialization."""
14
+
15
+ def __init__(self, in_channels, out_channels, kernel_size,
16
+ dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
17
+ """Initialize CausalConv1d module."""
18
+ super(CausalConv1d, self).__init__()
19
+ self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
20
+ self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
21
+ dilation=dilation, bias=bias)
22
+
23
+ def forward(self, x):
24
+ """Calculate forward propagation.
25
+
26
+ Args:
27
+ x (Tensor): Input tensor (B, in_channels, T).
28
+
29
+ Returns:
30
+ Tensor: Output tensor (B, out_channels, T).
31
+
32
+ """
33
+ return self.conv(self.pad(x))[:, :, :x.size(2)]
34
+
35
+
36
+ class CausalConvTranspose1d(torch.nn.Module):
37
+ """CausalConvTranspose1d module with customized initialization."""
38
+
39
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
40
+ """Initialize CausalConvTranspose1d module."""
41
+ super(CausalConvTranspose1d, self).__init__()
42
+ self.deconv = torch.nn.ConvTranspose1d(
43
+ in_channels, out_channels, kernel_size, stride, bias=bias)
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ """Calculate forward propagation.
48
+
49
+ Args:
50
+ x (Tensor): Input tensor (B, in_channels, T_in).
51
+
52
+ Returns:
53
+ Tensor: Output tensor (B, out_channels, T_out).
54
+
55
+ """
56
+ return self.deconv(x)[:, :, :-self.stride]
modules/parallel_wavegan/layers/pqmf.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Pseudo QMF modules."""
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from scipy.signal import kaiser
13
+
14
+
15
+ def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
16
+ """Design prototype filter for PQMF.
17
+
18
+ This method is based on `A Kaiser window approach for the design of prototype
19
+ filters of cosine modulated filterbanks`_.
20
+
21
+ Args:
22
+ taps (int): The number of filter taps.
23
+ cutoff_ratio (float): Cut-off frequency ratio.
24
+ beta (float): Beta coefficient for kaiser window.
25
+
26
+ Returns:
27
+ ndarray: Impluse response of prototype filter (taps + 1,).
28
+
29
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
30
+ https://ieeexplore.ieee.org/abstract/document/681427
31
+
32
+ """
33
+ # check the arguments are valid
34
+ assert taps % 2 == 0, "The number of taps mush be even number."
35
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
36
+
37
+ # make initial filter
38
+ omega_c = np.pi * cutoff_ratio
39
+ with np.errstate(invalid='ignore'):
40
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
41
+ / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
42
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
43
+
44
+ # apply kaiser window
45
+ w = kaiser(taps + 1, beta)
46
+ h = h_i * w
47
+
48
+ return h
49
+
50
+
51
+ class PQMF(torch.nn.Module):
52
+ """PQMF module.
53
+
54
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
55
+
56
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
57
+ https://ieeexplore.ieee.org/document/258122
58
+
59
+ """
60
+
61
+ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
62
+ """Initilize PQMF module.
63
+
64
+ Args:
65
+ subbands (int): The number of subbands.
66
+ taps (int): The number of filter taps.
67
+ cutoff_ratio (float): Cut-off frequency ratio.
68
+ beta (float): Beta coefficient for kaiser window.
69
+
70
+ """
71
+ super(PQMF, self).__init__()
72
+
73
+ # define filter coefficient
74
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
75
+ h_analysis = np.zeros((subbands, len(h_proto)))
76
+ h_synthesis = np.zeros((subbands, len(h_proto)))
77
+ for k in range(subbands):
78
+ h_analysis[k] = 2 * h_proto * np.cos(
79
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
80
+ (np.arange(taps + 1) - ((taps - 1) / 2)) +
81
+ (-1) ** k * np.pi / 4)
82
+ h_synthesis[k] = 2 * h_proto * np.cos(
83
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
84
+ (np.arange(taps + 1) - ((taps - 1) / 2)) -
85
+ (-1) ** k * np.pi / 4)
86
+
87
+ # convert to tensor
88
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
89
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
90
+
91
+ # register coefficients as beffer
92
+ self.register_buffer("analysis_filter", analysis_filter)
93
+ self.register_buffer("synthesis_filter", synthesis_filter)
94
+
95
+ # filter for downsampling & upsampling
96
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float()
97
+ for k in range(subbands):
98
+ updown_filter[k, k, 0] = 1.0
99
+ self.register_buffer("updown_filter", updown_filter)
100
+ self.subbands = subbands
101
+
102
+ # keep padding info
103
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
104
+
105
+ def analysis(self, x):
106
+ """Analysis with PQMF.
107
+
108
+ Args:
109
+ x (Tensor): Input tensor (B, 1, T).
110
+
111
+ Returns:
112
+ Tensor: Output tensor (B, subbands, T // subbands).
113
+
114
+ """
115
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
116
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
117
+
118
+ def synthesis(self, x):
119
+ """Synthesis with PQMF.
120
+
121
+ Args:
122
+ x (Tensor): Input tensor (B, subbands, T // subbands).
123
+
124
+ Returns:
125
+ Tensor: Output tensor (B, 1, T).
126
+
127
+ """
128
+ x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
129
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)