10kwon commited on
Commit
2bfc29a
·
1 Parent(s): 10f66bf
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. batch.py +43 -0
  2. flask_api.py +54 -0
  3. infer.py +98 -0
  4. infer_tools/__init__.py +0 -0
  5. infer_tools/infer_tool.py +334 -0
  6. infer_tools/slicer.py +158 -0
  7. modules/commons/common_layers.py +671 -0
  8. modules/commons/espnet_positional_embedding.py +113 -0
  9. modules/commons/ssim.py +391 -0
  10. modules/fastspeech/fs2.py +255 -0
  11. modules/fastspeech/pe.py +149 -0
  12. modules/fastspeech/tts_modules.py +364 -0
  13. modules/hifigan/hifigan.py +365 -0
  14. modules/hifigan/mel_utils.py +80 -0
  15. modules/nsf_hifigan/env.py +15 -0
  16. modules/nsf_hifigan/models.py +549 -0
  17. modules/nsf_hifigan/nvSTFT.py +111 -0
  18. modules/nsf_hifigan/utils.py +67 -0
  19. modules/parallel_wavegan/__init__.py +0 -0
  20. modules/parallel_wavegan/layers/__init__.py +5 -0
  21. modules/parallel_wavegan/layers/causal_conv.py +56 -0
  22. modules/parallel_wavegan/layers/pqmf.py +129 -0
  23. modules/parallel_wavegan/layers/residual_block.py +129 -0
  24. modules/parallel_wavegan/layers/residual_stack.py +75 -0
  25. modules/parallel_wavegan/layers/tf_layers.py +129 -0
  26. modules/parallel_wavegan/layers/upsample.py +183 -0
  27. modules/parallel_wavegan/losses/__init__.py +1 -0
  28. modules/parallel_wavegan/losses/stft_loss.py +153 -0
  29. modules/parallel_wavegan/models/__init__.py +2 -0
  30. modules/parallel_wavegan/models/melgan.py +427 -0
  31. modules/parallel_wavegan/models/parallel_wavegan.py +434 -0
  32. modules/parallel_wavegan/models/source.py +538 -0
  33. modules/parallel_wavegan/optimizers/__init__.py +2 -0
  34. modules/parallel_wavegan/optimizers/radam.py +91 -0
  35. modules/parallel_wavegan/stft_loss.py +100 -0
  36. modules/parallel_wavegan/utils/__init__.py +1 -0
  37. modules/parallel_wavegan/utils/utils.py +169 -0
  38. network/diff/candidate_decoder.py +98 -0
  39. network/diff/diffusion.py +332 -0
  40. network/diff/net.py +135 -0
  41. network/hubert/hubert_model.py +276 -0
  42. network/hubert/vec_model.py +60 -0
  43. network/vocoders/__init__.py +2 -0
  44. network/vocoders/base_vocoder.py +39 -0
  45. network/vocoders/hifigan.py +83 -0
  46. network/vocoders/nsf_hifigan.py +92 -0
  47. network/vocoders/pwg.py +137 -0
  48. network/vocoders/vocoder_utils.py +15 -0
  49. preprocessing/SVCpre.py +63 -0
  50. preprocessing/base_binarizer.py +237 -0
batch.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile
2
+
3
+ from infer_tools import infer_tool
4
+ from infer_tools.infer_tool import Svc
5
+
6
+
7
+ def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
8
+ file_path=None, out_path=None):
9
+ raw_audio_path = f_name
10
+ infer_tool.format_wav(raw_audio_path)
11
+ _f0_tst, _f0_pred, _audio = svc_model.infer(raw_audio_path, key=key, acc=acc, singer=True, use_pe=use_pe,
12
+ use_crepe=use_crepe,
13
+ thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
14
+ out_path = f'./singer_data/{f_name.split("/")[-1]}'
15
+ soundfile.write(out_path, _audio, 44100, 'PCM_16')
16
+
17
+
18
+ if __name__ == '__main__':
19
+ # 工程文件夹名,训练时用的那个
20
+ project_name = "firefox"
21
+ model_path = f'./checkpoints/{project_name}/clean_model_ckpt_steps_100000.ckpt'
22
+ config_path = f'./checkpoints/{project_name}/config.yaml'
23
+
24
+ # 支持多个wav/ogg文件,放在raw文件夹下,带扩展名
25
+ file_names = infer_tool.get_end_file("./batch", "wav")
26
+ trans = [-6] # 音高调整,支持正负(半音),数量与上一行对应,不足的自动按第一个移调参数补齐
27
+ # 加速倍数
28
+ accelerate = 50
29
+ hubert_gpu = True
30
+ cut_time = 30
31
+
32
+ # 下面不动
33
+ infer_tool.mkdir(["./batch", "./singer_data"])
34
+ infer_tool.fill_a_to_b(trans, file_names)
35
+
36
+ model = Svc(project_name, config_path, hubert_gpu, model_path)
37
+ count = 0
38
+ for f_name, tran in zip(file_names, trans):
39
+ print(f_name)
40
+ run_clip(model, key=tran, acc=accelerate, use_crepe=False, thre=0.05, use_pe=False, use_gt_mel=False,
41
+ add_noise_step=500, f_name=f_name, project_name=project_name)
42
+ count += 1
43
+ print(f"process:{round(count * 100 / len(file_names), 2)}%")
flask_api.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+ import librosa
5
+ import soundfile
6
+ from flask import Flask, request, send_file
7
+ from flask_cors import CORS
8
+
9
+ from infer_tools.infer_tool import Svc
10
+ from utils.hparams import hparams
11
+
12
+ app = Flask(__name__)
13
+
14
+ CORS(app)
15
+
16
+ logging.getLogger('numba').setLevel(logging.WARNING)
17
+
18
+
19
+ @app.route("/voiceChangeModel", methods=["POST"])
20
+ def voice_change_model():
21
+ request_form = request.form
22
+ wave_file = request.files.get("sample", None)
23
+ # 变调信息
24
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
25
+ # DAW所需的采样率
26
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
27
+ speaker_id = int(float(request_form.get("sSpeakId", 0)))
28
+ # http获得wav文件并转换
29
+ input_wav_path = io.BytesIO(wave_file.read())
30
+ # 模型推理
31
+ _f0_tst, _f0_pred, _audio = model.infer(input_wav_path, key=f_pitch_change, acc=accelerate, use_pe=False,
32
+ use_crepe=False)
33
+ tar_audio = librosa.resample(_audio, hparams["audio_sample_rate"], daw_sample)
34
+ # 返回音频
35
+ out_wav_path = io.BytesIO()
36
+ soundfile.write(out_wav_path, tar_audio, daw_sample, format="wav")
37
+ out_wav_path.seek(0)
38
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ # 工程文件夹名,训练时用的那个
43
+ project_name = "firefox"
44
+ model_path = f'./checkpoints/{project_name}/model_ckpt_steps_188000.ckpt'
45
+ config_path = f'./checkpoints/{project_name}/config.yaml'
46
+
47
+ # 加速倍数
48
+ accelerate = 50
49
+ hubert_gpu = True
50
+
51
+ model = Svc(project_name, config_path, hubert_gpu, model_path)
52
+
53
+ # 此处与vst插件对应,不建议更改
54
+ app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
infer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+
9
+ from infer_tools import infer_tool
10
+ from infer_tools import slicer
11
+ from infer_tools.infer_tool import Svc
12
+ from utils.hparams import hparams
13
+
14
+ chunks_dict = infer_tool.read_temp("./infer_tools/new_chunks_temp.json")
15
+
16
+
17
+ def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
18
+ file_path=None, out_path=None, slice_db=-40,**kwargs):
19
+ print(f'code version:2022-12-04')
20
+ use_pe = use_pe if hparams['audio_sample_rate'] == 24000 else False
21
+ if file_path is None:
22
+ raw_audio_path = f"./raw/{f_name}"
23
+ clean_name = f_name[:-4]
24
+ else:
25
+ raw_audio_path = file_path
26
+ clean_name = str(Path(file_path).name)[:-4]
27
+ infer_tool.format_wav(raw_audio_path)
28
+ wav_path = Path(raw_audio_path).with_suffix('.wav')
29
+ global chunks_dict
30
+ audio, sr = librosa.load(wav_path, mono=True,sr=None)
31
+ wav_hash = infer_tool.get_md5(audio)
32
+ if wav_hash in chunks_dict.keys():
33
+ print("load chunks from temp")
34
+ chunks = chunks_dict[wav_hash]["chunks"]
35
+ else:
36
+ chunks = slicer.cut(wav_path, db_thresh=slice_db)
37
+ chunks_dict[wav_hash] = {"chunks": chunks, "time": int(time.time())}
38
+ infer_tool.write_temp("./infer_tools/new_chunks_temp.json", chunks_dict)
39
+ audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
40
+
41
+ count = 0
42
+ f0_tst = []
43
+ f0_pred = []
44
+ audio = []
45
+ for (slice_tag, data) in audio_data:
46
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
47
+ length = int(np.ceil(len(data) / audio_sr * hparams['audio_sample_rate']))
48
+ raw_path = io.BytesIO()
49
+ soundfile.write(raw_path, data, audio_sr, format="wav")
50
+ if hparams['debug']:
51
+ print(np.mean(data), np.var(data))
52
+ raw_path.seek(0)
53
+ if slice_tag:
54
+ print('jump empty segment')
55
+ _f0_tst, _f0_pred, _audio = (
56
+ np.zeros(int(np.ceil(length / hparams['hop_size']))), np.zeros(int(np.ceil(length / hparams['hop_size']))),
57
+ np.zeros(length))
58
+ else:
59
+ _f0_tst, _f0_pred, _audio = svc_model.infer(raw_path, key=key, acc=acc, use_pe=use_pe, use_crepe=use_crepe,
60
+ thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
61
+ fix_audio = np.zeros(length)
62
+ fix_audio[:] = np.mean(_audio)
63
+ fix_audio[:len(_audio)] = _audio[0 if len(_audio)<len(fix_audio) else len(_audio)-len(fix_audio):]
64
+ f0_tst.extend(_f0_tst)
65
+ f0_pred.extend(_f0_pred)
66
+ audio.extend(list(fix_audio))
67
+ count += 1
68
+ if out_path is None:
69
+ out_path = f'./results/{clean_name}_{key}key_{project_name}_{hparams["residual_channels"]}_{hparams["residual_layers"]}_{int(step / 1000)}k_{accelerate}x.{kwargs["format"]}'
70
+ soundfile.write(out_path, audio, hparams["audio_sample_rate"], 'PCM_16',format=out_path.split('.')[-1])
71
+ return np.array(f0_tst), np.array(f0_pred), audio
72
+
73
+
74
+ if __name__ == '__main__':
75
+ # 工程文件夹名,训练时用的那个
76
+ project_name = "yilanqiu"
77
+ model_path = f'./checkpoints/{project_name}/model_ckpt_steps_246000.ckpt'
78
+ config_path = f'./checkpoints/{project_name}/config.yaml'
79
+
80
+ # 支持多个wav/ogg文件,放在raw文件夹下,带扩展名
81
+ file_names = ["青花瓷.wav"]
82
+ trans = [0] # 音高调整,支持正负(半音),数量与上一行对应,不足的自动按第一个移调参数补齐
83
+ # 加速倍数
84
+ accelerate = 20
85
+ hubert_gpu = True
86
+ format='flac'
87
+ step = int(model_path.split("_")[-1].split(".")[0])
88
+
89
+ # 下面不动
90
+ infer_tool.mkdir(["./raw", "./results"])
91
+ infer_tool.fill_a_to_b(trans, file_names)
92
+
93
+ model = Svc(project_name, config_path, hubert_gpu, model_path)
94
+ for f_name, tran in zip(file_names, trans):
95
+ if "." not in f_name:
96
+ f_name += ".wav"
97
+ run_clip(model, key=tran, acc=accelerate, use_crepe=True, thre=0.05, use_pe=True, use_gt_mel=False,
98
+ add_noise_step=500, f_name=f_name, project_name=project_name, format=format)
infer_tools/__init__.py ADDED
File without changes
infer_tools/infer_tool.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import soundfile
11
+ import torch
12
+
13
+ import utils
14
+ from modules.fastspeech.pe import PitchExtractor
15
+ from network.diff.candidate_decoder import FFT
16
+ from network.diff.diffusion import GaussianDiffusion
17
+ from network.diff.net import DiffNet
18
+ from network.vocoders.base_vocoder import VOCODERS, get_vocoder_cls
19
+ from preprocessing.data_gen_utils import get_pitch_parselmouth, get_pitch_crepe
20
+ from preprocessing.hubertinfer import Hubertencoder
21
+ from utils.hparams import hparams, set_hparams
22
+ from utils.pitch_utils import denorm_f0, norm_interp_f0
23
+
24
+ if os.path.exists("chunks_temp.json"):
25
+ os.remove("chunks_temp.json")
26
+
27
+
28
+ def read_temp(file_name):
29
+ if not os.path.exists(file_name):
30
+ with open(file_name, "w") as f:
31
+ f.write(json.dumps({"info": "temp_dict"}))
32
+ return {}
33
+ else:
34
+ try:
35
+ with open(file_name, "r") as f:
36
+ data = f.read()
37
+ data_dict = json.loads(data)
38
+ if os.path.getsize(file_name) > 50 * 1024 * 1024:
39
+ f_name = file_name.split("/")[-1]
40
+ print(f"clean {f_name}")
41
+ for wav_hash in list(data_dict.keys()):
42
+ if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
43
+ del data_dict[wav_hash]
44
+ except Exception as e:
45
+ print(e)
46
+ print(f"{file_name} error,auto rebuild file")
47
+ data_dict = {"info": "temp_dict"}
48
+ return data_dict
49
+
50
+
51
+ f0_dict = read_temp("./infer_tools/f0_temp.json")
52
+
53
+
54
+ def write_temp(file_name, data):
55
+ with open(file_name, "w") as f:
56
+ f.write(json.dumps(data))
57
+
58
+
59
+ def timeit(func):
60
+ def run(*args, **kwargs):
61
+ t = time.time()
62
+ res = func(*args, **kwargs)
63
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
64
+ return res
65
+
66
+ return run
67
+
68
+
69
+ def format_wav(audio_path):
70
+ if Path(audio_path).suffix=='.wav':
71
+ return
72
+ raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True,sr=None)
73
+ soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
74
+
75
+
76
+ def fill_a_to_b(a, b):
77
+ if len(a) < len(b):
78
+ for _ in range(0, len(b) - len(a)):
79
+ a.append(a[0])
80
+
81
+
82
+ def get_end_file(dir_path, end):
83
+ file_lists = []
84
+ for root, dirs, files in os.walk(dir_path):
85
+ files = [f for f in files if f[0] != '.']
86
+ dirs[:] = [d for d in dirs if d[0] != '.']
87
+ for f_file in files:
88
+ if f_file.endswith(end):
89
+ file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
90
+ return file_lists
91
+
92
+
93
+ def mkdir(paths: list):
94
+ for path in paths:
95
+ if not os.path.exists(path):
96
+ os.mkdir(path)
97
+
98
+
99
+ def get_md5(content):
100
+ return hashlib.new("md5", content).hexdigest()
101
+
102
+
103
+ class Svc:
104
+ def __init__(self, project_name, config_name, hubert_gpu, model_path):
105
+ self.project_name = project_name
106
+ self.DIFF_DECODERS = {
107
+ 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
108
+ 'fft': lambda hp: FFT(
109
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
110
+ }
111
+
112
+ self.model_path = model_path
113
+ self.dev = torch.device("cuda")
114
+
115
+ self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
116
+ reset=True,
117
+ hparams_str='',
118
+ print_hparams=False)
119
+
120
+ self.mel_bins = hparams['audio_num_mel_bins']
121
+ self.model = GaussianDiffusion(
122
+ phone_encoder=Hubertencoder(hparams['hubert_path']),
123
+ out_dims=self.mel_bins, denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
124
+ timesteps=hparams['timesteps'],
125
+ K_step=hparams['K_step'],
126
+ loss_type=hparams['diff_loss_type'],
127
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
128
+ )
129
+ self.load_ckpt()
130
+ self.model.cuda()
131
+ hparams['hubert_gpu'] = hubert_gpu
132
+ self.hubert = Hubertencoder(hparams['hubert_path'])
133
+ self.pe = PitchExtractor().cuda()
134
+ utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
135
+ self.pe.eval()
136
+ self.vocoder = get_vocoder_cls(hparams)()
137
+
138
+ def load_ckpt(self, model_name='model', force=True, strict=True):
139
+ utils.load_ckpt(self.model, self.model_path, model_name, force, strict)
140
+
141
+ def infer(self, in_path, key, acc, use_pe=True, use_crepe=True, thre=0.05, singer=False, **kwargs):
142
+ batch = self.pre(in_path, acc, use_crepe, thre)
143
+ spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
144
+ hubert = batch['hubert']
145
+ ref_mels = batch["mels"]
146
+ energy=batch['energy']
147
+ mel2ph = batch['mel2ph']
148
+ batch['f0'] = batch['f0'] + (key / 12)
149
+ batch['f0'][batch['f0']>np.log2(hparams['f0_max'])]=0
150
+ f0 = batch['f0']
151
+ uv = batch['uv']
152
+ @timeit
153
+ def diff_infer():
154
+ outputs = self.model(
155
+ hubert.cuda(), spk_embed=spk_embed, mel2ph=mel2ph.cuda(), f0=f0.cuda(), uv=uv.cuda(),energy=energy.cuda(),
156
+ ref_mels=ref_mels.cuda(),
157
+ infer=True, **kwargs)
158
+ return outputs
159
+ outputs=diff_infer()
160
+ batch['outputs'] = self.model.out2mel(outputs['mel_out'])
161
+ batch['mel2ph_pred'] = outputs['mel2ph']
162
+ batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
163
+ if use_pe:
164
+ batch['f0_pred'] = self.pe(outputs['mel_out'])['f0_denorm_pred'].detach()
165
+ else:
166
+ batch['f0_pred'] = outputs.get('f0_denorm')
167
+ return self.after_infer(batch, singer, in_path)
168
+
169
+ @timeit
170
+ def after_infer(self, prediction, singer, in_path):
171
+ for k, v in prediction.items():
172
+ if type(v) is torch.Tensor:
173
+ prediction[k] = v.cpu().numpy()
174
+
175
+ # remove paddings
176
+ mel_gt = prediction["mels"]
177
+ mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
178
+
179
+ mel_pred = prediction["outputs"]
180
+ mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
181
+ mel_pred = mel_pred[mel_pred_mask]
182
+ mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
183
+
184
+ f0_gt = prediction.get("f0_gt")
185
+ f0_pred = prediction.get("f0_pred")
186
+ if f0_pred is not None:
187
+ f0_gt = f0_gt[mel_gt_mask]
188
+ if len(f0_pred) > len(mel_pred_mask):
189
+ f0_pred = f0_pred[:len(mel_pred_mask)]
190
+ f0_pred = f0_pred[mel_pred_mask]
191
+ torch.cuda.is_available() and torch.cuda.empty_cache()
192
+
193
+ if singer:
194
+ data_path = in_path.replace("batch", "singer_data")
195
+ mel_path = data_path[:-4] + "_mel.npy"
196
+ f0_path = data_path[:-4] + "_f0.npy"
197
+ np.save(mel_path, mel_pred)
198
+ np.save(f0_path, f0_pred)
199
+ wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
200
+ return f0_gt, f0_pred, wav_pred
201
+
202
+ def temporary_dict2processed_input(self, item_name, temp_dict, use_crepe=True, thre=0.05):
203
+ '''
204
+ process data in temporary_dicts
205
+ '''
206
+
207
+ binarization_args = hparams['binarization_args']
208
+
209
+ @timeit
210
+ def get_pitch(wav, mel):
211
+ # get ground truth f0 by self.get_pitch_algorithm
212
+ global f0_dict
213
+ if use_crepe:
214
+ md5 = get_md5(wav)
215
+ if f"{md5}_gt" in f0_dict.keys():
216
+ print("load temp crepe f0")
217
+ gt_f0 = np.array(f0_dict[f"{md5}_gt"]["f0"])
218
+ coarse_f0 = np.array(f0_dict[f"{md5}_coarse"]["f0"])
219
+ else:
220
+ torch.cuda.is_available() and torch.cuda.empty_cache()
221
+ gt_f0, coarse_f0 = get_pitch_crepe(wav, mel, hparams, thre)
222
+ f0_dict[f"{md5}_gt"] = {"f0": gt_f0.tolist(), "time": int(time.time())}
223
+ f0_dict[f"{md5}_coarse"] = {"f0": coarse_f0.tolist(), "time": int(time.time())}
224
+ write_temp("./infer_tools/f0_temp.json", f0_dict)
225
+ else:
226
+ gt_f0, coarse_f0 = get_pitch_parselmouth(wav, mel, hparams)
227
+ processed_input['f0'] = gt_f0
228
+ processed_input['pitch'] = coarse_f0
229
+
230
+ def get_align(mel, phone_encoded):
231
+ mel2ph = np.zeros([mel.shape[0]], int)
232
+ start_frame = 0
233
+ ph_durs = mel.shape[0] / phone_encoded.shape[0]
234
+ if hparams['debug']:
235
+ print(mel.shape, phone_encoded.shape, mel.shape[0] / phone_encoded.shape[0])
236
+ for i_ph in range(phone_encoded.shape[0]):
237
+ end_frame = int(i_ph * ph_durs + ph_durs + 0.5)
238
+ mel2ph[start_frame:end_frame + 1] = i_ph + 1
239
+ start_frame = end_frame + 1
240
+
241
+ processed_input['mel2ph'] = mel2ph
242
+
243
+ if hparams['vocoder'] in VOCODERS:
244
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(temp_dict['wav_fn'])
245
+ else:
246
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(temp_dict['wav_fn'])
247
+ processed_input = {
248
+ 'item_name': item_name, 'mel': mel,
249
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
250
+ }
251
+ processed_input = {**temp_dict, **processed_input} # merge two dicts
252
+
253
+ if binarization_args['with_f0']:
254
+ get_pitch(wav, mel)
255
+ if binarization_args['with_hubert']:
256
+ st = time.time()
257
+ hubert_encoded = processed_input['hubert'] = self.hubert.encode(temp_dict['wav_fn'])
258
+ et = time.time()
259
+ dev = 'cuda' if hparams['hubert_gpu'] and torch.cuda.is_available() else 'cpu'
260
+ print(f'hubert (on {dev}) time used {et - st}')
261
+
262
+ if binarization_args['with_align']:
263
+ get_align(mel, hubert_encoded)
264
+ return processed_input
265
+
266
+ def pre(self, wav_fn, accelerate, use_crepe=True, thre=0.05):
267
+ if isinstance(wav_fn, BytesIO):
268
+ item_name = self.project_name
269
+ else:
270
+ song_info = wav_fn.split('/')
271
+ item_name = song_info[-1].split('.')[-2]
272
+ temp_dict = {'wav_fn': wav_fn, 'spk_id': self.project_name}
273
+
274
+ temp_dict = self.temporary_dict2processed_input(item_name, temp_dict, use_crepe, thre)
275
+ hparams['pndm_speedup'] = accelerate
276
+ batch = processed_input2batch([getitem(temp_dict)])
277
+ return batch
278
+
279
+
280
+ def getitem(item):
281
+ max_frames = hparams['max_frames']
282
+ spec = torch.Tensor(item['mel'])[:max_frames]
283
+ energy = (spec.exp() ** 2).sum(-1).sqrt()
284
+ mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
285
+ f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
286
+ hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
287
+ pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
288
+ sample = {
289
+ "item_name": item['item_name'],
290
+ "hubert": hubert,
291
+ "mel": spec,
292
+ "pitch": pitch,
293
+ "energy": energy,
294
+ "f0": f0,
295
+ "uv": uv,
296
+ "mel2ph": mel2ph,
297
+ "mel_nonpadding": spec.abs().sum(-1) > 0,
298
+ }
299
+ return sample
300
+
301
+
302
+ def processed_input2batch(samples):
303
+ '''
304
+ Args:
305
+ samples: one batch of processed_input
306
+ NOTE:
307
+ the batch size is controlled by hparams['max_sentences']
308
+ '''
309
+ if len(samples) == 0:
310
+ return {}
311
+ item_names = [s['item_name'] for s in samples]
312
+ hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
313
+ f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
314
+ pitch = utils.collate_1d([s['pitch'] for s in samples])
315
+ uv = utils.collate_1d([s['uv'] for s in samples])
316
+ energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
317
+ mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
318
+ if samples[0]['mel2ph'] is not None else None
319
+ mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
320
+ mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
321
+
322
+ batch = {
323
+ 'item_name': item_names,
324
+ 'nsamples': len(samples),
325
+ 'hubert': hubert,
326
+ 'mels': mels,
327
+ 'mel_lengths': mel_lengths,
328
+ 'mel2ph': mel2ph,
329
+ 'energy': energy,
330
+ 'pitch': pitch,
331
+ 'f0': f0,
332
+ 'uv': uv,
333
+ }
334
+ return batch
infer_tools/slicer.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from scipy.ndimage import maximum_filter1d, uniform_filter1d
7
+
8
+
9
+ def timeit(func):
10
+ def run(*args, **kwargs):
11
+ t = time.time()
12
+ res = func(*args, **kwargs)
13
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
14
+ return res
15
+
16
+ return run
17
+
18
+
19
+ # @timeit
20
+ def _window_maximum(arr, win_sz):
21
+ return maximum_filter1d(arr, size=win_sz)[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
22
+
23
+
24
+ # @timeit
25
+ def _window_rms(arr, win_sz):
26
+ filtered = np.sqrt(uniform_filter1d(np.power(arr, 2), win_sz) - np.power(uniform_filter1d(arr, win_sz), 2))
27
+ return filtered[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
28
+
29
+
30
+ def level2db(levels, eps=1e-12):
31
+ return 20 * np.log10(np.clip(levels, a_min=eps, a_max=1))
32
+
33
+
34
+ def _apply_slice(audio, begin, end):
35
+ if len(audio.shape) > 1:
36
+ return audio[:, begin: end]
37
+ else:
38
+ return audio[begin: end]
39
+
40
+
41
+ class Slicer:
42
+ def __init__(self,
43
+ sr: int,
44
+ db_threshold: float = -40,
45
+ min_length: int = 5000,
46
+ win_l: int = 300,
47
+ win_s: int = 20,
48
+ max_silence_kept: int = 500):
49
+ self.db_threshold = db_threshold
50
+ self.min_samples = round(sr * min_length / 1000)
51
+ self.win_ln = round(sr * win_l / 1000)
52
+ self.win_sn = round(sr * win_s / 1000)
53
+ self.max_silence = round(sr * max_silence_kept / 1000)
54
+ if not self.min_samples >= self.win_ln >= self.win_sn:
55
+ raise ValueError('The following condition must be satisfied: min_length >= win_l >= win_s')
56
+ if not self.max_silence >= self.win_sn:
57
+ raise ValueError('The following condition must be satisfied: max_silence_kept >= win_s')
58
+
59
+ @timeit
60
+ def slice(self, audio):
61
+ samples = audio
62
+ if samples.shape[0] <= self.min_samples:
63
+ return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
64
+ # get absolute amplitudes
65
+ abs_amp = np.abs(samples - np.mean(samples))
66
+ # calculate local maximum with large window
67
+ win_max_db = level2db(_window_maximum(abs_amp, win_sz=self.win_ln))
68
+ sil_tags = []
69
+ left = right = 0
70
+ while right < win_max_db.shape[0]:
71
+ if win_max_db[right] < self.db_threshold:
72
+ right += 1
73
+ elif left == right:
74
+ left += 1
75
+ right += 1
76
+ else:
77
+ if left == 0:
78
+ split_loc_l = left
79
+ else:
80
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
81
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
82
+ split_win_l = left + np.argmin(rms_db_left)
83
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
84
+ if len(sil_tags) != 0 and split_loc_l - sil_tags[-1][1] < self.min_samples and right < win_max_db.shape[
85
+ 0] - 1:
86
+ right += 1
87
+ left = right
88
+ continue
89
+ if right == win_max_db.shape[0] - 1:
90
+ split_loc_r = right + self.win_ln
91
+ else:
92
+ sil_right_n = min(self.max_silence, (right + self.win_ln - left) // 2)
93
+ rms_db_right = level2db(_window_rms(samples[right + self.win_ln - sil_right_n: right + self.win_ln],
94
+ win_sz=self.win_sn))
95
+ split_win_r = right + self.win_ln - sil_right_n + np.argmin(rms_db_right)
96
+ split_loc_r = split_win_r + np.argmin(abs_amp[split_win_r: split_win_r + self.win_sn])
97
+ sil_tags.append((split_loc_l, split_loc_r))
98
+ right += 1
99
+ left = right
100
+ if left != right:
101
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
102
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
103
+ split_win_l = left + np.argmin(rms_db_left)
104
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
105
+ sil_tags.append((split_loc_l, samples.shape[0]))
106
+ if len(sil_tags) == 0:
107
+ return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
108
+ else:
109
+ chunks = []
110
+ # 第一段静音并非从头开始,补上有声片段
111
+ if sil_tags[0][0]:
112
+ chunks.append({"slice": False, "split_time": f"0,{sil_tags[0][0]}"})
113
+ for i in range(0, len(sil_tags)):
114
+ # 标识有声片段(跳过第一段)
115
+ if i:
116
+ chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1]},{sil_tags[i][0]}"})
117
+ # 标识所有静音片段
118
+ chunks.append({"slice": True, "split_time": f"{sil_tags[i][0]},{sil_tags[i][1]}"})
119
+ # 最后一段静音并非结尾,补上结尾片段
120
+ if sil_tags[-1][1] != len(audio):
121
+ chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1]},{len(audio)}"})
122
+ chunk_dict = {}
123
+ for i in range(len(chunks)):
124
+ chunk_dict[str(i)] = chunks[i]
125
+ return chunk_dict
126
+
127
+
128
+ def cut(audio_path, db_thresh=-30, min_len=5000, win_l=300, win_s=20, max_sil_kept=500):
129
+ audio, sr = torchaudio.load(audio_path)
130
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
131
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
132
+ audio = audio.cpu().numpy()[0]
133
+
134
+ slicer = Slicer(
135
+ sr=sr,
136
+ db_threshold=db_thresh,
137
+ min_length=min_len,
138
+ win_l=win_l,
139
+ win_s=win_s,
140
+ max_silence_kept=max_sil_kept
141
+ )
142
+ chunks = slicer.slice(audio)
143
+ return chunks
144
+
145
+
146
+ def chunks2audio(audio_path, chunks):
147
+ chunks = dict(chunks)
148
+ audio, sr = torchaudio.load(audio_path)
149
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
150
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
151
+ audio = audio.cpu().numpy()[0]
152
+ result = []
153
+ for k, v in chunks.items():
154
+ tag = v["split_time"].split(",")
155
+ result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
156
+ return result, sr
157
+
158
+
modules/commons/common_layers.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+ import torch.onnx.operators
6
+ import torch.nn.functional as F
7
+ import utils
8
+
9
+
10
+ class Reshape(nn.Module):
11
+ def __init__(self, *args):
12
+ super(Reshape, self).__init__()
13
+ self.shape = args
14
+
15
+ def forward(self, x):
16
+ return x.view(self.shape)
17
+
18
+
19
+ class Permute(nn.Module):
20
+ def __init__(self, *args):
21
+ super(Permute, self).__init__()
22
+ self.args = args
23
+
24
+ def forward(self, x):
25
+ return x.permute(self.args)
26
+
27
+
28
+ class LinearNorm(torch.nn.Module):
29
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
30
+ super(LinearNorm, self).__init__()
31
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
+
33
+ torch.nn.init.xavier_uniform_(
34
+ self.linear_layer.weight,
35
+ gain=torch.nn.init.calculate_gain(w_init_gain))
36
+
37
+ def forward(self, x):
38
+ return self.linear_layer(x)
39
+
40
+
41
+ class ConvNorm(torch.nn.Module):
42
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
43
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
44
+ super(ConvNorm, self).__init__()
45
+ if padding is None:
46
+ assert (kernel_size % 2 == 1)
47
+ padding = int(dilation * (kernel_size - 1) / 2)
48
+
49
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
50
+ kernel_size=kernel_size, stride=stride,
51
+ padding=padding, dilation=dilation,
52
+ bias=bias)
53
+
54
+ torch.nn.init.xavier_uniform_(
55
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56
+
57
+ def forward(self, signal):
58
+ conv_signal = self.conv(signal)
59
+ return conv_signal
60
+
61
+
62
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None):
63
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
64
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
65
+ if padding_idx is not None:
66
+ nn.init.constant_(m.weight[padding_idx], 0)
67
+ return m
68
+
69
+
70
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
71
+ if not export and torch.cuda.is_available():
72
+ try:
73
+ from apex.normalization import FusedLayerNorm
74
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
75
+ except ImportError:
76
+ pass
77
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78
+
79
+
80
+ def Linear(in_features, out_features, bias=True):
81
+ m = nn.Linear(in_features, out_features, bias)
82
+ nn.init.xavier_uniform_(m.weight)
83
+ if bias:
84
+ nn.init.constant_(m.bias, 0.)
85
+ return m
86
+
87
+
88
+ class SinusoidalPositionalEmbedding(nn.Module):
89
+ """This module produces sinusoidal positional embeddings of any length.
90
+
91
+ Padding symbols are ignored.
92
+ """
93
+
94
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
95
+ super().__init__()
96
+ self.embedding_dim = embedding_dim
97
+ self.padding_idx = padding_idx
98
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
99
+ init_size,
100
+ embedding_dim,
101
+ padding_idx,
102
+ )
103
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
104
+
105
+ @staticmethod
106
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
107
+ """Build sinusoidal embeddings.
108
+
109
+ This matches the implementation in tensor2tensor, but differs slightly
110
+ from the description in Section 3.5 of "Attention Is All You Need".
111
+ """
112
+ half_dim = embedding_dim // 2
113
+ emb = math.log(10000) / (half_dim - 1)
114
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
115
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
116
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
117
+ if embedding_dim % 2 == 1:
118
+ # zero pad
119
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
120
+ if padding_idx is not None:
121
+ emb[padding_idx, :] = 0
122
+ return emb
123
+
124
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
125
+ """Input is expected to be of size [bsz x seqlen]."""
126
+ bsz, seq_len = input.shape[:2]
127
+ max_pos = self.padding_idx + 1 + seq_len
128
+ if self.weights is None or max_pos > self.weights.size(0):
129
+ # recompute/expand embeddings if needed
130
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
131
+ max_pos,
132
+ self.embedding_dim,
133
+ self.padding_idx,
134
+ )
135
+ self.weights = self.weights.to(self._float_tensor)
136
+
137
+ if incremental_state is not None:
138
+ # positions is the same for every token when decoding a single step
139
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
140
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
141
+
142
+ positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
143
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
144
+
145
+ def max_positions(self):
146
+ """Maximum number of supported positions."""
147
+ return int(1e5) # an arbitrary large number
148
+
149
+
150
+ class ConvTBC(nn.Module):
151
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
152
+ super(ConvTBC, self).__init__()
153
+ self.in_channels = in_channels
154
+ self.out_channels = out_channels
155
+ self.kernel_size = kernel_size
156
+ self.padding = padding
157
+
158
+ self.weight = torch.nn.Parameter(torch.Tensor(
159
+ self.kernel_size, in_channels, out_channels))
160
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
161
+
162
+ def forward(self, input):
163
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
164
+
165
+
166
+ class MultiheadAttention(nn.Module):
167
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
168
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
169
+ encoder_decoder_attention=False):
170
+ super().__init__()
171
+ self.embed_dim = embed_dim
172
+ self.kdim = kdim if kdim is not None else embed_dim
173
+ self.vdim = vdim if vdim is not None else embed_dim
174
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
175
+
176
+ self.num_heads = num_heads
177
+ self.dropout = dropout
178
+ self.head_dim = embed_dim // num_heads
179
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
180
+ self.scaling = self.head_dim ** -0.5
181
+
182
+ self.self_attention = self_attention
183
+ self.encoder_decoder_attention = encoder_decoder_attention
184
+
185
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
186
+ 'value to be of the same size'
187
+
188
+ if self.qkv_same_dim:
189
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
190
+ else:
191
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
192
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
193
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
194
+
195
+ if bias:
196
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
197
+ else:
198
+ self.register_parameter('in_proj_bias', None)
199
+
200
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
201
+
202
+ if add_bias_kv:
203
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
204
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
205
+ else:
206
+ self.bias_k = self.bias_v = None
207
+
208
+ self.add_zero_attn = add_zero_attn
209
+
210
+ self.reset_parameters()
211
+
212
+ self.enable_torch_version = False
213
+ if hasattr(F, "multi_head_attention_forward"):
214
+ self.enable_torch_version = True
215
+ else:
216
+ self.enable_torch_version = False
217
+ self.last_attn_probs = None
218
+
219
+ def reset_parameters(self):
220
+ if self.qkv_same_dim:
221
+ nn.init.xavier_uniform_(self.in_proj_weight)
222
+ else:
223
+ nn.init.xavier_uniform_(self.k_proj_weight)
224
+ nn.init.xavier_uniform_(self.v_proj_weight)
225
+ nn.init.xavier_uniform_(self.q_proj_weight)
226
+
227
+ nn.init.xavier_uniform_(self.out_proj.weight)
228
+ if self.in_proj_bias is not None:
229
+ nn.init.constant_(self.in_proj_bias, 0.)
230
+ nn.init.constant_(self.out_proj.bias, 0.)
231
+ if self.bias_k is not None:
232
+ nn.init.xavier_normal_(self.bias_k)
233
+ if self.bias_v is not None:
234
+ nn.init.xavier_normal_(self.bias_v)
235
+
236
+ def forward(
237
+ self,
238
+ query, key, value,
239
+ key_padding_mask=None,
240
+ incremental_state=None,
241
+ need_weights=True,
242
+ static_kv=False,
243
+ attn_mask=None,
244
+ before_softmax=False,
245
+ need_head_weights=False,
246
+ enc_dec_attn_constraint_mask=None,
247
+ reset_attn_weight=None
248
+ ):
249
+ """Input shape: Time x Batch x Channel
250
+
251
+ Args:
252
+ key_padding_mask (ByteTensor, optional): mask to exclude
253
+ keys that are pads, of shape `(batch, src_len)`, where
254
+ padding elements are indicated by 1s.
255
+ need_weights (bool, optional): return the attention weights,
256
+ averaged over heads (default: False).
257
+ attn_mask (ByteTensor, optional): typically used to
258
+ implement causal attention, where the mask prevents the
259
+ attention from looking forward in time (default: None).
260
+ before_softmax (bool, optional): return the raw attention
261
+ weights and values before the attention softmax.
262
+ need_head_weights (bool, optional): return the attention
263
+ weights for each head. Implies *need_weights*. Default:
264
+ return the average attention weights over all heads.
265
+ """
266
+ if need_head_weights:
267
+ need_weights = True
268
+
269
+ tgt_len, bsz, embed_dim = query.size()
270
+ assert embed_dim == self.embed_dim
271
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
272
+
273
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
274
+ if self.qkv_same_dim:
275
+ return F.multi_head_attention_forward(query, key, value,
276
+ self.embed_dim, self.num_heads,
277
+ self.in_proj_weight,
278
+ self.in_proj_bias, self.bias_k, self.bias_v,
279
+ self.add_zero_attn, self.dropout,
280
+ self.out_proj.weight, self.out_proj.bias,
281
+ self.training, key_padding_mask, need_weights,
282
+ attn_mask)
283
+ else:
284
+ return F.multi_head_attention_forward(query, key, value,
285
+ self.embed_dim, self.num_heads,
286
+ torch.empty([0]),
287
+ self.in_proj_bias, self.bias_k, self.bias_v,
288
+ self.add_zero_attn, self.dropout,
289
+ self.out_proj.weight, self.out_proj.bias,
290
+ self.training, key_padding_mask, need_weights,
291
+ attn_mask, use_separate_proj_weight=True,
292
+ q_proj_weight=self.q_proj_weight,
293
+ k_proj_weight=self.k_proj_weight,
294
+ v_proj_weight=self.v_proj_weight)
295
+
296
+ if incremental_state is not None:
297
+ print('Not implemented error.')
298
+ exit()
299
+ else:
300
+ saved_state = None
301
+
302
+ if self.self_attention:
303
+ # self-attention
304
+ q, k, v = self.in_proj_qkv(query)
305
+ elif self.encoder_decoder_attention:
306
+ # encoder-decoder attention
307
+ q = self.in_proj_q(query)
308
+ if key is None:
309
+ assert value is None
310
+ k = v = None
311
+ else:
312
+ k = self.in_proj_k(key)
313
+ v = self.in_proj_v(key)
314
+
315
+ else:
316
+ q = self.in_proj_q(query)
317
+ k = self.in_proj_k(key)
318
+ v = self.in_proj_v(value)
319
+ q *= self.scaling
320
+
321
+ if self.bias_k is not None:
322
+ assert self.bias_v is not None
323
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
324
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
325
+ if attn_mask is not None:
326
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
327
+ if key_padding_mask is not None:
328
+ key_padding_mask = torch.cat(
329
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
330
+
331
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
332
+ if k is not None:
333
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
334
+ if v is not None:
335
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
336
+
337
+ if saved_state is not None:
338
+ print('Not implemented error.')
339
+ exit()
340
+
341
+ src_len = k.size(1)
342
+
343
+ # This is part of a workaround to get around fork/join parallelism
344
+ # not supporting Optional types.
345
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
346
+ key_padding_mask = None
347
+
348
+ if key_padding_mask is not None:
349
+ assert key_padding_mask.size(0) == bsz
350
+ assert key_padding_mask.size(1) == src_len
351
+
352
+ if self.add_zero_attn:
353
+ src_len += 1
354
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
355
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
356
+ if attn_mask is not None:
357
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
358
+ if key_padding_mask is not None:
359
+ key_padding_mask = torch.cat(
360
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
361
+
362
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
363
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
364
+
365
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
366
+
367
+ if attn_mask is not None:
368
+ if len(attn_mask.shape) == 2:
369
+ attn_mask = attn_mask.unsqueeze(0)
370
+ elif len(attn_mask.shape) == 3:
371
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
372
+ bsz * self.num_heads, tgt_len, src_len)
373
+ attn_weights = attn_weights + attn_mask
374
+
375
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
376
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
377
+ attn_weights = attn_weights.masked_fill(
378
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
379
+ -1e9,
380
+ )
381
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
382
+
383
+ if key_padding_mask is not None:
384
+ # don't attend to padding symbols
385
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
386
+ attn_weights = attn_weights.masked_fill(
387
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
388
+ -1e9,
389
+ )
390
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
391
+
392
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
393
+
394
+ if before_softmax:
395
+ return attn_weights, v
396
+
397
+ attn_weights_float = utils.softmax(attn_weights, dim=-1)
398
+ attn_weights = attn_weights_float.type_as(attn_weights)
399
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
400
+
401
+ if reset_attn_weight is not None:
402
+ if reset_attn_weight:
403
+ self.last_attn_probs = attn_probs.detach()
404
+ else:
405
+ assert self.last_attn_probs is not None
406
+ attn_probs = self.last_attn_probs
407
+ attn = torch.bmm(attn_probs, v)
408
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
409
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
410
+ attn = self.out_proj(attn)
411
+
412
+ if need_weights:
413
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
414
+ if not need_head_weights:
415
+ # average attention weights over heads
416
+ attn_weights = attn_weights.mean(dim=0)
417
+ else:
418
+ attn_weights = None
419
+
420
+ return attn, (attn_weights, attn_logits)
421
+
422
+ def in_proj_qkv(self, query):
423
+ return self._in_proj(query).chunk(3, dim=-1)
424
+
425
+ def in_proj_q(self, query):
426
+ if self.qkv_same_dim:
427
+ return self._in_proj(query, end=self.embed_dim)
428
+ else:
429
+ bias = self.in_proj_bias
430
+ if bias is not None:
431
+ bias = bias[:self.embed_dim]
432
+ return F.linear(query, self.q_proj_weight, bias)
433
+
434
+ def in_proj_k(self, key):
435
+ if self.qkv_same_dim:
436
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
437
+ else:
438
+ weight = self.k_proj_weight
439
+ bias = self.in_proj_bias
440
+ if bias is not None:
441
+ bias = bias[self.embed_dim:2 * self.embed_dim]
442
+ return F.linear(key, weight, bias)
443
+
444
+ def in_proj_v(self, value):
445
+ if self.qkv_same_dim:
446
+ return self._in_proj(value, start=2 * self.embed_dim)
447
+ else:
448
+ weight = self.v_proj_weight
449
+ bias = self.in_proj_bias
450
+ if bias is not None:
451
+ bias = bias[2 * self.embed_dim:]
452
+ return F.linear(value, weight, bias)
453
+
454
+ def _in_proj(self, input, start=0, end=None):
455
+ weight = self.in_proj_weight
456
+ bias = self.in_proj_bias
457
+ weight = weight[start:end, :]
458
+ if bias is not None:
459
+ bias = bias[start:end]
460
+ return F.linear(input, weight, bias)
461
+
462
+
463
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
464
+ return attn_weights
465
+
466
+
467
+ class Swish(torch.autograd.Function):
468
+ @staticmethod
469
+ def forward(ctx, i):
470
+ result = i * torch.sigmoid(i)
471
+ ctx.save_for_backward(i)
472
+ return result
473
+
474
+ @staticmethod
475
+ def backward(ctx, grad_output):
476
+ i = ctx.saved_variables[0]
477
+ sigmoid_i = torch.sigmoid(i)
478
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
479
+
480
+
481
+ class CustomSwish(nn.Module):
482
+ def forward(self, input_tensor):
483
+ return Swish.apply(input_tensor)
484
+
485
+ class Mish(nn.Module):
486
+ def forward(self, x):
487
+ return x * torch.tanh(F.softplus(x))
488
+
489
+ class TransformerFFNLayer(nn.Module):
490
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
491
+ super().__init__()
492
+ self.kernel_size = kernel_size
493
+ self.dropout = dropout
494
+ self.act = act
495
+ if padding == 'SAME':
496
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
497
+ elif padding == 'LEFT':
498
+ self.ffn_1 = nn.Sequential(
499
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
500
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
501
+ )
502
+ self.ffn_2 = Linear(filter_size, hidden_size)
503
+ if self.act == 'swish':
504
+ self.swish_fn = CustomSwish()
505
+
506
+ def forward(self, x, incremental_state=None):
507
+ # x: T x B x C
508
+ if incremental_state is not None:
509
+ assert incremental_state is None, 'Nar-generation does not allow this.'
510
+ exit(1)
511
+
512
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
513
+ x = x * self.kernel_size ** -0.5
514
+
515
+ if incremental_state is not None:
516
+ x = x[-1:]
517
+ if self.act == 'gelu':
518
+ x = F.gelu(x)
519
+ if self.act == 'relu':
520
+ x = F.relu(x)
521
+ if self.act == 'swish':
522
+ x = self.swish_fn(x)
523
+ x = F.dropout(x, self.dropout, training=self.training)
524
+ x = self.ffn_2(x)
525
+ return x
526
+
527
+
528
+ class BatchNorm1dTBC(nn.Module):
529
+ def __init__(self, c):
530
+ super(BatchNorm1dTBC, self).__init__()
531
+ self.bn = nn.BatchNorm1d(c)
532
+
533
+ def forward(self, x):
534
+ """
535
+
536
+ :param x: [T, B, C]
537
+ :return: [T, B, C]
538
+ """
539
+ x = x.permute(1, 2, 0) # [B, C, T]
540
+ x = self.bn(x) # [B, C, T]
541
+ x = x.permute(2, 0, 1) # [T, B, C]
542
+ return x
543
+
544
+
545
+ class EncSALayer(nn.Module):
546
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
547
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
548
+ super().__init__()
549
+ self.c = c
550
+ self.dropout = dropout
551
+ self.num_heads = num_heads
552
+ if num_heads > 0:
553
+ if norm == 'ln':
554
+ self.layer_norm1 = LayerNorm(c)
555
+ elif norm == 'bn':
556
+ self.layer_norm1 = BatchNorm1dTBC(c)
557
+ self.self_attn = MultiheadAttention(
558
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
559
+ )
560
+ if norm == 'ln':
561
+ self.layer_norm2 = LayerNorm(c)
562
+ elif norm == 'bn':
563
+ self.layer_norm2 = BatchNorm1dTBC(c)
564
+ self.ffn = TransformerFFNLayer(
565
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
566
+
567
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
568
+ layer_norm_training = kwargs.get('layer_norm_training', None)
569
+ if layer_norm_training is not None:
570
+ self.layer_norm1.training = layer_norm_training
571
+ self.layer_norm2.training = layer_norm_training
572
+ if self.num_heads > 0:
573
+ residual = x
574
+ x = self.layer_norm1(x)
575
+ x, _, = self.self_attn(
576
+ query=x,
577
+ key=x,
578
+ value=x,
579
+ key_padding_mask=encoder_padding_mask
580
+ )
581
+ x = F.dropout(x, self.dropout, training=self.training)
582
+ x = residual + x
583
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
584
+
585
+ residual = x
586
+ x = self.layer_norm2(x)
587
+ x = self.ffn(x)
588
+ x = F.dropout(x, self.dropout, training=self.training)
589
+ x = residual + x
590
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
591
+ return x
592
+
593
+
594
+ class DecSALayer(nn.Module):
595
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
596
+ super().__init__()
597
+ self.c = c
598
+ self.dropout = dropout
599
+ self.layer_norm1 = LayerNorm(c)
600
+ self.self_attn = MultiheadAttention(
601
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
602
+ )
603
+ self.layer_norm2 = LayerNorm(c)
604
+ self.encoder_attn = MultiheadAttention(
605
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
606
+ )
607
+ self.layer_norm3 = LayerNorm(c)
608
+ self.ffn = TransformerFFNLayer(
609
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
610
+
611
+ def forward(
612
+ self,
613
+ x,
614
+ encoder_out=None,
615
+ encoder_padding_mask=None,
616
+ incremental_state=None,
617
+ self_attn_mask=None,
618
+ self_attn_padding_mask=None,
619
+ attn_out=None,
620
+ reset_attn_weight=None,
621
+ **kwargs,
622
+ ):
623
+ layer_norm_training = kwargs.get('layer_norm_training', None)
624
+ if layer_norm_training is not None:
625
+ self.layer_norm1.training = layer_norm_training
626
+ self.layer_norm2.training = layer_norm_training
627
+ self.layer_norm3.training = layer_norm_training
628
+ residual = x
629
+ x = self.layer_norm1(x)
630
+ x, _ = self.self_attn(
631
+ query=x,
632
+ key=x,
633
+ value=x,
634
+ key_padding_mask=self_attn_padding_mask,
635
+ incremental_state=incremental_state,
636
+ attn_mask=self_attn_mask
637
+ )
638
+ x = F.dropout(x, self.dropout, training=self.training)
639
+ x = residual + x
640
+
641
+ residual = x
642
+ x = self.layer_norm2(x)
643
+ if encoder_out is not None:
644
+ x, attn = self.encoder_attn(
645
+ query=x,
646
+ key=encoder_out,
647
+ value=encoder_out,
648
+ key_padding_mask=encoder_padding_mask,
649
+ incremental_state=incremental_state,
650
+ static_kv=True,
651
+ enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
652
+ reset_attn_weight=reset_attn_weight
653
+ )
654
+ attn_logits = attn[1]
655
+ else:
656
+ assert attn_out is not None
657
+ x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
658
+ attn_logits = None
659
+ x = F.dropout(x, self.dropout, training=self.training)
660
+ x = residual + x
661
+
662
+ residual = x
663
+ x = self.layer_norm3(x)
664
+ x = self.ffn(x, incremental_state=incremental_state)
665
+ x = F.dropout(x, self.dropout, training=self.training)
666
+ x = residual + x
667
+ # if len(attn_logits.size()) > 3:
668
+ # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
669
+ # attn_logits = attn_logits.gather(1,
670
+ # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
671
+ return x, attn_logits
modules/commons/espnet_positional_embedding.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class PositionalEncoding(torch.nn.Module):
6
+ """Positional encoding.
7
+ Args:
8
+ d_model (int): Embedding dimension.
9
+ dropout_rate (float): Dropout rate.
10
+ max_len (int): Maximum input length.
11
+ reverse (bool): Whether to reverse the input position.
12
+ """
13
+
14
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
15
+ """Construct an PositionalEncoding object."""
16
+ super(PositionalEncoding, self).__init__()
17
+ self.d_model = d_model
18
+ self.reverse = reverse
19
+ self.xscale = math.sqrt(self.d_model)
20
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
21
+ self.pe = None
22
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
23
+
24
+ def extend_pe(self, x):
25
+ """Reset the positional encodings."""
26
+ if self.pe is not None:
27
+ if self.pe.size(1) >= x.size(1):
28
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
29
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
30
+ return
31
+ pe = torch.zeros(x.size(1), self.d_model)
32
+ if self.reverse:
33
+ position = torch.arange(
34
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
35
+ ).unsqueeze(1)
36
+ else:
37
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
38
+ div_term = torch.exp(
39
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
40
+ * -(math.log(10000.0) / self.d_model)
41
+ )
42
+ pe[:, 0::2] = torch.sin(position * div_term)
43
+ pe[:, 1::2] = torch.cos(position * div_term)
44
+ pe = pe.unsqueeze(0)
45
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ """Add positional encoding.
49
+ Args:
50
+ x (torch.Tensor): Input tensor (batch, time, `*`).
51
+ Returns:
52
+ torch.Tensor: Encoded tensor (batch, time, `*`).
53
+ """
54
+ self.extend_pe(x)
55
+ x = x * self.xscale + self.pe[:, : x.size(1)]
56
+ return self.dropout(x)
57
+
58
+
59
+ class ScaledPositionalEncoding(PositionalEncoding):
60
+ """Scaled positional encoding module.
61
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
62
+ Args:
63
+ d_model (int): Embedding dimension.
64
+ dropout_rate (float): Dropout rate.
65
+ max_len (int): Maximum input length.
66
+ """
67
+
68
+ def __init__(self, d_model, dropout_rate, max_len=5000):
69
+ """Initialize class."""
70
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
71
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
72
+
73
+ def reset_parameters(self):
74
+ """Reset parameters."""
75
+ self.alpha.data = torch.tensor(1.0)
76
+
77
+ def forward(self, x):
78
+ """Add positional encoding.
79
+ Args:
80
+ x (torch.Tensor): Input tensor (batch, time, `*`).
81
+ Returns:
82
+ torch.Tensor: Encoded tensor (batch, time, `*`).
83
+ """
84
+ self.extend_pe(x)
85
+ x = x + self.alpha * self.pe[:, : x.size(1)]
86
+ return self.dropout(x)
87
+
88
+
89
+ class RelPositionalEncoding(PositionalEncoding):
90
+ """Relative positional encoding module.
91
+ See : Appendix B in https://arxiv.org/abs/1901.02860
92
+ Args:
93
+ d_model (int): Embedding dimension.
94
+ dropout_rate (float): Dropout rate.
95
+ max_len (int): Maximum input length.
96
+ """
97
+
98
+ def __init__(self, d_model, dropout_rate, max_len=5000):
99
+ """Initialize class."""
100
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
101
+
102
+ def forward(self, x):
103
+ """Compute positional encoding.
104
+ Args:
105
+ x (torch.Tensor): Input tensor (batch, time, `*`).
106
+ Returns:
107
+ torch.Tensor: Encoded tensor (batch, time, `*`).
108
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
109
+ """
110
+ self.extend_pe(x)
111
+ x = x * self.xscale
112
+ pos_emb = self.pe[:, : x.size(1)]
113
+ return self.dropout(x) + self.dropout(pos_emb)
modules/commons/ssim.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # '''
2
+ # https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
3
+ # '''
4
+ #
5
+ # import torch
6
+ # import torch.jit
7
+ # import torch.nn.functional as F
8
+ #
9
+ #
10
+ # @torch.jit.script
11
+ # def create_window(window_size: int, sigma: float, channel: int):
12
+ # '''
13
+ # Create 1-D gauss kernel
14
+ # :param window_size: the size of gauss kernel
15
+ # :param sigma: sigma of normal distribution
16
+ # :param channel: input channel
17
+ # :return: 1D kernel
18
+ # '''
19
+ # coords = torch.arange(window_size, dtype=torch.float)
20
+ # coords -= window_size // 2
21
+ #
22
+ # g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
23
+ # g /= g.sum()
24
+ #
25
+ # g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
26
+ # return g
27
+ #
28
+ #
29
+ # @torch.jit.script
30
+ # def _gaussian_filter(x, window_1d, use_padding: bool):
31
+ # '''
32
+ # Blur input with 1-D kernel
33
+ # :param x: batch of tensors to be blured
34
+ # :param window_1d: 1-D gauss kernel
35
+ # :param use_padding: padding image before conv
36
+ # :return: blured tensors
37
+ # '''
38
+ # C = x.shape[1]
39
+ # padding = 0
40
+ # if use_padding:
41
+ # window_size = window_1d.shape[3]
42
+ # padding = window_size // 2
43
+ # out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
44
+ # out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
45
+ # return out
46
+ #
47
+ #
48
+ # @torch.jit.script
49
+ # def ssim(X, Y, window, data_range: float, use_padding: bool = False):
50
+ # '''
51
+ # Calculate ssim index for X and Y
52
+ # :param X: images [B, C, H, N_bins]
53
+ # :param Y: images [B, C, H, N_bins]
54
+ # :param window: 1-D gauss kernel
55
+ # :param data_range: value range of input images. (usually 1.0 or 255)
56
+ # :param use_padding: padding image before conv
57
+ # :return:
58
+ # '''
59
+ #
60
+ # K1 = 0.01
61
+ # K2 = 0.03
62
+ # compensation = 1.0
63
+ #
64
+ # C1 = (K1 * data_range) ** 2
65
+ # C2 = (K2 * data_range) ** 2
66
+ #
67
+ # mu1 = _gaussian_filter(X, window, use_padding)
68
+ # mu2 = _gaussian_filter(Y, window, use_padding)
69
+ # sigma1_sq = _gaussian_filter(X * X, window, use_padding)
70
+ # sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
71
+ # sigma12 = _gaussian_filter(X * Y, window, use_padding)
72
+ #
73
+ # mu1_sq = mu1.pow(2)
74
+ # mu2_sq = mu2.pow(2)
75
+ # mu1_mu2 = mu1 * mu2
76
+ #
77
+ # sigma1_sq = compensation * (sigma1_sq - mu1_sq)
78
+ # sigma2_sq = compensation * (sigma2_sq - mu2_sq)
79
+ # sigma12 = compensation * (sigma12 - mu1_mu2)
80
+ #
81
+ # cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
82
+ # # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
83
+ # cs_map = cs_map.clamp_min(0.)
84
+ # ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
85
+ #
86
+ # ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
87
+ # cs = cs_map.mean(dim=(1, 2, 3))
88
+ #
89
+ # return ssim_val, cs
90
+ #
91
+ #
92
+ # @torch.jit.script
93
+ # def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
94
+ # '''
95
+ # interface of ms-ssim
96
+ # :param X: a batch of images, (N,C,H,W)
97
+ # :param Y: a batch of images, (N,C,H,W)
98
+ # :param window: 1-D gauss kernel
99
+ # :param data_range: value range of input images. (usually 1.0 or 255)
100
+ # :param weights: weights for different levels
101
+ # :param use_padding: padding image before conv
102
+ # :param eps: use for avoid grad nan.
103
+ # :return:
104
+ # '''
105
+ # levels = weights.shape[0]
106
+ # cs_vals = []
107
+ # ssim_vals = []
108
+ # for _ in range(levels):
109
+ # ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
110
+ # # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
111
+ # ssim_val = ssim_val.clamp_min(eps)
112
+ # cs = cs.clamp_min(eps)
113
+ # cs_vals.append(cs)
114
+ #
115
+ # ssim_vals.append(ssim_val)
116
+ # padding = (X.shape[2] % 2, X.shape[3] % 2)
117
+ # X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
118
+ # Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
119
+ #
120
+ # cs_vals = torch.stack(cs_vals, dim=0)
121
+ # ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
122
+ # return ms_ssim_val
123
+ #
124
+ #
125
+ # class SSIM(torch.jit.ScriptModule):
126
+ # __constants__ = ['data_range', 'use_padding']
127
+ #
128
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
129
+ # '''
130
+ # :param window_size: the size of gauss kernel
131
+ # :param window_sigma: sigma of normal distribution
132
+ # :param data_range: value range of input images. (usually 1.0 or 255)
133
+ # :param channel: input channels (default: 3)
134
+ # :param use_padding: padding image before conv
135
+ # '''
136
+ # super().__init__()
137
+ # assert window_size % 2 == 1, 'Window size must be odd.'
138
+ # window = create_window(window_size, window_sigma, channel)
139
+ # self.register_buffer('window', window)
140
+ # self.data_range = data_range
141
+ # self.use_padding = use_padding
142
+ #
143
+ # @torch.jit.script_method
144
+ # def forward(self, X, Y):
145
+ # r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
146
+ # return r[0]
147
+ #
148
+ #
149
+ # class MS_SSIM(torch.jit.ScriptModule):
150
+ # __constants__ = ['data_range', 'use_padding', 'eps']
151
+ #
152
+ # def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
153
+ # levels=None, eps=1e-8):
154
+ # '''
155
+ # class for ms-ssim
156
+ # :param window_size: the size of gauss kernel
157
+ # :param window_sigma: sigma of normal distribution
158
+ # :param data_range: value range of input images. (usually 1.0 or 255)
159
+ # :param channel: input channels
160
+ # :param use_padding: padding image before conv
161
+ # :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
162
+ # :param levels: number of downsampling
163
+ # :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
164
+ # '''
165
+ # super().__init__()
166
+ # assert window_size % 2 == 1, 'Window size must be odd.'
167
+ # self.data_range = data_range
168
+ # self.use_padding = use_padding
169
+ # self.eps = eps
170
+ #
171
+ # window = create_window(window_size, window_sigma, channel)
172
+ # self.register_buffer('window', window)
173
+ #
174
+ # if weights is None:
175
+ # weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
176
+ # weights = torch.tensor(weights, dtype=torch.float)
177
+ #
178
+ # if levels is not None:
179
+ # weights = weights[:levels]
180
+ # weights = weights / weights.sum()
181
+ #
182
+ # self.register_buffer('weights', weights)
183
+ #
184
+ # @torch.jit.script_method
185
+ # def forward(self, X, Y):
186
+ # return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
187
+ # use_padding=self.use_padding, eps=self.eps)
188
+ #
189
+ #
190
+ # if __name__ == '__main__':
191
+ # print('Simple Test')
192
+ # im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
193
+ # img1 = im / 255
194
+ # img2 = img1 * 0.5
195
+ #
196
+ # losser = SSIM(data_range=1.).cuda()
197
+ # loss = losser(img1, img2).mean()
198
+ #
199
+ # losser2 = MS_SSIM(data_range=1.).cuda()
200
+ # loss2 = losser2(img1, img2).mean()
201
+ #
202
+ # print(loss.item())
203
+ # print(loss2.item())
204
+ #
205
+ # if __name__ == '__main__':
206
+ # print('Training Test')
207
+ # import cv2
208
+ # import torch.optim
209
+ # import numpy as np
210
+ # import imageio
211
+ # import time
212
+ #
213
+ # out_test_video = False
214
+ # # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
215
+ # video_use_gif = False
216
+ #
217
+ # im = cv2.imread('test_img1.jpg', 1)
218
+ # t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
219
+ #
220
+ # if out_test_video:
221
+ # if video_use_gif:
222
+ # fps = 0.5
223
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
224
+ # suffix = '.gif'
225
+ # else:
226
+ # fps = 5
227
+ # out_wh = (im.shape[1], im.shape[0])
228
+ # suffix = '.mkv'
229
+ # video_last_time = time.perf_counter()
230
+ # video = imageio.get_writer('ssim_test' + suffix, fps=fps)
231
+ #
232
+ # # 测试ssim
233
+ # print('Training SSIM')
234
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
235
+ # rand_im.requires_grad = True
236
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
237
+ # losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
238
+ # ssim_score = 0
239
+ # while ssim_score < 0.999:
240
+ # optim.zero_grad()
241
+ # loss = losser(rand_im, t_im)
242
+ # (-loss).sum().backward()
243
+ # ssim_score = loss.item()
244
+ # optim.step()
245
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
246
+ # r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
247
+ #
248
+ # if out_test_video:
249
+ # if time.perf_counter() - video_last_time > 1. / fps:
250
+ # video_last_time = time.perf_counter()
251
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
252
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
253
+ # if isinstance(out_frame, cv2.UMat):
254
+ # out_frame = out_frame.get()
255
+ # video.append_data(out_frame)
256
+ #
257
+ # cv2.imshow('ssim', r_im)
258
+ # cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
259
+ # cv2.waitKey(1)
260
+ #
261
+ # if out_test_video:
262
+ # video.close()
263
+ #
264
+ # # 测试ms_ssim
265
+ # if out_test_video:
266
+ # if video_use_gif:
267
+ # fps = 0.5
268
+ # out_wh = (im.shape[1] // 2, im.shape[0] // 2)
269
+ # suffix = '.gif'
270
+ # else:
271
+ # fps = 5
272
+ # out_wh = (im.shape[1], im.shape[0])
273
+ # suffix = '.mkv'
274
+ # video_last_time = time.perf_counter()
275
+ # video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
276
+ #
277
+ # print('Training MS_SSIM')
278
+ # rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
279
+ # rand_im.requires_grad = True
280
+ # optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
281
+ # losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
282
+ # ssim_score = 0
283
+ # while ssim_score < 0.999:
284
+ # optim.zero_grad()
285
+ # loss = losser(rand_im, t_im)
286
+ # (-loss).sum().backward()
287
+ # ssim_score = loss.item()
288
+ # optim.step()
289
+ # r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
290
+ # r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
291
+ #
292
+ # if out_test_video:
293
+ # if time.perf_counter() - video_last_time > 1. / fps:
294
+ # video_last_time = time.perf_counter()
295
+ # out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
296
+ # out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
297
+ # if isinstance(out_frame, cv2.UMat):
298
+ # out_frame = out_frame.get()
299
+ # video.append_data(out_frame)
300
+ #
301
+ # cv2.imshow('ms_ssim', r_im)
302
+ # cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
303
+ # cv2.waitKey(1)
304
+ #
305
+ # if out_test_video:
306
+ # video.close()
307
+
308
+ """
309
+ Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
310
+ """
311
+
312
+ import torch
313
+ import torch.nn.functional as F
314
+ from torch.autograd import Variable
315
+ import numpy as np
316
+ from math import exp
317
+
318
+
319
+ def gaussian(window_size, sigma):
320
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
321
+ return gauss / gauss.sum()
322
+
323
+
324
+ def create_window(window_size, channel):
325
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
326
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
327
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
328
+ return window
329
+
330
+
331
+ def _ssim(img1, img2, window, window_size, channel, size_average=True):
332
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
333
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
334
+
335
+ mu1_sq = mu1.pow(2)
336
+ mu2_sq = mu2.pow(2)
337
+ mu1_mu2 = mu1 * mu2
338
+
339
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
340
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
341
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
342
+
343
+ C1 = 0.01 ** 2
344
+ C2 = 0.03 ** 2
345
+
346
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
347
+
348
+ if size_average:
349
+ return ssim_map.mean()
350
+ else:
351
+ return ssim_map.mean(1)
352
+
353
+
354
+ class SSIM(torch.nn.Module):
355
+ def __init__(self, window_size=11, size_average=True):
356
+ super(SSIM, self).__init__()
357
+ self.window_size = window_size
358
+ self.size_average = size_average
359
+ self.channel = 1
360
+ self.window = create_window(window_size, self.channel)
361
+
362
+ def forward(self, img1, img2):
363
+ (_, channel, _, _) = img1.size()
364
+
365
+ if channel == self.channel and self.window.data.type() == img1.data.type():
366
+ window = self.window
367
+ else:
368
+ window = create_window(self.window_size, channel)
369
+
370
+ if img1.is_cuda:
371
+ window = window.cuda(img1.get_device())
372
+ window = window.type_as(img1)
373
+
374
+ self.window = window
375
+ self.channel = channel
376
+
377
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
378
+
379
+
380
+ window = None
381
+
382
+
383
+ def ssim(img1, img2, window_size=11, size_average=True):
384
+ (_, channel, _, _) = img1.size()
385
+ global window
386
+ if window is None:
387
+ window = create_window(window_size, channel)
388
+ if img1.is_cuda:
389
+ window = window.cuda(img1.get_device())
390
+ window = window.type_as(img1)
391
+ return _ssim(img1, img2, window, window_size, channel, size_average)
modules/fastspeech/fs2.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from modules.commons.common_layers import Embedding
3
+ from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
4
+ EnergyPredictor, FastspeechEncoder
5
+ from utils.cwt import cwt2f0
6
+ from utils.hparams import hparams
7
+ from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
8
+
9
+ FS_ENCODERS = {
10
+ 'fft': lambda hp: FastspeechEncoder(
11
+ hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
12
+ num_heads=hp['num_heads']),
13
+ }
14
+
15
+ FS_DECODERS = {
16
+ 'fft': lambda hp: FastspeechDecoder(
17
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
18
+ }
19
+
20
+
21
+ class FastSpeech2(nn.Module):
22
+ def __init__(self, dictionary, out_dims=None):
23
+ super().__init__()
24
+ # self.dictionary = dictionary
25
+ self.padding_idx = 0
26
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
27
+ self.enc_layers = hparams['enc_layers']
28
+ self.dec_layers = hparams['dec_layers']
29
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams)
30
+ self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
31
+ self.hidden_size = hparams['hidden_size']
32
+ # self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
33
+ self.out_dims = out_dims
34
+ if out_dims is None:
35
+ self.out_dims = hparams['audio_num_mel_bins']
36
+ self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
37
+ #=========not used===========
38
+ # if hparams['use_spk_id']:
39
+ # self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
40
+ # if hparams['use_split_spk_id']:
41
+ # self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
42
+ # self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
43
+ # elif hparams['use_spk_embed']:
44
+ # self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
45
+ predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
46
+ # self.dur_predictor = DurationPredictor(
47
+ # self.hidden_size,
48
+ # n_chans=predictor_hidden,
49
+ # n_layers=hparams['dur_predictor_layers'],
50
+ # dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
51
+ # kernel_size=hparams['dur_predictor_kernel'])
52
+ # self.length_regulator = LengthRegulator()
53
+ if hparams['use_pitch_embed']:
54
+ self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
55
+ if hparams['pitch_type'] == 'cwt':
56
+ h = hparams['cwt_hidden_size']
57
+ cwt_out_dims = 10
58
+ if hparams['use_uv']:
59
+ cwt_out_dims = cwt_out_dims + 1
60
+ self.cwt_predictor = nn.Sequential(
61
+ nn.Linear(self.hidden_size, h),
62
+ PitchPredictor(
63
+ h,
64
+ n_chans=predictor_hidden,
65
+ n_layers=hparams['predictor_layers'],
66
+ dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
67
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
68
+ self.cwt_stats_layers = nn.Sequential(
69
+ nn.Linear(self.hidden_size, h), nn.ReLU(),
70
+ nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
71
+ )
72
+ else:
73
+ self.pitch_predictor = PitchPredictor(
74
+ self.hidden_size,
75
+ n_chans=predictor_hidden,
76
+ n_layers=hparams['predictor_layers'],
77
+ dropout_rate=hparams['predictor_dropout'],
78
+ odim=2 if hparams['pitch_type'] == 'frame' else 1,
79
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
80
+ if hparams['use_energy_embed']:
81
+ self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
82
+ # self.energy_predictor = EnergyPredictor(
83
+ # self.hidden_size,
84
+ # n_chans=predictor_hidden,
85
+ # n_layers=hparams['predictor_layers'],
86
+ # dropout_rate=hparams['predictor_dropout'], odim=1,
87
+ # padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
88
+
89
+ # def build_embedding(self, dictionary, embed_dim):
90
+ # num_embeddings = len(dictionary)
91
+ # emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
92
+ # return emb
93
+
94
+ def forward(self, hubert, mel2ph=None, spk_embed=None,
95
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=True,
96
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
97
+ ret = {}
98
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
99
+ encoder_out =self.encoder(hubert) # [B, T, C]
100
+ else:
101
+ encoder_out =hubert
102
+ src_nonpadding = (hubert!=0).any(-1)[:,:,None]
103
+
104
+ # add ref style embed
105
+ # Not implemented
106
+ # variance encoder
107
+ var_embed = 0
108
+
109
+ # encoder_out_dur denotes encoder outputs for duration predictor
110
+ # in speech adaptation, duration predictor use old speaker embedding
111
+ if hparams['use_spk_embed']:
112
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
113
+ elif hparams['use_spk_id']:
114
+ spk_embed_id = spk_embed
115
+ if spk_embed_dur_id is None:
116
+ spk_embed_dur_id = spk_embed_id
117
+ if spk_embed_f0_id is None:
118
+ spk_embed_f0_id = spk_embed_id
119
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
120
+ spk_embed_dur = spk_embed_f0 = spk_embed
121
+ if hparams['use_split_spk_id']:
122
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
123
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
124
+ else:
125
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
126
+
127
+ # add dur
128
+ # dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
129
+
130
+ # mel2ph = self.add_dur(dur_inp, mel2ph, hubert, ret)
131
+ ret['mel2ph'] = mel2ph
132
+
133
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
134
+
135
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
136
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
137
+
138
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
139
+
140
+ # add pitch and energy embed
141
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
142
+ if hparams['use_pitch_embed']:
143
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
144
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
145
+ if hparams['use_energy_embed']:
146
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
147
+
148
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
149
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
150
+ if skip_decoder:
151
+ return ret
152
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
153
+
154
+ return ret
155
+
156
+ def add_dur(self, dur_input, mel2ph, hubert, ret):
157
+ src_padding = (hubert==0).all(-1)
158
+ dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
159
+ if mel2ph is None:
160
+ dur, xs = self.dur_predictor.inference(dur_input, src_padding)
161
+ ret['dur'] = xs
162
+ ret['dur_choice'] = dur
163
+ mel2ph = self.length_regulator(dur, src_padding).detach()
164
+ else:
165
+ ret['dur'] = self.dur_predictor(dur_input, src_padding)
166
+ ret['mel2ph'] = mel2ph
167
+ return mel2ph
168
+
169
+ def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
170
+ x = decoder_inp # [B, T, H]
171
+ x = self.decoder(x)
172
+ x = self.mel_out(x)
173
+ return x * tgt_nonpadding
174
+
175
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
176
+ f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
177
+ f0 = torch.cat(
178
+ [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
179
+ f0_norm = norm_f0(f0, None, hparams)
180
+ return f0_norm
181
+
182
+ def out2mel(self, out):
183
+ return out
184
+
185
+ def add_pitch(self,decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
186
+ # if hparams['pitch_type'] == 'ph':
187
+ # pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
188
+ # pitch_padding = (encoder_out.sum().abs() == 0)
189
+ # ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
190
+ # if f0 is None:
191
+ # f0 = pitch_pred[:, :, 0]
192
+ # ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
193
+ # pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
194
+ # pitch = F.pad(pitch, [1, 0])
195
+ # pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
196
+ # pitch_embedding = pitch_embed(pitch)
197
+ # return pitch_embedding
198
+
199
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
200
+
201
+ pitch_padding = (mel2ph == 0)
202
+
203
+ # if hparams['pitch_type'] == 'cwt':
204
+ # # NOTE: this part of script is *isolated* from other scripts, which means
205
+ # # it may not be compatible with the current version.
206
+ # pass
207
+ # # pitch_padding = None
208
+ # # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
209
+ # # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
210
+ # # mean = ret['f0_mean'] = stats_out[:, 0]
211
+ # # std = ret['f0_std'] = stats_out[:, 1]
212
+ # # cwt_spec = cwt_out[:, :, :10]
213
+ # # if f0 is None:
214
+ # # std = std * hparams['cwt_std_scale']
215
+ # # f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
216
+ # # if hparams['use_uv']:
217
+ # # assert cwt_out.shape[-1] == 11
218
+ # # uv = cwt_out[:, :, -1] > 0
219
+ # elif hparams['pitch_ar']:
220
+ # ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if is_training else None)
221
+ # if f0 is None:
222
+ # f0 = pitch_pred[:, :, 0]
223
+ # else:
224
+ #ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
225
+ # if f0 is None:
226
+ # f0 = pitch_pred[:, :, 0]
227
+ # if hparams['use_uv'] and uv is None:
228
+ # uv = pitch_pred[:, :, 1] > 0
229
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
230
+ if pitch_padding is not None:
231
+ f0[pitch_padding] = 0
232
+
233
+ pitch = f0_to_coarse(f0_denorm,hparams) # start from 0
234
+ ret['pitch_pred']=pitch.unsqueeze(-1)
235
+ # print(ret['pitch_pred'].shape)
236
+ # print(pitch.shape)
237
+ pitch_embedding = self.pitch_embed(pitch)
238
+ return pitch_embedding
239
+
240
+ def add_energy(self,decoder_inp, energy, ret):
241
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
242
+ ret['energy_pred'] = energy#energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
243
+ # if energy is None:
244
+ # energy = energy_pred
245
+ energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
246
+ energy_embedding = self.energy_embed(energy)
247
+ return energy_embedding
248
+
249
+ @staticmethod
250
+ def mel_norm(x):
251
+ return (x + 5.5) / (6.3 / 2) - 1
252
+
253
+ @staticmethod
254
+ def mel_denorm(x):
255
+ return (x + 1) * (6.3 / 2) - 5.5
modules/fastspeech/pe.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.commons.common_layers import *
2
+ from utils.hparams import hparams
3
+ from modules.fastspeech.tts_modules import PitchPredictor
4
+ from utils.pitch_utils import denorm_f0
5
+
6
+
7
+ class Prenet(nn.Module):
8
+ def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
9
+ super(Prenet, self).__init__()
10
+ padding = kernel // 2
11
+ self.layers = []
12
+ self.strides = strides if strides is not None else [1] * n_layers
13
+ for l in range(n_layers):
14
+ self.layers.append(nn.Sequential(
15
+ nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
16
+ nn.ReLU(),
17
+ nn.BatchNorm1d(out_dim)
18
+ ))
19
+ in_dim = out_dim
20
+ self.layers = nn.ModuleList(self.layers)
21
+ self.out_proj = nn.Linear(out_dim, out_dim)
22
+
23
+ def forward(self, x):
24
+ """
25
+
26
+ :param x: [B, T, 80]
27
+ :return: [L, B, T, H], [B, T, H]
28
+ """
29
+ # padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
30
+ padding_mask = x.abs().sum(-1).eq(0).detach()
31
+ nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
32
+ x = x.transpose(1, 2)
33
+ hiddens = []
34
+ for i, l in enumerate(self.layers):
35
+ nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
36
+ x = l(x) * nonpadding_mask_TB
37
+ hiddens.append(x)
38
+ hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
39
+ hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
40
+ x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
41
+ x = x * nonpadding_mask_TB.transpose(1, 2)
42
+ return hiddens, x
43
+
44
+
45
+ class ConvBlock(nn.Module):
46
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
47
+ super().__init__()
48
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
49
+ self.norm = norm
50
+ if self.norm == 'bn':
51
+ self.norm = nn.BatchNorm1d(n_chans)
52
+ elif self.norm == 'in':
53
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
54
+ elif self.norm == 'gn':
55
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
56
+ elif self.norm == 'ln':
57
+ self.norm = LayerNorm(n_chans // 16, n_chans)
58
+ elif self.norm == 'wn':
59
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
60
+ self.dropout = nn.Dropout(dropout)
61
+ self.relu = nn.ReLU()
62
+
63
+ def forward(self, x):
64
+ """
65
+
66
+ :param x: [B, C, T]
67
+ :return: [B, C, T]
68
+ """
69
+ x = self.conv(x)
70
+ if not isinstance(self.norm, str):
71
+ if self.norm == 'none':
72
+ pass
73
+ elif self.norm == 'ln':
74
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
75
+ else:
76
+ x = self.norm(x)
77
+ x = self.relu(x)
78
+ x = self.dropout(x)
79
+ return x
80
+
81
+
82
+ class ConvStacks(nn.Module):
83
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
84
+ dropout=0, strides=None, res=True):
85
+ super().__init__()
86
+ self.conv = torch.nn.ModuleList()
87
+ self.kernel_size = kernel_size
88
+ self.res = res
89
+ self.in_proj = Linear(idim, n_chans)
90
+ if strides is None:
91
+ strides = [1] * n_layers
92
+ else:
93
+ assert len(strides) == n_layers
94
+ for idx in range(n_layers):
95
+ self.conv.append(ConvBlock(
96
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
97
+ self.out_proj = Linear(n_chans, odim)
98
+
99
+ def forward(self, x, return_hiddens=False):
100
+ """
101
+
102
+ :param x: [B, T, H]
103
+ :return: [B, T, H]
104
+ """
105
+ x = self.in_proj(x)
106
+ x = x.transpose(1, -1) # (B, idim, Tmax)
107
+ hiddens = []
108
+ for f in self.conv:
109
+ x_ = f(x)
110
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
111
+ hiddens.append(x)
112
+ x = x.transpose(1, -1)
113
+ x = self.out_proj(x) # (B, Tmax, H)
114
+ if return_hiddens:
115
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
116
+ return x, hiddens
117
+ return x
118
+
119
+
120
+ class PitchExtractor(nn.Module):
121
+ def __init__(self, n_mel_bins=80, conv_layers=2):
122
+ super().__init__()
123
+ self.hidden_size = hparams['hidden_size']
124
+ self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
125
+ self.conv_layers = conv_layers
126
+
127
+ self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
128
+ if self.conv_layers > 0:
129
+ self.mel_encoder = ConvStacks(
130
+ idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
131
+ self.pitch_predictor = PitchPredictor(
132
+ self.hidden_size, n_chans=self.predictor_hidden,
133
+ n_layers=5, dropout_rate=0.1, odim=2,
134
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
135
+
136
+ def forward(self, mel_input=None):
137
+ ret = {}
138
+ mel_hidden = self.mel_prenet(mel_input)[1]
139
+ if self.conv_layers > 0:
140
+ mel_hidden = self.mel_encoder(mel_hidden)
141
+
142
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
143
+
144
+ pitch_padding = mel_input.abs().sum(-1) == 0
145
+ use_uv = hparams['pitch_type'] == 'frame' #and hparams['use_uv']
146
+ ret['f0_denorm_pred'] = denorm_f0(
147
+ pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
148
+ hparams, pitch_padding=pitch_padding)
149
+ return ret
modules/fastspeech/tts_modules.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ from modules.commons.espnet_positional_embedding import RelPositionalEncoding
9
+ from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
10
+ from utils.hparams import hparams
11
+
12
+ DEFAULT_MAX_SOURCE_POSITIONS = 2000
13
+ DEFAULT_MAX_TARGET_POSITIONS = 2000
14
+
15
+
16
+ class TransformerEncoderLayer(nn.Module):
17
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
18
+ super().__init__()
19
+ self.hidden_size = hidden_size
20
+ self.dropout = dropout
21
+ self.num_heads = num_heads
22
+ self.op = EncSALayer(
23
+ hidden_size, num_heads, dropout=dropout,
24
+ attention_dropout=0.0, relu_dropout=dropout,
25
+ kernel_size=kernel_size
26
+ if kernel_size is not None else hparams['enc_ffn_kernel_size'],
27
+ padding=hparams['ffn_padding'],
28
+ norm=norm, act=hparams['ffn_act'])
29
+
30
+ def forward(self, x, **kwargs):
31
+ return self.op(x, **kwargs)
32
+
33
+
34
+ ######################
35
+ # fastspeech modules
36
+ ######################
37
+ class LayerNorm(torch.nn.LayerNorm):
38
+ """Layer normalization module.
39
+ :param int nout: output dim size
40
+ :param int dim: dimension to be normalized
41
+ """
42
+
43
+ def __init__(self, nout, dim=-1):
44
+ """Construct an LayerNorm object."""
45
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
46
+ self.dim = dim
47
+
48
+ def forward(self, x):
49
+ """Apply layer normalization.
50
+ :param torch.Tensor x: input tensor
51
+ :return: layer normalized tensor
52
+ :rtype torch.Tensor
53
+ """
54
+ if self.dim == -1:
55
+ return super(LayerNorm, self).forward(x)
56
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
57
+
58
+
59
+ class DurationPredictor(torch.nn.Module):
60
+ """Duration predictor module.
61
+ This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
62
+ The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
63
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
64
+ https://arxiv.org/pdf/1905.09263.pdf
65
+ Note:
66
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
67
+ the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
68
+ """
69
+
70
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
71
+ """Initilize duration predictor module.
72
+ Args:
73
+ idim (int): Input dimension.
74
+ n_layers (int, optional): Number of convolutional layers.
75
+ n_chans (int, optional): Number of channels of convolutional layers.
76
+ kernel_size (int, optional): Kernel size of convolutional layers.
77
+ dropout_rate (float, optional): Dropout rate.
78
+ offset (float, optional): Offset value to avoid nan in log domain.
79
+ """
80
+ super(DurationPredictor, self).__init__()
81
+ self.offset = offset
82
+ self.conv = torch.nn.ModuleList()
83
+ self.kernel_size = kernel_size
84
+ self.padding = padding
85
+ for idx in range(n_layers):
86
+ in_chans = idim if idx == 0 else n_chans
87
+ self.conv += [torch.nn.Sequential(
88
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
89
+ if padding == 'SAME'
90
+ else (kernel_size - 1, 0), 0),
91
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
92
+ torch.nn.ReLU(),
93
+ LayerNorm(n_chans, dim=1),
94
+ torch.nn.Dropout(dropout_rate)
95
+ )]
96
+ if hparams['dur_loss'] in ['mse', 'huber']:
97
+ odims = 1
98
+ elif hparams['dur_loss'] == 'mog':
99
+ odims = 15
100
+ elif hparams['dur_loss'] == 'crf':
101
+ odims = 32
102
+ from torchcrf import CRF
103
+ self.crf = CRF(odims, batch_first=True)
104
+ self.linear = torch.nn.Linear(n_chans, odims)
105
+
106
+ def _forward(self, xs, x_masks=None, is_inference=False):
107
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
108
+ for f in self.conv:
109
+ xs = f(xs) # (B, C, Tmax)
110
+ if x_masks is not None:
111
+ xs = xs * (1 - x_masks.float())[:, None, :]
112
+
113
+ xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
114
+ xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
115
+ if is_inference:
116
+ return self.out2dur(xs), xs
117
+ else:
118
+ if hparams['dur_loss'] in ['mse']:
119
+ xs = xs.squeeze(-1) # (B, Tmax)
120
+ return xs
121
+
122
+ def out2dur(self, xs):
123
+ if hparams['dur_loss'] in ['mse']:
124
+ # NOTE: calculate in log domain
125
+ xs = xs.squeeze(-1) # (B, Tmax)
126
+ dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
127
+ elif hparams['dur_loss'] == 'mog':
128
+ return NotImplementedError
129
+ elif hparams['dur_loss'] == 'crf':
130
+ dur = torch.LongTensor(self.crf.decode(xs)).cuda()
131
+ return dur
132
+
133
+ def forward(self, xs, x_masks=None):
134
+ """Calculate forward propagation.
135
+ Args:
136
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
137
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
138
+ Returns:
139
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
140
+ """
141
+ return self._forward(xs, x_masks, False)
142
+
143
+ def inference(self, xs, x_masks=None):
144
+ """Inference duration.
145
+ Args:
146
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
147
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
148
+ Returns:
149
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
150
+ """
151
+ return self._forward(xs, x_masks, True)
152
+
153
+
154
+ class LengthRegulator(torch.nn.Module):
155
+ def __init__(self, pad_value=0.0):
156
+ super(LengthRegulator, self).__init__()
157
+ self.pad_value = pad_value
158
+
159
+ def forward(self, dur, dur_padding=None, alpha=1.0):
160
+ """
161
+ Example (no batch dim version):
162
+ 1. dur = [2,2,3]
163
+ 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
164
+ 3. token_mask = [[1,1,0,0,0,0,0],
165
+ [0,0,1,1,0,0,0],
166
+ [0,0,0,0,1,1,1]]
167
+ 4. token_idx * token_mask = [[1,1,0,0,0,0,0],
168
+ [0,0,2,2,0,0,0],
169
+ [0,0,0,0,3,3,3]]
170
+ 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
171
+
172
+ :param dur: Batch of durations of each frame (B, T_txt)
173
+ :param dur_padding: Batch of padding of each frame (B, T_txt)
174
+ :param alpha: duration rescale coefficient
175
+ :return:
176
+ mel2ph (B, T_speech)
177
+ """
178
+ assert alpha > 0
179
+ dur = torch.round(dur.float() * alpha).long()
180
+ if dur_padding is not None:
181
+ dur = dur * (1 - dur_padding.long())
182
+ token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
183
+ dur_cumsum = torch.cumsum(dur, 1)
184
+ dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
185
+
186
+ pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
187
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
188
+ mel2ph = (token_idx * token_mask.long()).sum(1)
189
+ return mel2ph
190
+
191
+
192
+ class PitchPredictor(torch.nn.Module):
193
+ def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
194
+ dropout_rate=0.1, padding='SAME'):
195
+ """Initilize pitch predictor module.
196
+ Args:
197
+ idim (int): Input dimension.
198
+ n_layers (int, optional): Number of convolutional layers.
199
+ n_chans (int, optional): Number of channels of convolutional layers.
200
+ kernel_size (int, optional): Kernel size of convolutional layers.
201
+ dropout_rate (float, optional): Dropout rate.
202
+ """
203
+ super(PitchPredictor, self).__init__()
204
+ self.conv = torch.nn.ModuleList()
205
+ self.kernel_size = kernel_size
206
+ self.padding = padding
207
+ for idx in range(n_layers):
208
+ in_chans = idim if idx == 0 else n_chans
209
+ self.conv += [torch.nn.Sequential(
210
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
211
+ if padding == 'SAME'
212
+ else (kernel_size - 1, 0), 0),
213
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
214
+ torch.nn.ReLU(),
215
+ LayerNorm(n_chans, dim=1),
216
+ torch.nn.Dropout(dropout_rate)
217
+ )]
218
+ self.linear = torch.nn.Linear(n_chans, odim)
219
+ self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
220
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
221
+
222
+ def forward(self, xs):
223
+ """
224
+
225
+ :param xs: [B, T, H]
226
+ :return: [B, T, H]
227
+ """
228
+ positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
229
+ xs = xs + positions
230
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
231
+ for f in self.conv:
232
+ xs = f(xs) # (B, C, Tmax)
233
+ # NOTE: calculate in log domain
234
+ xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
235
+ return xs
236
+
237
+
238
+ class EnergyPredictor(PitchPredictor):
239
+ pass
240
+
241
+
242
+ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
243
+ B, _ = mel2ph.shape
244
+ dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
245
+ dur = dur[:, 1:]
246
+ if max_dur is not None:
247
+ dur = dur.clamp(max=max_dur)
248
+ return dur
249
+
250
+
251
+ class FFTBlocks(nn.Module):
252
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
253
+ use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
254
+ super().__init__()
255
+ self.num_layers = num_layers
256
+ embed_dim = self.hidden_size = hidden_size
257
+ self.dropout = dropout if dropout is not None else hparams['dropout']
258
+ self.use_pos_embed = use_pos_embed
259
+ self.use_last_norm = use_last_norm
260
+ if use_pos_embed:
261
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
262
+ self.padding_idx = 0
263
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
264
+ self.embed_positions = SinusoidalPositionalEmbedding(
265
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
266
+ )
267
+
268
+ self.layers = nn.ModuleList([])
269
+ self.layers.extend([
270
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
271
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
272
+ for _ in range(self.num_layers)
273
+ ])
274
+ if self.use_last_norm:
275
+ if norm == 'ln':
276
+ self.layer_norm = nn.LayerNorm(embed_dim)
277
+ elif norm == 'bn':
278
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
279
+ else:
280
+ self.layer_norm = None
281
+
282
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
283
+ """
284
+ :param x: [B, T, C]
285
+ :param padding_mask: [B, T]
286
+ :return: [B, T, C] or [L, B, T, C]
287
+ """
288
+ # padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
289
+ padding_mask = x.abs().sum(-1).eq(0).detach() if padding_mask is None else padding_mask
290
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
291
+ if self.use_pos_embed:
292
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
293
+ x = x + positions
294
+ x = F.dropout(x, p=self.dropout, training=self.training)
295
+ # B x T x C -> T x B x C
296
+ x = x.transpose(0, 1) * nonpadding_mask_TB
297
+ hiddens = []
298
+ for layer in self.layers:
299
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
300
+ hiddens.append(x)
301
+ if self.use_last_norm:
302
+ x = self.layer_norm(x) * nonpadding_mask_TB
303
+ if return_hiddens:
304
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
305
+ x = x.transpose(1, 2) # [L, B, T, C]
306
+ else:
307
+ x = x.transpose(0, 1) # [B, T, C]
308
+ return x
309
+
310
+
311
+ class FastspeechEncoder(FFTBlocks):
312
+ '''
313
+ compared to FFTBlocks:
314
+ - input is [B, T, H], not [B, T, C]
315
+ - supports "relative" positional encoding
316
+ '''
317
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
318
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
319
+ kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
320
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
321
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
322
+ use_pos_embed=False) # use_pos_embed_alpha for compatibility
323
+ #self.embed_tokens = embed_tokens
324
+ self.embed_scale = math.sqrt(hidden_size)
325
+ self.padding_idx = 0
326
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
327
+ self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
328
+ else:
329
+ self.embed_positions = SinusoidalPositionalEmbedding(
330
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
331
+ )
332
+
333
+ def forward(self, hubert):
334
+ """
335
+
336
+ :param hubert: [B, T, H ]
337
+ :return: {
338
+ 'encoder_out': [T x B x C]
339
+ }
340
+ """
341
+ # encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
342
+ encoder_padding_mask = (hubert==0).all(-1)
343
+ x = self.forward_embedding(hubert) # [B, T, H]
344
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
345
+ return x
346
+
347
+ def forward_embedding(self, hubert):
348
+ # embed tokens and positions
349
+ x = self.embed_scale * hubert
350
+ if hparams['use_pos_embed']:
351
+ positions = self.embed_positions(hubert)
352
+ x = x + positions
353
+ x = F.dropout(x, p=self.dropout, training=self.training)
354
+ return x
355
+
356
+
357
+ class FastspeechDecoder(FFTBlocks):
358
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
359
+ num_heads = hparams['num_heads'] if num_heads is None else num_heads
360
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
361
+ kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
362
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
363
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
364
+
modules/hifigan/hifigan.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+
7
+ from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
8
+ from modules.parallel_wavegan.models.source import SourceModuleHnNSF
9
+ import numpy as np
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ def init_weights(m, mean=0.0, std=0.01):
15
+ classname = m.__class__.__name__
16
+ if classname.find("Conv") != -1:
17
+ m.weight.data.normal_(mean, std)
18
+
19
+
20
+ def apply_weight_norm(m):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ weight_norm(m)
24
+
25
+
26
+ def get_padding(kernel_size, dilation=1):
27
+ return int((kernel_size * dilation - dilation) / 2)
28
+
29
+
30
+ class ResBlock1(torch.nn.Module):
31
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32
+ super(ResBlock1, self).__init__()
33
+ self.h = h
34
+ self.convs1 = nn.ModuleList([
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36
+ padding=get_padding(kernel_size, dilation[0]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38
+ padding=get_padding(kernel_size, dilation[1]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40
+ padding=get_padding(kernel_size, dilation[2])))
41
+ ])
42
+ self.convs1.apply(init_weights)
43
+
44
+ self.convs2 = nn.ModuleList([
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1)))
51
+ ])
52
+ self.convs2.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ for c1, c2 in zip(self.convs1, self.convs2):
56
+ xt = F.leaky_relu(x, LRELU_SLOPE)
57
+ xt = c1(xt)
58
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
59
+ xt = c2(xt)
60
+ x = xt + x
61
+ return x
62
+
63
+ def remove_weight_norm(self):
64
+ for l in self.convs1:
65
+ remove_weight_norm(l)
66
+ for l in self.convs2:
67
+ remove_weight_norm(l)
68
+
69
+
70
+ class ResBlock2(torch.nn.Module):
71
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72
+ super(ResBlock2, self).__init__()
73
+ self.h = h
74
+ self.convs = nn.ModuleList([
75
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76
+ padding=get_padding(kernel_size, dilation[0]))),
77
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78
+ padding=get_padding(kernel_size, dilation[1])))
79
+ ])
80
+ self.convs.apply(init_weights)
81
+
82
+ def forward(self, x):
83
+ for c in self.convs:
84
+ xt = F.leaky_relu(x, LRELU_SLOPE)
85
+ xt = c(xt)
86
+ x = xt + x
87
+ return x
88
+
89
+ def remove_weight_norm(self):
90
+ for l in self.convs:
91
+ remove_weight_norm(l)
92
+
93
+
94
+ class Conv1d1x1(Conv1d):
95
+ """1x1 Conv1d with customized initialization."""
96
+
97
+ def __init__(self, in_channels, out_channels, bias):
98
+ """Initialize 1x1 Conv1d module."""
99
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
100
+ kernel_size=1, padding=0,
101
+ dilation=1, bias=bias)
102
+
103
+
104
+ class HifiGanGenerator(torch.nn.Module):
105
+ def __init__(self, h, c_out=1):
106
+ super(HifiGanGenerator, self).__init__()
107
+ self.h = h
108
+ self.num_kernels = len(h['resblock_kernel_sizes'])
109
+ self.num_upsamples = len(h['upsample_rates'])
110
+
111
+ if h['use_pitch_embed']:
112
+ self.harmonic_num = 8
113
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114
+ self.m_source = SourceModuleHnNSF(
115
+ sampling_rate=h['audio_sample_rate'],
116
+ harmonic_num=self.harmonic_num)
117
+ self.noise_convs = nn.ModuleList()
118
+ self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119
+ resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120
+
121
+ self.ups = nn.ModuleList()
122
+ for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123
+ c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124
+ self.ups.append(weight_norm(
125
+ ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126
+ if h['use_pitch_embed']:
127
+ if i + 1 < len(h['upsample_rates']):
128
+ stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129
+ self.noise_convs.append(Conv1d(
130
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131
+ else:
132
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133
+
134
+ self.resblocks = nn.ModuleList()
135
+ for i in range(len(self.ups)):
136
+ ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137
+ for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138
+ self.resblocks.append(resblock(h, ch, k, d))
139
+
140
+ self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141
+ self.ups.apply(init_weights)
142
+ self.conv_post.apply(init_weights)
143
+
144
+ def forward(self, x, f0=None):
145
+ if f0 is not None:
146
+ # harmonic-source signal, noise-source signal, uv flag
147
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148
+ har_source, noi_source, uv = self.m_source(f0)
149
+ har_source = har_source.transpose(1, 2)
150
+
151
+ x = self.conv_pre(x)
152
+ for i in range(self.num_upsamples):
153
+ x = F.leaky_relu(x, LRELU_SLOPE)
154
+ x = self.ups[i](x)
155
+ if f0 is not None:
156
+ x_source = self.noise_convs[i](har_source)
157
+ x = x + x_source
158
+ xs = None
159
+ for j in range(self.num_kernels):
160
+ if xs is None:
161
+ xs = self.resblocks[i * self.num_kernels + j](x)
162
+ else:
163
+ xs += self.resblocks[i * self.num_kernels + j](x)
164
+ x = xs / self.num_kernels
165
+ x = F.leaky_relu(x)
166
+ x = self.conv_post(x)
167
+ x = torch.tanh(x)
168
+
169
+ return x
170
+
171
+ def remove_weight_norm(self):
172
+ print('Removing weight norm...')
173
+ for l in self.ups:
174
+ remove_weight_norm(l)
175
+ for l in self.resblocks:
176
+ l.remove_weight_norm()
177
+ remove_weight_norm(self.conv_pre)
178
+ remove_weight_norm(self.conv_post)
179
+
180
+
181
+ class DiscriminatorP(torch.nn.Module):
182
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
183
+ super(DiscriminatorP, self).__init__()
184
+ self.use_cond = use_cond
185
+ if use_cond:
186
+ from utils.hparams import hparams
187
+ t = hparams['hop_size']
188
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
189
+ c_in = 2
190
+
191
+ self.period = period
192
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
193
+ self.convs = nn.ModuleList([
194
+ norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
195
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
196
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
197
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
198
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
199
+ ])
200
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
201
+
202
+ def forward(self, x, mel):
203
+ fmap = []
204
+ if self.use_cond:
205
+ x_mel = self.cond_net(mel)
206
+ x = torch.cat([x_mel, x], 1)
207
+ # 1d to 2d
208
+ b, c, t = x.shape
209
+ if t % self.period != 0: # pad first
210
+ n_pad = self.period - (t % self.period)
211
+ x = F.pad(x, (0, n_pad), "reflect")
212
+ t = t + n_pad
213
+ x = x.view(b, c, t // self.period, self.period)
214
+
215
+ for l in self.convs:
216
+ x = l(x)
217
+ x = F.leaky_relu(x, LRELU_SLOPE)
218
+ fmap.append(x)
219
+ x = self.conv_post(x)
220
+ fmap.append(x)
221
+ x = torch.flatten(x, 1, -1)
222
+
223
+ return x, fmap
224
+
225
+
226
+ class MultiPeriodDiscriminator(torch.nn.Module):
227
+ def __init__(self, use_cond=False, c_in=1):
228
+ super(MultiPeriodDiscriminator, self).__init__()
229
+ self.discriminators = nn.ModuleList([
230
+ DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
231
+ DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
232
+ DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
233
+ DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
234
+ DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
235
+ ])
236
+
237
+ def forward(self, y, y_hat, mel=None):
238
+ y_d_rs = []
239
+ y_d_gs = []
240
+ fmap_rs = []
241
+ fmap_gs = []
242
+ for i, d in enumerate(self.discriminators):
243
+ y_d_r, fmap_r = d(y, mel)
244
+ y_d_g, fmap_g = d(y_hat, mel)
245
+ y_d_rs.append(y_d_r)
246
+ fmap_rs.append(fmap_r)
247
+ y_d_gs.append(y_d_g)
248
+ fmap_gs.append(fmap_g)
249
+
250
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
251
+
252
+
253
+ class DiscriminatorS(torch.nn.Module):
254
+ def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
255
+ super(DiscriminatorS, self).__init__()
256
+ self.use_cond = use_cond
257
+ if use_cond:
258
+ t = np.prod(upsample_rates)
259
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
260
+ c_in = 2
261
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
262
+ self.convs = nn.ModuleList([
263
+ norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
264
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
265
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
266
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
267
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
268
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
269
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
270
+ ])
271
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
272
+
273
+ def forward(self, x, mel):
274
+ if self.use_cond:
275
+ x_mel = self.cond_net(mel)
276
+ x = torch.cat([x_mel, x], 1)
277
+ fmap = []
278
+ for l in self.convs:
279
+ x = l(x)
280
+ x = F.leaky_relu(x, LRELU_SLOPE)
281
+ fmap.append(x)
282
+ x = self.conv_post(x)
283
+ fmap.append(x)
284
+ x = torch.flatten(x, 1, -1)
285
+
286
+ return x, fmap
287
+
288
+
289
+ class MultiScaleDiscriminator(torch.nn.Module):
290
+ def __init__(self, use_cond=False, c_in=1):
291
+ super(MultiScaleDiscriminator, self).__init__()
292
+ from utils.hparams import hparams
293
+ self.discriminators = nn.ModuleList([
294
+ DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
295
+ upsample_rates=[4, 4, hparams['hop_size'] // 16],
296
+ c_in=c_in),
297
+ DiscriminatorS(use_cond=use_cond,
298
+ upsample_rates=[4, 4, hparams['hop_size'] // 32],
299
+ c_in=c_in),
300
+ DiscriminatorS(use_cond=use_cond,
301
+ upsample_rates=[4, 4, hparams['hop_size'] // 64],
302
+ c_in=c_in),
303
+ ])
304
+ self.meanpools = nn.ModuleList([
305
+ AvgPool1d(4, 2, padding=1),
306
+ AvgPool1d(4, 2, padding=1)
307
+ ])
308
+
309
+ def forward(self, y, y_hat, mel=None):
310
+ y_d_rs = []
311
+ y_d_gs = []
312
+ fmap_rs = []
313
+ fmap_gs = []
314
+ for i, d in enumerate(self.discriminators):
315
+ if i != 0:
316
+ y = self.meanpools[i - 1](y)
317
+ y_hat = self.meanpools[i - 1](y_hat)
318
+ y_d_r, fmap_r = d(y, mel)
319
+ y_d_g, fmap_g = d(y_hat, mel)
320
+ y_d_rs.append(y_d_r)
321
+ fmap_rs.append(fmap_r)
322
+ y_d_gs.append(y_d_g)
323
+ fmap_gs.append(fmap_g)
324
+
325
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
326
+
327
+
328
+ def feature_loss(fmap_r, fmap_g):
329
+ loss = 0
330
+ for dr, dg in zip(fmap_r, fmap_g):
331
+ for rl, gl in zip(dr, dg):
332
+ loss += torch.mean(torch.abs(rl - gl))
333
+
334
+ return loss * 2
335
+
336
+
337
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
338
+ r_losses = 0
339
+ g_losses = 0
340
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
341
+ r_loss = torch.mean((1 - dr) ** 2)
342
+ g_loss = torch.mean(dg ** 2)
343
+ r_losses += r_loss
344
+ g_losses += g_loss
345
+ r_losses = r_losses / len(disc_real_outputs)
346
+ g_losses = g_losses / len(disc_real_outputs)
347
+ return r_losses, g_losses
348
+
349
+
350
+ def cond_discriminator_loss(outputs):
351
+ loss = 0
352
+ for dg in outputs:
353
+ g_loss = torch.mean(dg ** 2)
354
+ loss += g_loss
355
+ loss = loss / len(outputs)
356
+ return loss
357
+
358
+
359
+ def generator_loss(disc_outputs):
360
+ loss = 0
361
+ for dg in disc_outputs:
362
+ l = torch.mean((1 - dg) ** 2)
363
+ loss += l
364
+ loss = loss / len(disc_outputs)
365
+ return loss
modules/hifigan/mel_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, hparams, center=False, complex=False):
46
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
47
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
48
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
49
+ # fmax: 10000 # To be increased/reduced depending on data.
50
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
51
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
52
+ n_fft = hparams['fft_size']
53
+ num_mels = hparams['audio_num_mel_bins']
54
+ sampling_rate = hparams['audio_sample_rate']
55
+ hop_size = hparams['hop_size']
56
+ win_size = hparams['win_size']
57
+ fmin = hparams['fmin']
58
+ fmax = hparams['fmax']
59
+ y = y.clamp(min=-1., max=1.)
60
+ global mel_basis, hann_window
61
+ if fmax not in mel_basis:
62
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
63
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
64
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
65
+
66
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
67
+ mode='reflect')
68
+ y = y.squeeze(1)
69
+
70
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
71
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
72
+
73
+ if not complex:
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
75
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
76
+ spec = spectral_normalize_torch(spec)
77
+ else:
78
+ B, C, T, _ = spec.shape
79
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
80
+ return spec
modules/nsf_hifigan/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
modules/nsf_hifigan/models.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from .env import AttrDict
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10
+ from .utils import init_weights, get_padding
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ def load_model(model_path, device='cuda'):
15
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
16
+ with open(config_file) as f:
17
+ data = f.read()
18
+
19
+ global h
20
+ json_config = json.loads(data)
21
+ h = AttrDict(json_config)
22
+
23
+ generator = Generator(h).to(device)
24
+
25
+ cp_dict = torch.load(model_path)
26
+ generator.load_state_dict(cp_dict['generator'])
27
+ generator.eval()
28
+ generator.remove_weight_norm()
29
+ del cp_dict
30
+ return generator, h
31
+
32
+
33
+ class ResBlock1(torch.nn.Module):
34
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
35
+ super(ResBlock1, self).__init__()
36
+ self.h = h
37
+ self.convs1 = nn.ModuleList([
38
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
39
+ padding=get_padding(kernel_size, dilation[0]))),
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
41
+ padding=get_padding(kernel_size, dilation[1]))),
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
43
+ padding=get_padding(kernel_size, dilation[2])))
44
+ ])
45
+ self.convs1.apply(init_weights)
46
+
47
+ self.convs2 = nn.ModuleList([
48
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
49
+ padding=get_padding(kernel_size, 1))),
50
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
51
+ padding=get_padding(kernel_size, 1))),
52
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
53
+ padding=get_padding(kernel_size, 1)))
54
+ ])
55
+ self.convs2.apply(init_weights)
56
+
57
+ def forward(self, x):
58
+ for c1, c2 in zip(self.convs1, self.convs2):
59
+ xt = F.leaky_relu(x, LRELU_SLOPE)
60
+ xt = c1(xt)
61
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
62
+ xt = c2(xt)
63
+ x = xt + x
64
+ return x
65
+
66
+ def remove_weight_norm(self):
67
+ for l in self.convs1:
68
+ remove_weight_norm(l)
69
+ for l in self.convs2:
70
+ remove_weight_norm(l)
71
+
72
+
73
+ class ResBlock2(torch.nn.Module):
74
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
75
+ super(ResBlock2, self).__init__()
76
+ self.h = h
77
+ self.convs = nn.ModuleList([
78
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
79
+ padding=get_padding(kernel_size, dilation[0]))),
80
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
81
+ padding=get_padding(kernel_size, dilation[1])))
82
+ ])
83
+ self.convs.apply(init_weights)
84
+
85
+ def forward(self, x):
86
+ for c in self.convs:
87
+ xt = F.leaky_relu(x, LRELU_SLOPE)
88
+ xt = c(xt)
89
+ x = xt + x
90
+ return x
91
+
92
+ def remove_weight_norm(self):
93
+ for l in self.convs:
94
+ remove_weight_norm(l)
95
+
96
+
97
+ class Generator(torch.nn.Module):
98
+ def __init__(self, h):
99
+ super(Generator, self).__init__()
100
+ self.h = h
101
+ self.num_kernels = len(h.resblock_kernel_sizes)
102
+ self.num_upsamples = len(h.upsample_rates)
103
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
104
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
105
+
106
+ self.ups = nn.ModuleList()
107
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
108
+ self.ups.append(weight_norm(
109
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
110
+ k, u, padding=(k-u)//2)))
111
+
112
+ self.resblocks = nn.ModuleList()
113
+ for i in range(len(self.ups)):
114
+ ch = h.upsample_initial_channel//(2**(i+1))
115
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
116
+ self.resblocks.append(resblock(h, ch, k, d))
117
+
118
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
119
+ self.ups.apply(init_weights)
120
+ self.conv_post.apply(init_weights)
121
+
122
+ def forward(self, x):
123
+ x = self.conv_pre(x)
124
+ for i in range(self.num_upsamples):
125
+ x = F.leaky_relu(x, LRELU_SLOPE)
126
+ x = self.ups[i](x)
127
+ xs = None
128
+ for j in range(self.num_kernels):
129
+ if xs is None:
130
+ xs = self.resblocks[i*self.num_kernels+j](x)
131
+ else:
132
+ xs += self.resblocks[i*self.num_kernels+j](x)
133
+ x = xs / self.num_kernels
134
+ x = F.leaky_relu(x)
135
+ x = self.conv_post(x)
136
+ x = torch.tanh(x)
137
+
138
+ return x
139
+
140
+ def remove_weight_norm(self):
141
+ print('Removing weight norm...')
142
+ for l in self.ups:
143
+ remove_weight_norm(l)
144
+ for l in self.resblocks:
145
+ l.remove_weight_norm()
146
+ remove_weight_norm(self.conv_pre)
147
+ remove_weight_norm(self.conv_post)
148
+ class SineGen(torch.nn.Module):
149
+ """ Definition of sine generator
150
+ SineGen(samp_rate, harmonic_num = 0,
151
+ sine_amp = 0.1, noise_std = 0.003,
152
+ voiced_threshold = 0,
153
+ flag_for_pulse=False)
154
+ samp_rate: sampling rate in Hz
155
+ harmonic_num: number of harmonic overtones (default 0)
156
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
157
+ noise_std: std of Gaussian noise (default 0.003)
158
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
159
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
160
+ Note: when flag_for_pulse is True, the first time step of a voiced
161
+ segment is always sin(np.pi) or cos(0)
162
+ """
163
+
164
+ def __init__(self, samp_rate, harmonic_num=0,
165
+ sine_amp=0.1, noise_std=0.003,
166
+ voiced_threshold=0,
167
+ flag_for_pulse=False):
168
+ super(SineGen, self).__init__()
169
+ self.sine_amp = sine_amp
170
+ self.noise_std = noise_std
171
+ self.harmonic_num = harmonic_num
172
+ self.dim = self.harmonic_num + 1
173
+ self.sampling_rate = samp_rate
174
+ self.voiced_threshold = voiced_threshold
175
+ self.flag_for_pulse = flag_for_pulse
176
+
177
+ def _f02uv(self, f0):
178
+ # generate uv signal
179
+ uv = torch.ones_like(f0)
180
+ uv = uv * (f0 > self.voiced_threshold)
181
+ return uv
182
+
183
+ def _f02sine(self, f0_values):
184
+ """ f0_values: (batchsize, length, dim)
185
+ where dim indicates fundamental tone and overtones
186
+ """
187
+ # convert to F0 in rad. The interger part n can be ignored
188
+ # because 2 * np.pi * n doesn't affect phase
189
+ rad_values = (f0_values / self.sampling_rate) % 1
190
+
191
+ # initial phase noise (no noise for fundamental component)
192
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
193
+ device=f0_values.device)
194
+ rand_ini[:, 0] = 0
195
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
196
+
197
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
198
+ if not self.flag_for_pulse:
199
+ # for normal case
200
+
201
+ # To prevent torch.cumsum numerical overflow,
202
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
203
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
204
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
205
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
206
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
207
+ tmp_over_one[:, :-1, :]) < 0
208
+ cumsum_shift = torch.zeros_like(rad_values)
209
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
210
+
211
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
212
+ * 2 * np.pi)
213
+ else:
214
+ # If necessary, make sure that the first time step of every
215
+ # voiced segments is sin(pi) or cos(0)
216
+ # This is used for pulse-train generation
217
+
218
+ # identify the last time step in unvoiced segments
219
+ uv = self._f02uv(f0_values)
220
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
221
+ uv_1[:, -1, :] = 1
222
+ u_loc = (uv < 1) * (uv_1 > 0)
223
+
224
+ # get the instantanouse phase
225
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
226
+ # different batch needs to be processed differently
227
+ for idx in range(f0_values.shape[0]):
228
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
229
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
230
+ # stores the accumulation of i.phase within
231
+ # each voiced segments
232
+ tmp_cumsum[idx, :, :] = 0
233
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
234
+
235
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
236
+ # within the previous voiced segment.
237
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
238
+
239
+ # get the sines
240
+ sines = torch.cos(i_phase * 2 * np.pi)
241
+ return sines
242
+
243
+ def forward(self, f0):
244
+ """ sine_tensor, uv = forward(f0)
245
+ input F0: tensor(batchsize=1, length, dim=1)
246
+ f0 for unvoiced steps should be 0
247
+ output sine_tensor: tensor(batchsize=1, length, dim)
248
+ output uv: tensor(batchsize=1, length, 1)
249
+ """
250
+ with torch.no_grad():
251
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
252
+ device=f0.device)
253
+ # fundamental component
254
+ f0_buf[:, :, 0] = f0[:, :, 0]
255
+ for idx in np.arange(self.harmonic_num):
256
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
257
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
258
+
259
+ # generate sine waveforms
260
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
261
+
262
+ # generate uv signal
263
+ # uv = torch.ones(f0.shape)
264
+ # uv = uv * (f0 > self.voiced_threshold)
265
+ uv = self._f02uv(f0)
266
+
267
+ # noise: for unvoiced should be similar to sine_amp
268
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
269
+ # . for voiced regions is self.noise_std
270
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
271
+ noise = noise_amp * torch.randn_like(sine_waves)
272
+
273
+ # first: set the unvoiced part to 0 by uv
274
+ # then: additive noise
275
+ sine_waves = sine_waves * uv + noise
276
+ return sine_waves, uv, noise
277
+ class SourceModuleHnNSF(torch.nn.Module):
278
+ """ SourceModule for hn-nsf
279
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
280
+ add_noise_std=0.003, voiced_threshod=0)
281
+ sampling_rate: sampling_rate in Hz
282
+ harmonic_num: number of harmonic above F0 (default: 0)
283
+ sine_amp: amplitude of sine source signal (default: 0.1)
284
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
285
+ note that amplitude of noise in unvoiced is decided
286
+ by sine_amp
287
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
288
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
289
+ F0_sampled (batchsize, length, 1)
290
+ Sine_source (batchsize, length, 1)
291
+ noise_source (batchsize, length 1)
292
+ uv (batchsize, length, 1)
293
+ """
294
+
295
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
296
+ add_noise_std=0.003, voiced_threshod=0):
297
+ super(SourceModuleHnNSF, self).__init__()
298
+
299
+ self.sine_amp = sine_amp
300
+ self.noise_std = add_noise_std
301
+
302
+ # to produce sine waveforms
303
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
304
+ sine_amp, add_noise_std, voiced_threshod)
305
+
306
+ # to merge source harmonics into a single excitation
307
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
308
+ self.l_tanh = torch.nn.Tanh()
309
+
310
+ def forward(self, x):
311
+ """
312
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
313
+ F0_sampled (batchsize, length, 1)
314
+ Sine_source (batchsize, length, 1)
315
+ noise_source (batchsize, length 1)
316
+ """
317
+ # source for harmonic branch
318
+ sine_wavs, uv, _ = self.l_sin_gen(x)
319
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
320
+
321
+ # source for noise branch, in the same shape as uv
322
+ noise = torch.randn_like(uv) * self.sine_amp / 3
323
+ return sine_merge, noise, uv
324
+
325
+ class Generator(torch.nn.Module):
326
+ def __init__(self, h):
327
+ super(Generator, self).__init__()
328
+ self.h = h
329
+ self.num_kernels = len(h.resblock_kernel_sizes)
330
+ self.num_upsamples = len(h.upsample_rates)
331
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates))
332
+ self.m_source = SourceModuleHnNSF(
333
+ sampling_rate=h.sampling_rate,
334
+ harmonic_num=8)
335
+ self.noise_convs = nn.ModuleList()
336
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
337
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
338
+
339
+ self.ups = nn.ModuleList()
340
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
341
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
342
+ self.ups.append(weight_norm(
343
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
344
+ k, u, padding=(k-u)//2)))
345
+ if i + 1 < len(h.upsample_rates):#
346
+ stride_f0 = np.prod(h.upsample_rates[i + 1:])
347
+ self.noise_convs.append(Conv1d(
348
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
349
+ else:
350
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
351
+ self.resblocks = nn.ModuleList()
352
+ for i in range(len(self.ups)):
353
+ ch = h.upsample_initial_channel//(2**(i+1))
354
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
355
+ self.resblocks.append(resblock(h, ch, k, d))
356
+
357
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
358
+ self.ups.apply(init_weights)
359
+ self.conv_post.apply(init_weights)
360
+
361
+ def forward(self, x,f0):
362
+ # print(1,x.shape,f0.shape,f0[:, None].shape)
363
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)#bs,n,t
364
+ # print(2,f0.shape)
365
+ har_source, noi_source, uv = self.m_source(f0)
366
+ har_source = har_source.transpose(1, 2)
367
+ x = self.conv_pre(x)
368
+ # print(124,x.shape,har_source.shape)
369
+ for i in range(self.num_upsamples):
370
+ x = F.leaky_relu(x, LRELU_SLOPE)
371
+ # print(3,x.shape)
372
+ x = self.ups[i](x)
373
+ x_source = self.noise_convs[i](har_source)
374
+ # print(4,x_source.shape,har_source.shape,x.shape)
375
+ x = x + x_source
376
+ xs = None
377
+ for j in range(self.num_kernels):
378
+ if xs is None:
379
+ xs = self.resblocks[i*self.num_kernels+j](x)
380
+ else:
381
+ xs += self.resblocks[i*self.num_kernels+j](x)
382
+ x = xs / self.num_kernels
383
+ x = F.leaky_relu(x)
384
+ x = self.conv_post(x)
385
+ x = torch.tanh(x)
386
+
387
+ return x
388
+
389
+ def remove_weight_norm(self):
390
+ print('Removing weight norm...')
391
+ for l in self.ups:
392
+ remove_weight_norm(l)
393
+ for l in self.resblocks:
394
+ l.remove_weight_norm()
395
+ remove_weight_norm(self.conv_pre)
396
+ remove_weight_norm(self.conv_post)
397
+
398
+ class DiscriminatorP(torch.nn.Module):
399
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
400
+ super(DiscriminatorP, self).__init__()
401
+ self.period = period
402
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
403
+ self.convs = nn.ModuleList([
404
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
405
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
406
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
407
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
408
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
409
+ ])
410
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
411
+
412
+ def forward(self, x):
413
+ fmap = []
414
+
415
+ # 1d to 2d
416
+ b, c, t = x.shape
417
+ if t % self.period != 0: # pad first
418
+ n_pad = self.period - (t % self.period)
419
+ x = F.pad(x, (0, n_pad), "reflect")
420
+ t = t + n_pad
421
+ x = x.view(b, c, t // self.period, self.period)
422
+
423
+ for l in self.convs:
424
+ x = l(x)
425
+ x = F.leaky_relu(x, LRELU_SLOPE)
426
+ fmap.append(x)
427
+ x = self.conv_post(x)
428
+ fmap.append(x)
429
+ x = torch.flatten(x, 1, -1)
430
+
431
+ return x, fmap
432
+
433
+
434
+ class MultiPeriodDiscriminator(torch.nn.Module):
435
+ def __init__(self, periods=None):
436
+ super(MultiPeriodDiscriminator, self).__init__()
437
+ self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
438
+ self.discriminators = nn.ModuleList()
439
+ for period in self.periods:
440
+ self.discriminators.append(DiscriminatorP(period))
441
+
442
+ def forward(self, y, y_hat):
443
+ y_d_rs = []
444
+ y_d_gs = []
445
+ fmap_rs = []
446
+ fmap_gs = []
447
+ for i, d in enumerate(self.discriminators):
448
+ y_d_r, fmap_r = d(y)
449
+ y_d_g, fmap_g = d(y_hat)
450
+ y_d_rs.append(y_d_r)
451
+ fmap_rs.append(fmap_r)
452
+ y_d_gs.append(y_d_g)
453
+ fmap_gs.append(fmap_g)
454
+
455
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
456
+
457
+
458
+ class DiscriminatorS(torch.nn.Module):
459
+ def __init__(self, use_spectral_norm=False):
460
+ super(DiscriminatorS, self).__init__()
461
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
462
+ self.convs = nn.ModuleList([
463
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
464
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
465
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
466
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
467
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
468
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
469
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
470
+ ])
471
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
472
+
473
+ def forward(self, x):
474
+ fmap = []
475
+ for l in self.convs:
476
+ x = l(x)
477
+ x = F.leaky_relu(x, LRELU_SLOPE)
478
+ fmap.append(x)
479
+ x = self.conv_post(x)
480
+ fmap.append(x)
481
+ x = torch.flatten(x, 1, -1)
482
+
483
+ return x, fmap
484
+
485
+
486
+ class MultiScaleDiscriminator(torch.nn.Module):
487
+ def __init__(self):
488
+ super(MultiScaleDiscriminator, self).__init__()
489
+ self.discriminators = nn.ModuleList([
490
+ DiscriminatorS(use_spectral_norm=True),
491
+ DiscriminatorS(),
492
+ DiscriminatorS(),
493
+ ])
494
+ self.meanpools = nn.ModuleList([
495
+ AvgPool1d(4, 2, padding=2),
496
+ AvgPool1d(4, 2, padding=2)
497
+ ])
498
+
499
+ def forward(self, y, y_hat):
500
+ y_d_rs = []
501
+ y_d_gs = []
502
+ fmap_rs = []
503
+ fmap_gs = []
504
+ for i, d in enumerate(self.discriminators):
505
+ if i != 0:
506
+ y = self.meanpools[i-1](y)
507
+ y_hat = self.meanpools[i-1](y_hat)
508
+ y_d_r, fmap_r = d(y)
509
+ y_d_g, fmap_g = d(y_hat)
510
+ y_d_rs.append(y_d_r)
511
+ fmap_rs.append(fmap_r)
512
+ y_d_gs.append(y_d_g)
513
+ fmap_gs.append(fmap_g)
514
+
515
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
516
+
517
+
518
+ def feature_loss(fmap_r, fmap_g):
519
+ loss = 0
520
+ for dr, dg in zip(fmap_r, fmap_g):
521
+ for rl, gl in zip(dr, dg):
522
+ loss += torch.mean(torch.abs(rl - gl))
523
+
524
+ return loss*2
525
+
526
+
527
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
528
+ loss = 0
529
+ r_losses = []
530
+ g_losses = []
531
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
532
+ r_loss = torch.mean((1-dr)**2)
533
+ g_loss = torch.mean(dg**2)
534
+ loss += (r_loss + g_loss)
535
+ r_losses.append(r_loss.item())
536
+ g_losses.append(g_loss.item())
537
+
538
+ return loss, r_losses, g_losses
539
+
540
+
541
+ def generator_loss(disc_outputs):
542
+ loss = 0
543
+ gen_losses = []
544
+ for dg in disc_outputs:
545
+ l = torch.mean((1-dg)**2)
546
+ gen_losses.append(l)
547
+ loss += l
548
+
549
+ return loss, gen_losses
modules/nsf_hifigan/nvSTFT.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
4
+ import random
5
+ import torch
6
+ import torch.utils.data
7
+ import numpy as np
8
+ import librosa
9
+ from librosa.util import normalize
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from scipy.io.wavfile import read
12
+ import soundfile as sf
13
+
14
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
15
+ sampling_rate = None
16
+ try:
17
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
18
+ except Exception as ex:
19
+ print(f"'{full_path}' failed to load.\nException:")
20
+ print(ex)
21
+ if return_empty_on_exception:
22
+ return [], sampling_rate or target_sr or 48000
23
+ else:
24
+ raise Exception(ex)
25
+
26
+ if len(data.shape) > 1:
27
+ data = data[:, 0]
28
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
29
+
30
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
31
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
32
+ else: # if audio data is type fp32
33
+ max_mag = max(np.amax(data), -np.amin(data))
34
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
35
+
36
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
37
+
38
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
39
+ return [], sampling_rate or target_sr or 48000
40
+ if target_sr is not None and sampling_rate != target_sr:
41
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
42
+ sampling_rate = target_sr
43
+
44
+ return data, sampling_rate
45
+
46
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
47
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
48
+
49
+ def dynamic_range_decompression(x, C=1):
50
+ return np.exp(x) / C
51
+
52
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
53
+ return torch.log(torch.clamp(x, min=clip_val) * C)
54
+
55
+ def dynamic_range_decompression_torch(x, C=1):
56
+ return torch.exp(x) / C
57
+
58
+ class STFT():
59
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
60
+ self.target_sr = sr
61
+
62
+ self.n_mels = n_mels
63
+ self.n_fft = n_fft
64
+ self.win_size = win_size
65
+ self.hop_length = hop_length
66
+ self.fmin = fmin
67
+ self.fmax = fmax
68
+ self.clip_val = clip_val
69
+ self.mel_basis = {}
70
+ self.hann_window = {}
71
+
72
+ def get_mel(self, y, center=False):
73
+ sampling_rate = self.target_sr
74
+ n_mels = self.n_mels
75
+ n_fft = self.n_fft
76
+ win_size = self.win_size
77
+ hop_length = self.hop_length
78
+ fmin = self.fmin
79
+ fmax = self.fmax
80
+ clip_val = self.clip_val
81
+
82
+ if torch.min(y) < -1.:
83
+ print('min value is ', torch.min(y))
84
+ if torch.max(y) > 1.:
85
+ print('max value is ', torch.max(y))
86
+
87
+ if fmax not in self.mel_basis:
88
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
89
+ self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
90
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
91
+
92
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
93
+ y = y.squeeze(1)
94
+
95
+ spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
96
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
97
+ # print(111,spec)
98
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
99
+ # print(222,spec)
100
+ spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
101
+ # print(333,spec)
102
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
103
+ # print(444,spec)
104
+ return spec
105
+
106
+ def __call__(self, audiopath):
107
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
108
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
109
+ return spect
110
+
111
+ stft = STFT()
modules/nsf_hifigan/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def del_old_checkpoints(cp_dir, prefix, n_models=2):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern) # get checkpoint paths
55
+ cp_list = sorted(cp_list)# sort by iter
56
+ if len(cp_list) > n_models: # if more than n_models models are found
57
+ for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58
+ open(cp, 'w').close()# empty file contents
59
+ os.unlink(cp)# delete file (move to trash when using Colab)
60
+
61
+
62
+ def scan_checkpoint(cp_dir, prefix):
63
+ pattern = os.path.join(cp_dir, prefix + '????????')
64
+ cp_list = glob.glob(pattern)
65
+ if len(cp_list) == 0:
66
+ return None
67
+ return sorted(cp_list)[-1]
modules/parallel_wavegan/__init__.py ADDED
File without changes
modules/parallel_wavegan/layers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .causal_conv import * # NOQA
2
+ from .pqmf import * # NOQA
3
+ from .residual_block import * # NOQA
4
+ from modules.parallel_wavegan.layers.residual_stack import * # NOQA
5
+ from .upsample import * # NOQA
modules/parallel_wavegan/layers/causal_conv.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Causal convolusion layer modules."""
7
+
8
+
9
+ import torch
10
+
11
+
12
+ class CausalConv1d(torch.nn.Module):
13
+ """CausalConv1d module with customized initialization."""
14
+
15
+ def __init__(self, in_channels, out_channels, kernel_size,
16
+ dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
17
+ """Initialize CausalConv1d module."""
18
+ super(CausalConv1d, self).__init__()
19
+ self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
20
+ self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
21
+ dilation=dilation, bias=bias)
22
+
23
+ def forward(self, x):
24
+ """Calculate forward propagation.
25
+
26
+ Args:
27
+ x (Tensor): Input tensor (B, in_channels, T).
28
+
29
+ Returns:
30
+ Tensor: Output tensor (B, out_channels, T).
31
+
32
+ """
33
+ return self.conv(self.pad(x))[:, :, :x.size(2)]
34
+
35
+
36
+ class CausalConvTranspose1d(torch.nn.Module):
37
+ """CausalConvTranspose1d module with customized initialization."""
38
+
39
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
40
+ """Initialize CausalConvTranspose1d module."""
41
+ super(CausalConvTranspose1d, self).__init__()
42
+ self.deconv = torch.nn.ConvTranspose1d(
43
+ in_channels, out_channels, kernel_size, stride, bias=bias)
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ """Calculate forward propagation.
48
+
49
+ Args:
50
+ x (Tensor): Input tensor (B, in_channels, T_in).
51
+
52
+ Returns:
53
+ Tensor: Output tensor (B, out_channels, T_out).
54
+
55
+ """
56
+ return self.deconv(x)[:, :, :-self.stride]
modules/parallel_wavegan/layers/pqmf.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Pseudo QMF modules."""
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from scipy.signal import kaiser
13
+
14
+
15
+ def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
16
+ """Design prototype filter for PQMF.
17
+
18
+ This method is based on `A Kaiser window approach for the design of prototype
19
+ filters of cosine modulated filterbanks`_.
20
+
21
+ Args:
22
+ taps (int): The number of filter taps.
23
+ cutoff_ratio (float): Cut-off frequency ratio.
24
+ beta (float): Beta coefficient for kaiser window.
25
+
26
+ Returns:
27
+ ndarray: Impluse response of prototype filter (taps + 1,).
28
+
29
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
30
+ https://ieeexplore.ieee.org/abstract/document/681427
31
+
32
+ """
33
+ # check the arguments are valid
34
+ assert taps % 2 == 0, "The number of taps mush be even number."
35
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
36
+
37
+ # make initial filter
38
+ omega_c = np.pi * cutoff_ratio
39
+ with np.errstate(invalid='ignore'):
40
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
41
+ / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
42
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
43
+
44
+ # apply kaiser window
45
+ w = kaiser(taps + 1, beta)
46
+ h = h_i * w
47
+
48
+ return h
49
+
50
+
51
+ class PQMF(torch.nn.Module):
52
+ """PQMF module.
53
+
54
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
55
+
56
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
57
+ https://ieeexplore.ieee.org/document/258122
58
+
59
+ """
60
+
61
+ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
62
+ """Initilize PQMF module.
63
+
64
+ Args:
65
+ subbands (int): The number of subbands.
66
+ taps (int): The number of filter taps.
67
+ cutoff_ratio (float): Cut-off frequency ratio.
68
+ beta (float): Beta coefficient for kaiser window.
69
+
70
+ """
71
+ super(PQMF, self).__init__()
72
+
73
+ # define filter coefficient
74
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
75
+ h_analysis = np.zeros((subbands, len(h_proto)))
76
+ h_synthesis = np.zeros((subbands, len(h_proto)))
77
+ for k in range(subbands):
78
+ h_analysis[k] = 2 * h_proto * np.cos(
79
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
80
+ (np.arange(taps + 1) - ((taps - 1) / 2)) +
81
+ (-1) ** k * np.pi / 4)
82
+ h_synthesis[k] = 2 * h_proto * np.cos(
83
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
84
+ (np.arange(taps + 1) - ((taps - 1) / 2)) -
85
+ (-1) ** k * np.pi / 4)
86
+
87
+ # convert to tensor
88
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
89
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
90
+
91
+ # register coefficients as beffer
92
+ self.register_buffer("analysis_filter", analysis_filter)
93
+ self.register_buffer("synthesis_filter", synthesis_filter)
94
+
95
+ # filter for downsampling & upsampling
96
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float()
97
+ for k in range(subbands):
98
+ updown_filter[k, k, 0] = 1.0
99
+ self.register_buffer("updown_filter", updown_filter)
100
+ self.subbands = subbands
101
+
102
+ # keep padding info
103
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
104
+
105
+ def analysis(self, x):
106
+ """Analysis with PQMF.
107
+
108
+ Args:
109
+ x (Tensor): Input tensor (B, 1, T).
110
+
111
+ Returns:
112
+ Tensor: Output tensor (B, subbands, T // subbands).
113
+
114
+ """
115
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
116
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
117
+
118
+ def synthesis(self, x):
119
+ """Synthesis with PQMF.
120
+
121
+ Args:
122
+ x (Tensor): Input tensor (B, subbands, T // subbands).
123
+
124
+ Returns:
125
+ Tensor: Output tensor (B, 1, T).
126
+
127
+ """
128
+ x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
129
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
modules/parallel_wavegan/layers/residual_block.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Residual block module in WaveNet.
4
+
5
+ This code is modified from https://github.com/r9y9/wavenet_vocoder.
6
+
7
+ """
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class Conv1d(torch.nn.Conv1d):
16
+ """Conv1d module with customized initialization."""
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ """Initialize Conv1d module."""
20
+ super(Conv1d, self).__init__(*args, **kwargs)
21
+
22
+ def reset_parameters(self):
23
+ """Reset parameters."""
24
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
25
+ if self.bias is not None:
26
+ torch.nn.init.constant_(self.bias, 0.0)
27
+
28
+
29
+ class Conv1d1x1(Conv1d):
30
+ """1x1 Conv1d with customized initialization."""
31
+
32
+ def __init__(self, in_channels, out_channels, bias):
33
+ """Initialize 1x1 Conv1d module."""
34
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
35
+ kernel_size=1, padding=0,
36
+ dilation=1, bias=bias)
37
+
38
+
39
+ class ResidualBlock(torch.nn.Module):
40
+ """Residual block module in WaveNet."""
41
+
42
+ def __init__(self,
43
+ kernel_size=3,
44
+ residual_channels=64,
45
+ gate_channels=128,
46
+ skip_channels=64,
47
+ aux_channels=80,
48
+ dropout=0.0,
49
+ dilation=1,
50
+ bias=True,
51
+ use_causal_conv=False
52
+ ):
53
+ """Initialize ResidualBlock module.
54
+
55
+ Args:
56
+ kernel_size (int): Kernel size of dilation convolution layer.
57
+ residual_channels (int): Number of channels for residual connection.
58
+ skip_channels (int): Number of channels for skip connection.
59
+ aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
60
+ dropout (float): Dropout probability.
61
+ dilation (int): Dilation factor.
62
+ bias (bool): Whether to add bias parameter in convolution layers.
63
+ use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
64
+
65
+ """
66
+ super(ResidualBlock, self).__init__()
67
+ self.dropout = dropout
68
+ # no future time stamps available
69
+ if use_causal_conv:
70
+ padding = (kernel_size - 1) * dilation
71
+ else:
72
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
73
+ padding = (kernel_size - 1) // 2 * dilation
74
+ self.use_causal_conv = use_causal_conv
75
+
76
+ # dilation conv
77
+ self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
78
+ padding=padding, dilation=dilation, bias=bias)
79
+
80
+ # local conditioning
81
+ if aux_channels > 0:
82
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
83
+ else:
84
+ self.conv1x1_aux = None
85
+
86
+ # conv output is split into two groups
87
+ gate_out_channels = gate_channels // 2
88
+ self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
89
+ self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
90
+
91
+ def forward(self, x, c):
92
+ """Calculate forward propagation.
93
+
94
+ Args:
95
+ x (Tensor): Input tensor (B, residual_channels, T).
96
+ c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
97
+
98
+ Returns:
99
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
100
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
101
+
102
+ """
103
+ residual = x
104
+ x = F.dropout(x, p=self.dropout, training=self.training)
105
+ x = self.conv(x)
106
+
107
+ # remove future time steps if use_causal_conv conv
108
+ x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
109
+
110
+ # split into two part for gated activation
111
+ splitdim = 1
112
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
113
+
114
+ # local conditioning
115
+ if c is not None:
116
+ assert self.conv1x1_aux is not None
117
+ c = self.conv1x1_aux(c)
118
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
119
+ xa, xb = xa + ca, xb + cb
120
+
121
+ x = torch.tanh(xa) * torch.sigmoid(xb)
122
+
123
+ # for skip connection
124
+ s = self.conv1x1_skip(x)
125
+
126
+ # for residual connection
127
+ x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
128
+
129
+ return x, s
modules/parallel_wavegan/layers/residual_stack.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Residual stack module in MelGAN."""
7
+
8
+ import torch
9
+
10
+ from . import CausalConv1d
11
+
12
+
13
+ class ResidualStack(torch.nn.Module):
14
+ """Residual stack module introduced in MelGAN."""
15
+
16
+ def __init__(self,
17
+ kernel_size=3,
18
+ channels=32,
19
+ dilation=1,
20
+ bias=True,
21
+ nonlinear_activation="LeakyReLU",
22
+ nonlinear_activation_params={"negative_slope": 0.2},
23
+ pad="ReflectionPad1d",
24
+ pad_params={},
25
+ use_causal_conv=False,
26
+ ):
27
+ """Initialize ResidualStack module.
28
+
29
+ Args:
30
+ kernel_size (int): Kernel size of dilation convolution layer.
31
+ channels (int): Number of channels of convolution layers.
32
+ dilation (int): Dilation factor.
33
+ bias (bool): Whether to add bias parameter in convolution layers.
34
+ nonlinear_activation (str): Activation function module name.
35
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
36
+ pad (str): Padding function module name before dilated convolution layer.
37
+ pad_params (dict): Hyperparameters for padding function.
38
+ use_causal_conv (bool): Whether to use causal convolution.
39
+
40
+ """
41
+ super(ResidualStack, self).__init__()
42
+
43
+ # defile residual stack part
44
+ if not use_causal_conv:
45
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
46
+ self.stack = torch.nn.Sequential(
47
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
48
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
49
+ torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
50
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
51
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
52
+ )
53
+ else:
54
+ self.stack = torch.nn.Sequential(
55
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
56
+ CausalConv1d(channels, channels, kernel_size, dilation=dilation,
57
+ bias=bias, pad=pad, pad_params=pad_params),
58
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
59
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
60
+ )
61
+
62
+ # defile extra layer for skip connection
63
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
64
+
65
+ def forward(self, c):
66
+ """Calculate forward propagation.
67
+
68
+ Args:
69
+ c (Tensor): Input tensor (B, channels, T).
70
+
71
+ Returns:
72
+ Tensor: Output tensor (B, chennels, T).
73
+
74
+ """
75
+ return self.stack(c) + self.skip_layer(c)
modules/parallel_wavegan/layers/tf_layers.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 MINH ANH (@dathudeptrai)
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Tensorflow Layer modules complatible with pytorch."""
7
+
8
+ import tensorflow as tf
9
+
10
+
11
+ class TFReflectionPad1d(tf.keras.layers.Layer):
12
+ """Tensorflow ReflectionPad1d module."""
13
+
14
+ def __init__(self, padding_size):
15
+ """Initialize TFReflectionPad1d module.
16
+
17
+ Args:
18
+ padding_size (int): Padding size.
19
+
20
+ """
21
+ super(TFReflectionPad1d, self).__init__()
22
+ self.padding_size = padding_size
23
+
24
+ @tf.function
25
+ def call(self, x):
26
+ """Calculate forward propagation.
27
+
28
+ Args:
29
+ x (Tensor): Input tensor (B, T, 1, C).
30
+
31
+ Returns:
32
+ Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
33
+
34
+ """
35
+ return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
36
+
37
+
38
+ class TFConvTranspose1d(tf.keras.layers.Layer):
39
+ """Tensorflow ConvTranspose1d module."""
40
+
41
+ def __init__(self, channels, kernel_size, stride, padding):
42
+ """Initialize TFConvTranspose1d( module.
43
+
44
+ Args:
45
+ channels (int): Number of channels.
46
+ kernel_size (int): kernel size.
47
+ strides (int): Stride width.
48
+ padding (str): Padding type ("same" or "valid").
49
+
50
+ """
51
+ super(TFConvTranspose1d, self).__init__()
52
+ self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
53
+ filters=channels,
54
+ kernel_size=(kernel_size, 1),
55
+ strides=(stride, 1),
56
+ padding=padding,
57
+ )
58
+
59
+ @tf.function
60
+ def call(self, x):
61
+ """Calculate forward propagation.
62
+
63
+ Args:
64
+ x (Tensor): Input tensor (B, T, 1, C).
65
+
66
+ Returns:
67
+ Tensors: Output tensor (B, T', 1, C').
68
+
69
+ """
70
+ x = self.conv1d_transpose(x)
71
+ return x
72
+
73
+
74
+ class TFResidualStack(tf.keras.layers.Layer):
75
+ """Tensorflow ResidualStack module."""
76
+
77
+ def __init__(self,
78
+ kernel_size,
79
+ channels,
80
+ dilation,
81
+ bias,
82
+ nonlinear_activation,
83
+ nonlinear_activation_params,
84
+ padding,
85
+ ):
86
+ """Initialize TFResidualStack module.
87
+
88
+ Args:
89
+ kernel_size (int): Kernel size.
90
+ channles (int): Number of channels.
91
+ dilation (int): Dilation ine.
92
+ bias (bool): Whether to add bias parameter in convolution layers.
93
+ nonlinear_activation (str): Activation function module name.
94
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
95
+ padding (str): Padding type ("same" or "valid").
96
+
97
+ """
98
+ super(TFResidualStack, self).__init__()
99
+ self.block = [
100
+ getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
101
+ TFReflectionPad1d(dilation),
102
+ tf.keras.layers.Conv2D(
103
+ filters=channels,
104
+ kernel_size=(kernel_size, 1),
105
+ dilation_rate=(dilation, 1),
106
+ use_bias=bias,
107
+ padding="valid",
108
+ ),
109
+ getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
110
+ tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
111
+ ]
112
+ self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
113
+
114
+ @tf.function
115
+ def call(self, x):
116
+ """Calculate forward propagation.
117
+
118
+ Args:
119
+ x (Tensor): Input tensor (B, T, 1, C).
120
+
121
+ Returns:
122
+ Tensor: Output tensor (B, T, 1, C).
123
+
124
+ """
125
+ _x = tf.identity(x)
126
+ for i, layer in enumerate(self.block):
127
+ _x = layer(_x)
128
+ shortcut = self.shortcut(x)
129
+ return shortcut + _x
modules/parallel_wavegan/layers/upsample.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """Upsampling module.
4
+
5
+ This code is modified from https://github.com/r9y9/wavenet_vocoder.
6
+
7
+ """
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from . import Conv1d
14
+
15
+
16
+ class Stretch2d(torch.nn.Module):
17
+ """Stretch2d module."""
18
+
19
+ def __init__(self, x_scale, y_scale, mode="nearest"):
20
+ """Initialize Stretch2d module.
21
+
22
+ Args:
23
+ x_scale (int): X scaling factor (Time axis in spectrogram).
24
+ y_scale (int): Y scaling factor (Frequency axis in spectrogram).
25
+ mode (str): Interpolation mode.
26
+
27
+ """
28
+ super(Stretch2d, self).__init__()
29
+ self.x_scale = x_scale
30
+ self.y_scale = y_scale
31
+ self.mode = mode
32
+
33
+ def forward(self, x):
34
+ """Calculate forward propagation.
35
+
36
+ Args:
37
+ x (Tensor): Input tensor (B, C, F, T).
38
+
39
+ Returns:
40
+ Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
41
+
42
+ """
43
+ return F.interpolate(
44
+ x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
45
+
46
+
47
+ class Conv2d(torch.nn.Conv2d):
48
+ """Conv2d module with customized initialization."""
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ """Initialize Conv2d module."""
52
+ super(Conv2d, self).__init__(*args, **kwargs)
53
+
54
+ def reset_parameters(self):
55
+ """Reset parameters."""
56
+ self.weight.data.fill_(1. / np.prod(self.kernel_size))
57
+ if self.bias is not None:
58
+ torch.nn.init.constant_(self.bias, 0.0)
59
+
60
+
61
+ class UpsampleNetwork(torch.nn.Module):
62
+ """Upsampling network module."""
63
+
64
+ def __init__(self,
65
+ upsample_scales,
66
+ nonlinear_activation=None,
67
+ nonlinear_activation_params={},
68
+ interpolate_mode="nearest",
69
+ freq_axis_kernel_size=1,
70
+ use_causal_conv=False,
71
+ ):
72
+ """Initialize upsampling network module.
73
+
74
+ Args:
75
+ upsample_scales (list): List of upsampling scales.
76
+ nonlinear_activation (str): Activation function name.
77
+ nonlinear_activation_params (dict): Arguments for specified activation function.
78
+ interpolate_mode (str): Interpolation mode.
79
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
80
+
81
+ """
82
+ super(UpsampleNetwork, self).__init__()
83
+ self.use_causal_conv = use_causal_conv
84
+ self.up_layers = torch.nn.ModuleList()
85
+ for scale in upsample_scales:
86
+ # interpolation layer
87
+ stretch = Stretch2d(scale, 1, interpolate_mode)
88
+ self.up_layers += [stretch]
89
+
90
+ # conv layer
91
+ assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
92
+ freq_axis_padding = (freq_axis_kernel_size - 1) // 2
93
+ kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
94
+ if use_causal_conv:
95
+ padding = (freq_axis_padding, scale * 2)
96
+ else:
97
+ padding = (freq_axis_padding, scale)
98
+ conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
99
+ self.up_layers += [conv]
100
+
101
+ # nonlinear
102
+ if nonlinear_activation is not None:
103
+ nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
104
+ self.up_layers += [nonlinear]
105
+
106
+ def forward(self, c):
107
+ """Calculate forward propagation.
108
+
109
+ Args:
110
+ c : Input tensor (B, C, T).
111
+
112
+ Returns:
113
+ Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
114
+
115
+ """
116
+ c = c.unsqueeze(1) # (B, 1, C, T)
117
+ for f in self.up_layers:
118
+ if self.use_causal_conv and isinstance(f, Conv2d):
119
+ c = f(c)[..., :c.size(-1)]
120
+ else:
121
+ c = f(c)
122
+ return c.squeeze(1) # (B, C, T')
123
+
124
+
125
+ class ConvInUpsampleNetwork(torch.nn.Module):
126
+ """Convolution + upsampling network module."""
127
+
128
+ def __init__(self,
129
+ upsample_scales,
130
+ nonlinear_activation=None,
131
+ nonlinear_activation_params={},
132
+ interpolate_mode="nearest",
133
+ freq_axis_kernel_size=1,
134
+ aux_channels=80,
135
+ aux_context_window=0,
136
+ use_causal_conv=False
137
+ ):
138
+ """Initialize convolution + upsampling network module.
139
+
140
+ Args:
141
+ upsample_scales (list): List of upsampling scales.
142
+ nonlinear_activation (str): Activation function name.
143
+ nonlinear_activation_params (dict): Arguments for specified activation function.
144
+ mode (str): Interpolation mode.
145
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
146
+ aux_channels (int): Number of channels of pre-convolutional layer.
147
+ aux_context_window (int): Context window size of the pre-convolutional layer.
148
+ use_causal_conv (bool): Whether to use causal structure.
149
+
150
+ """
151
+ super(ConvInUpsampleNetwork, self).__init__()
152
+ self.aux_context_window = aux_context_window
153
+ self.use_causal_conv = use_causal_conv and aux_context_window > 0
154
+ # To capture wide-context information in conditional features
155
+ kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
156
+ # NOTE(kan-bayashi): Here do not use padding because the input is already padded
157
+ self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
158
+ self.upsample = UpsampleNetwork(
159
+ upsample_scales=upsample_scales,
160
+ nonlinear_activation=nonlinear_activation,
161
+ nonlinear_activation_params=nonlinear_activation_params,
162
+ interpolate_mode=interpolate_mode,
163
+ freq_axis_kernel_size=freq_axis_kernel_size,
164
+ use_causal_conv=use_causal_conv,
165
+ )
166
+
167
+ def forward(self, c):
168
+ """Calculate forward propagation.
169
+
170
+ Args:
171
+ c : Input tensor (B, C, T').
172
+
173
+ Returns:
174
+ Tensor: Upsampled tensor (B, C, T),
175
+ where T = (T' - aux_context_window * 2) * prod(upsample_scales).
176
+
177
+ Note:
178
+ The length of inputs considers the context window size.
179
+
180
+ """
181
+ c_ = self.conv_in(c)
182
+ c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
183
+ return self.upsample(c)
modules/parallel_wavegan/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .stft_loss import * # NOQA
modules/parallel_wavegan/losses/stft_loss.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """STFT-based Loss modules."""
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+
15
+ Args:
16
+ x (Tensor): Input signal tensor (B, T).
17
+ fft_size (int): FFT size.
18
+ hop_size (int): Hop size.
19
+ win_length (int): Window length.
20
+ window (str): Window function type.
21
+
22
+ Returns:
23
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
24
+
25
+ """
26
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
27
+ real = x_stft[..., 0]
28
+ imag = x_stft[..., 1]
29
+
30
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
31
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
32
+
33
+
34
+ class SpectralConvergengeLoss(torch.nn.Module):
35
+ """Spectral convergence loss module."""
36
+
37
+ def __init__(self):
38
+ """Initilize spectral convergence loss module."""
39
+ super(SpectralConvergengeLoss, self).__init__()
40
+
41
+ def forward(self, x_mag, y_mag):
42
+ """Calculate forward propagation.
43
+
44
+ Args:
45
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
46
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
47
+
48
+ Returns:
49
+ Tensor: Spectral convergence loss value.
50
+
51
+ """
52
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
53
+
54
+
55
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
56
+ """Log STFT magnitude loss module."""
57
+
58
+ def __init__(self):
59
+ """Initilize los STFT magnitude loss module."""
60
+ super(LogSTFTMagnitudeLoss, self).__init__()
61
+
62
+ def forward(self, x_mag, y_mag):
63
+ """Calculate forward propagation.
64
+
65
+ Args:
66
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
67
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
68
+
69
+ Returns:
70
+ Tensor: Log STFT magnitude loss value.
71
+
72
+ """
73
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
74
+
75
+
76
+ class STFTLoss(torch.nn.Module):
77
+ """STFT loss module."""
78
+
79
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
80
+ """Initialize STFT loss module."""
81
+ super(STFTLoss, self).__init__()
82
+ self.fft_size = fft_size
83
+ self.shift_size = shift_size
84
+ self.win_length = win_length
85
+ self.window = getattr(torch, window)(win_length)
86
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
87
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
88
+
89
+ def forward(self, x, y):
90
+ """Calculate forward propagation.
91
+
92
+ Args:
93
+ x (Tensor): Predicted signal (B, T).
94
+ y (Tensor): Groundtruth signal (B, T).
95
+
96
+ Returns:
97
+ Tensor: Spectral convergence loss value.
98
+ Tensor: Log STFT magnitude loss value.
99
+
100
+ """
101
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
102
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
103
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
104
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
105
+
106
+ return sc_loss, mag_loss
107
+
108
+
109
+ class MultiResolutionSTFTLoss(torch.nn.Module):
110
+ """Multi resolution STFT loss module."""
111
+
112
+ def __init__(self,
113
+ fft_sizes=[1024, 2048, 512],
114
+ hop_sizes=[120, 240, 50],
115
+ win_lengths=[600, 1200, 240],
116
+ window="hann_window"):
117
+ """Initialize Multi resolution STFT loss module.
118
+
119
+ Args:
120
+ fft_sizes (list): List of FFT sizes.
121
+ hop_sizes (list): List of hop sizes.
122
+ win_lengths (list): List of window lengths.
123
+ window (str): Window function type.
124
+
125
+ """
126
+ super(MultiResolutionSTFTLoss, self).__init__()
127
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
128
+ self.stft_losses = torch.nn.ModuleList()
129
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
130
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
131
+
132
+ def forward(self, x, y):
133
+ """Calculate forward propagation.
134
+
135
+ Args:
136
+ x (Tensor): Predicted signal (B, T).
137
+ y (Tensor): Groundtruth signal (B, T).
138
+
139
+ Returns:
140
+ Tensor: Multi resolution spectral convergence loss value.
141
+ Tensor: Multi resolution log STFT magnitude loss value.
142
+
143
+ """
144
+ sc_loss = 0.0
145
+ mag_loss = 0.0
146
+ for f in self.stft_losses:
147
+ sc_l, mag_l = f(x, y)
148
+ sc_loss += sc_l
149
+ mag_loss += mag_l
150
+ sc_loss /= len(self.stft_losses)
151
+ mag_loss /= len(self.stft_losses)
152
+
153
+ return sc_loss, mag_loss
modules/parallel_wavegan/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .melgan import * # NOQA
2
+ from .parallel_wavegan import * # NOQA
modules/parallel_wavegan/models/melgan.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2020 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """MelGAN Modules."""
7
+
8
+ import logging
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from modules.parallel_wavegan.layers import CausalConv1d
14
+ from modules.parallel_wavegan.layers import CausalConvTranspose1d
15
+ from modules.parallel_wavegan.layers import ResidualStack
16
+
17
+
18
+ class MelGANGenerator(torch.nn.Module):
19
+ """MelGAN generator module."""
20
+
21
+ def __init__(self,
22
+ in_channels=80,
23
+ out_channels=1,
24
+ kernel_size=7,
25
+ channels=512,
26
+ bias=True,
27
+ upsample_scales=[8, 8, 2, 2],
28
+ stack_kernel_size=3,
29
+ stacks=3,
30
+ nonlinear_activation="LeakyReLU",
31
+ nonlinear_activation_params={"negative_slope": 0.2},
32
+ pad="ReflectionPad1d",
33
+ pad_params={},
34
+ use_final_nonlinear_activation=True,
35
+ use_weight_norm=True,
36
+ use_causal_conv=False,
37
+ ):
38
+ """Initialize MelGANGenerator module.
39
+
40
+ Args:
41
+ in_channels (int): Number of input channels.
42
+ out_channels (int): Number of output channels.
43
+ kernel_size (int): Kernel size of initial and final conv layer.
44
+ channels (int): Initial number of channels for conv layer.
45
+ bias (bool): Whether to add bias parameter in convolution layers.
46
+ upsample_scales (list): List of upsampling scales.
47
+ stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
48
+ stacks (int): Number of stacks in a single residual stack.
49
+ nonlinear_activation (str): Activation function module name.
50
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
51
+ pad (str): Padding function module name before dilated convolution layer.
52
+ pad_params (dict): Hyperparameters for padding function.
53
+ use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
54
+ use_weight_norm (bool): Whether to use weight norm.
55
+ If set to true, it will be applied to all of the conv layers.
56
+ use_causal_conv (bool): Whether to use causal convolution.
57
+
58
+ """
59
+ super(MelGANGenerator, self).__init__()
60
+
61
+ # check hyper parameters is valid
62
+ assert channels >= np.prod(upsample_scales)
63
+ assert channels % (2 ** len(upsample_scales)) == 0
64
+ if not use_causal_conv:
65
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
66
+
67
+ # add initial layer
68
+ layers = []
69
+ if not use_causal_conv:
70
+ layers += [
71
+ getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
72
+ torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
73
+ ]
74
+ else:
75
+ layers += [
76
+ CausalConv1d(in_channels, channels, kernel_size,
77
+ bias=bias, pad=pad, pad_params=pad_params),
78
+ ]
79
+
80
+ for i, upsample_scale in enumerate(upsample_scales):
81
+ # add upsampling layer
82
+ layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
83
+ if not use_causal_conv:
84
+ layers += [
85
+ torch.nn.ConvTranspose1d(
86
+ channels // (2 ** i),
87
+ channels // (2 ** (i + 1)),
88
+ upsample_scale * 2,
89
+ stride=upsample_scale,
90
+ padding=upsample_scale // 2 + upsample_scale % 2,
91
+ output_padding=upsample_scale % 2,
92
+ bias=bias,
93
+ )
94
+ ]
95
+ else:
96
+ layers += [
97
+ CausalConvTranspose1d(
98
+ channels // (2 ** i),
99
+ channels // (2 ** (i + 1)),
100
+ upsample_scale * 2,
101
+ stride=upsample_scale,
102
+ bias=bias,
103
+ )
104
+ ]
105
+
106
+ # add residual stack
107
+ for j in range(stacks):
108
+ layers += [
109
+ ResidualStack(
110
+ kernel_size=stack_kernel_size,
111
+ channels=channels // (2 ** (i + 1)),
112
+ dilation=stack_kernel_size ** j,
113
+ bias=bias,
114
+ nonlinear_activation=nonlinear_activation,
115
+ nonlinear_activation_params=nonlinear_activation_params,
116
+ pad=pad,
117
+ pad_params=pad_params,
118
+ use_causal_conv=use_causal_conv,
119
+ )
120
+ ]
121
+
122
+ # add final layer
123
+ layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
124
+ if not use_causal_conv:
125
+ layers += [
126
+ getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
127
+ torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
128
+ ]
129
+ else:
130
+ layers += [
131
+ CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
132
+ bias=bias, pad=pad, pad_params=pad_params),
133
+ ]
134
+ if use_final_nonlinear_activation:
135
+ layers += [torch.nn.Tanh()]
136
+
137
+ # define the model as a single function
138
+ self.melgan = torch.nn.Sequential(*layers)
139
+
140
+ # apply weight norm
141
+ if use_weight_norm:
142
+ self.apply_weight_norm()
143
+
144
+ # reset parameters
145
+ self.reset_parameters()
146
+
147
+ def forward(self, c):
148
+ """Calculate forward propagation.
149
+
150
+ Args:
151
+ c (Tensor): Input tensor (B, channels, T).
152
+
153
+ Returns:
154
+ Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
155
+
156
+ """
157
+ return self.melgan(c)
158
+
159
+ def remove_weight_norm(self):
160
+ """Remove weight normalization module from all of the layers."""
161
+ def _remove_weight_norm(m):
162
+ try:
163
+ logging.debug(f"Weight norm is removed from {m}.")
164
+ torch.nn.utils.remove_weight_norm(m)
165
+ except ValueError: # this module didn't have weight norm
166
+ return
167
+
168
+ self.apply(_remove_weight_norm)
169
+
170
+ def apply_weight_norm(self):
171
+ """Apply weight normalization module from all of the layers."""
172
+ def _apply_weight_norm(m):
173
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
174
+ torch.nn.utils.weight_norm(m)
175
+ logging.debug(f"Weight norm is applied to {m}.")
176
+
177
+ self.apply(_apply_weight_norm)
178
+
179
+ def reset_parameters(self):
180
+ """Reset parameters.
181
+
182
+ This initialization follows official implementation manner.
183
+ https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
184
+
185
+ """
186
+ def _reset_parameters(m):
187
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
188
+ m.weight.data.normal_(0.0, 0.02)
189
+ logging.debug(f"Reset parameters in {m}.")
190
+
191
+ self.apply(_reset_parameters)
192
+
193
+
194
+ class MelGANDiscriminator(torch.nn.Module):
195
+ """MelGAN discriminator module."""
196
+
197
+ def __init__(self,
198
+ in_channels=1,
199
+ out_channels=1,
200
+ kernel_sizes=[5, 3],
201
+ channels=16,
202
+ max_downsample_channels=1024,
203
+ bias=True,
204
+ downsample_scales=[4, 4, 4, 4],
205
+ nonlinear_activation="LeakyReLU",
206
+ nonlinear_activation_params={"negative_slope": 0.2},
207
+ pad="ReflectionPad1d",
208
+ pad_params={},
209
+ ):
210
+ """Initilize MelGAN discriminator module.
211
+
212
+ Args:
213
+ in_channels (int): Number of input channels.
214
+ out_channels (int): Number of output channels.
215
+ kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
216
+ and the first and the second kernel sizes will be used for the last two layers.
217
+ For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
218
+ the last two layers' kernel size will be 5 and 3, respectively.
219
+ channels (int): Initial number of channels for conv layer.
220
+ max_downsample_channels (int): Maximum number of channels for downsampling layers.
221
+ bias (bool): Whether to add bias parameter in convolution layers.
222
+ downsample_scales (list): List of downsampling scales.
223
+ nonlinear_activation (str): Activation function module name.
224
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
225
+ pad (str): Padding function module name before dilated convolution layer.
226
+ pad_params (dict): Hyperparameters for padding function.
227
+
228
+ """
229
+ super(MelGANDiscriminator, self).__init__()
230
+ self.layers = torch.nn.ModuleList()
231
+
232
+ # check kernel size is valid
233
+ assert len(kernel_sizes) == 2
234
+ assert kernel_sizes[0] % 2 == 1
235
+ assert kernel_sizes[1] % 2 == 1
236
+
237
+ # add first layer
238
+ self.layers += [
239
+ torch.nn.Sequential(
240
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
241
+ torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
242
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
243
+ )
244
+ ]
245
+
246
+ # add downsample layers
247
+ in_chs = channels
248
+ for downsample_scale in downsample_scales:
249
+ out_chs = min(in_chs * downsample_scale, max_downsample_channels)
250
+ self.layers += [
251
+ torch.nn.Sequential(
252
+ torch.nn.Conv1d(
253
+ in_chs, out_chs,
254
+ kernel_size=downsample_scale * 10 + 1,
255
+ stride=downsample_scale,
256
+ padding=downsample_scale * 5,
257
+ groups=in_chs // 4,
258
+ bias=bias,
259
+ ),
260
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
261
+ )
262
+ ]
263
+ in_chs = out_chs
264
+
265
+ # add final layers
266
+ out_chs = min(in_chs * 2, max_downsample_channels)
267
+ self.layers += [
268
+ torch.nn.Sequential(
269
+ torch.nn.Conv1d(
270
+ in_chs, out_chs, kernel_sizes[0],
271
+ padding=(kernel_sizes[0] - 1) // 2,
272
+ bias=bias,
273
+ ),
274
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
275
+ )
276
+ ]
277
+ self.layers += [
278
+ torch.nn.Conv1d(
279
+ out_chs, out_channels, kernel_sizes[1],
280
+ padding=(kernel_sizes[1] - 1) // 2,
281
+ bias=bias,
282
+ ),
283
+ ]
284
+
285
+ def forward(self, x):
286
+ """Calculate forward propagation.
287
+
288
+ Args:
289
+ x (Tensor): Input noise signal (B, 1, T).
290
+
291
+ Returns:
292
+ List: List of output tensors of each layer.
293
+
294
+ """
295
+ outs = []
296
+ for f in self.layers:
297
+ x = f(x)
298
+ outs += [x]
299
+
300
+ return outs
301
+
302
+
303
+ class MelGANMultiScaleDiscriminator(torch.nn.Module):
304
+ """MelGAN multi-scale discriminator module."""
305
+
306
+ def __init__(self,
307
+ in_channels=1,
308
+ out_channels=1,
309
+ scales=3,
310
+ downsample_pooling="AvgPool1d",
311
+ # follow the official implementation setting
312
+ downsample_pooling_params={
313
+ "kernel_size": 4,
314
+ "stride": 2,
315
+ "padding": 1,
316
+ "count_include_pad": False,
317
+ },
318
+ kernel_sizes=[5, 3],
319
+ channels=16,
320
+ max_downsample_channels=1024,
321
+ bias=True,
322
+ downsample_scales=[4, 4, 4, 4],
323
+ nonlinear_activation="LeakyReLU",
324
+ nonlinear_activation_params={"negative_slope": 0.2},
325
+ pad="ReflectionPad1d",
326
+ pad_params={},
327
+ use_weight_norm=True,
328
+ ):
329
+ """Initilize MelGAN multi-scale discriminator module.
330
+
331
+ Args:
332
+ in_channels (int): Number of input channels.
333
+ out_channels (int): Number of output channels.
334
+ downsample_pooling (str): Pooling module name for downsampling of the inputs.
335
+ downsample_pooling_params (dict): Parameters for the above pooling module.
336
+ kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
337
+ and the first and the second kernel sizes will be used for the last two layers.
338
+ channels (int): Initial number of channels for conv layer.
339
+ max_downsample_channels (int): Maximum number of channels for downsampling layers.
340
+ bias (bool): Whether to add bias parameter in convolution layers.
341
+ downsample_scales (list): List of downsampling scales.
342
+ nonlinear_activation (str): Activation function module name.
343
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
344
+ pad (str): Padding function module name before dilated convolution layer.
345
+ pad_params (dict): Hyperparameters for padding function.
346
+ use_causal_conv (bool): Whether to use causal convolution.
347
+
348
+ """
349
+ super(MelGANMultiScaleDiscriminator, self).__init__()
350
+ self.discriminators = torch.nn.ModuleList()
351
+
352
+ # add discriminators
353
+ for _ in range(scales):
354
+ self.discriminators += [
355
+ MelGANDiscriminator(
356
+ in_channels=in_channels,
357
+ out_channels=out_channels,
358
+ kernel_sizes=kernel_sizes,
359
+ channels=channels,
360
+ max_downsample_channels=max_downsample_channels,
361
+ bias=bias,
362
+ downsample_scales=downsample_scales,
363
+ nonlinear_activation=nonlinear_activation,
364
+ nonlinear_activation_params=nonlinear_activation_params,
365
+ pad=pad,
366
+ pad_params=pad_params,
367
+ )
368
+ ]
369
+ self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
370
+
371
+ # apply weight norm
372
+ if use_weight_norm:
373
+ self.apply_weight_norm()
374
+
375
+ # reset parameters
376
+ self.reset_parameters()
377
+
378
+ def forward(self, x):
379
+ """Calculate forward propagation.
380
+
381
+ Args:
382
+ x (Tensor): Input noise signal (B, 1, T).
383
+
384
+ Returns:
385
+ List: List of list of each discriminator outputs, which consists of each layer output tensors.
386
+
387
+ """
388
+ outs = []
389
+ for f in self.discriminators:
390
+ outs += [f(x)]
391
+ x = self.pooling(x)
392
+
393
+ return outs
394
+
395
+ def remove_weight_norm(self):
396
+ """Remove weight normalization module from all of the layers."""
397
+ def _remove_weight_norm(m):
398
+ try:
399
+ logging.debug(f"Weight norm is removed from {m}.")
400
+ torch.nn.utils.remove_weight_norm(m)
401
+ except ValueError: # this module didn't have weight norm
402
+ return
403
+
404
+ self.apply(_remove_weight_norm)
405
+
406
+ def apply_weight_norm(self):
407
+ """Apply weight normalization module from all of the layers."""
408
+ def _apply_weight_norm(m):
409
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
410
+ torch.nn.utils.weight_norm(m)
411
+ logging.debug(f"Weight norm is applied to {m}.")
412
+
413
+ self.apply(_apply_weight_norm)
414
+
415
+ def reset_parameters(self):
416
+ """Reset parameters.
417
+
418
+ This initialization follows official implementation manner.
419
+ https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
420
+
421
+ """
422
+ def _reset_parameters(m):
423
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
424
+ m.weight.data.normal_(0.0, 0.02)
425
+ logging.debug(f"Reset parameters in {m}.")
426
+
427
+ self.apply(_reset_parameters)
modules/parallel_wavegan/models/parallel_wavegan.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Parallel WaveGAN Modules."""
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from modules.parallel_wavegan.layers import Conv1d
15
+ from modules.parallel_wavegan.layers import Conv1d1x1
16
+ from modules.parallel_wavegan.layers import ResidualBlock
17
+ from modules.parallel_wavegan.layers import upsample
18
+ from modules.parallel_wavegan import models
19
+
20
+
21
+ class ParallelWaveGANGenerator(torch.nn.Module):
22
+ """Parallel WaveGAN Generator module."""
23
+
24
+ def __init__(self,
25
+ in_channels=1,
26
+ out_channels=1,
27
+ kernel_size=3,
28
+ layers=30,
29
+ stacks=3,
30
+ residual_channels=64,
31
+ gate_channels=128,
32
+ skip_channels=64,
33
+ aux_channels=80,
34
+ aux_context_window=2,
35
+ dropout=0.0,
36
+ bias=True,
37
+ use_weight_norm=True,
38
+ use_causal_conv=False,
39
+ upsample_conditional_features=True,
40
+ upsample_net="ConvInUpsampleNetwork",
41
+ upsample_params={"upsample_scales": [4, 4, 4, 4]},
42
+ use_pitch_embed=False,
43
+ ):
44
+ """Initialize Parallel WaveGAN Generator module.
45
+
46
+ Args:
47
+ in_channels (int): Number of input channels.
48
+ out_channels (int): Number of output channels.
49
+ kernel_size (int): Kernel size of dilated convolution.
50
+ layers (int): Number of residual block layers.
51
+ stacks (int): Number of stacks i.e., dilation cycles.
52
+ residual_channels (int): Number of channels in residual conv.
53
+ gate_channels (int): Number of channels in gated conv.
54
+ skip_channels (int): Number of channels in skip conv.
55
+ aux_channels (int): Number of channels for auxiliary feature conv.
56
+ aux_context_window (int): Context window size for auxiliary feature.
57
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
58
+ bias (bool): Whether to use bias parameter in conv layer.
59
+ use_weight_norm (bool): Whether to use weight norm.
60
+ If set to true, it will be applied to all of the conv layers.
61
+ use_causal_conv (bool): Whether to use causal structure.
62
+ upsample_conditional_features (bool): Whether to use upsampling network.
63
+ upsample_net (str): Upsampling network architecture.
64
+ upsample_params (dict): Upsampling network parameters.
65
+
66
+ """
67
+ super(ParallelWaveGANGenerator, self).__init__()
68
+ self.in_channels = in_channels
69
+ self.out_channels = out_channels
70
+ self.aux_channels = aux_channels
71
+ self.layers = layers
72
+ self.stacks = stacks
73
+ self.kernel_size = kernel_size
74
+
75
+ # check the number of layers and stacks
76
+ assert layers % stacks == 0
77
+ layers_per_stack = layers // stacks
78
+
79
+ # define first convolution
80
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
81
+
82
+ # define conv + upsampling network
83
+ if upsample_conditional_features:
84
+ upsample_params.update({
85
+ "use_causal_conv": use_causal_conv,
86
+ })
87
+ if upsample_net == "MelGANGenerator":
88
+ assert aux_context_window == 0
89
+ upsample_params.update({
90
+ "use_weight_norm": False, # not to apply twice
91
+ "use_final_nonlinear_activation": False,
92
+ })
93
+ self.upsample_net = getattr(models, upsample_net)(**upsample_params)
94
+ else:
95
+ if upsample_net == "ConvInUpsampleNetwork":
96
+ upsample_params.update({
97
+ "aux_channels": aux_channels,
98
+ "aux_context_window": aux_context_window,
99
+ })
100
+ self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
101
+ else:
102
+ self.upsample_net = None
103
+
104
+ # define residual blocks
105
+ self.conv_layers = torch.nn.ModuleList()
106
+ for layer in range(layers):
107
+ dilation = 2 ** (layer % layers_per_stack)
108
+ conv = ResidualBlock(
109
+ kernel_size=kernel_size,
110
+ residual_channels=residual_channels,
111
+ gate_channels=gate_channels,
112
+ skip_channels=skip_channels,
113
+ aux_channels=aux_channels,
114
+ dilation=dilation,
115
+ dropout=dropout,
116
+ bias=bias,
117
+ use_causal_conv=use_causal_conv,
118
+ )
119
+ self.conv_layers += [conv]
120
+
121
+ # define output layers
122
+ self.last_conv_layers = torch.nn.ModuleList([
123
+ torch.nn.ReLU(inplace=True),
124
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
125
+ torch.nn.ReLU(inplace=True),
126
+ Conv1d1x1(skip_channels, out_channels, bias=True),
127
+ ])
128
+
129
+ self.use_pitch_embed = use_pitch_embed
130
+ if use_pitch_embed:
131
+ self.pitch_embed = nn.Embedding(300, aux_channels, 0)
132
+ self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
133
+
134
+ # apply weight norm
135
+ if use_weight_norm:
136
+ self.apply_weight_norm()
137
+
138
+ def forward(self, x, c=None, pitch=None, **kwargs):
139
+ """Calculate forward propagation.
140
+
141
+ Args:
142
+ x (Tensor): Input noise signal (B, C_in, T).
143
+ c (Tensor): Local conditioning auxiliary features (B, C ,T').
144
+ pitch (Tensor): Local conditioning pitch (B, T').
145
+
146
+ Returns:
147
+ Tensor: Output tensor (B, C_out, T)
148
+
149
+ """
150
+ # perform upsampling
151
+ if c is not None and self.upsample_net is not None:
152
+ if self.use_pitch_embed:
153
+ p = self.pitch_embed(pitch)
154
+ c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
155
+ c = self.upsample_net(c)
156
+ assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
157
+
158
+ # encode to hidden representation
159
+ x = self.first_conv(x)
160
+ skips = 0
161
+ for f in self.conv_layers:
162
+ x, h = f(x, c)
163
+ skips += h
164
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
165
+
166
+ # apply final layers
167
+ x = skips
168
+ for f in self.last_conv_layers:
169
+ x = f(x)
170
+
171
+ return x
172
+
173
+ def remove_weight_norm(self):
174
+ """Remove weight normalization module from all of the layers."""
175
+ def _remove_weight_norm(m):
176
+ try:
177
+ logging.debug(f"Weight norm is removed from {m}.")
178
+ torch.nn.utils.remove_weight_norm(m)
179
+ except ValueError: # this module didn't have weight norm
180
+ return
181
+
182
+ self.apply(_remove_weight_norm)
183
+
184
+ def apply_weight_norm(self):
185
+ """Apply weight normalization module from all of the layers."""
186
+ def _apply_weight_norm(m):
187
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
188
+ torch.nn.utils.weight_norm(m)
189
+ logging.debug(f"Weight norm is applied to {m}.")
190
+
191
+ self.apply(_apply_weight_norm)
192
+
193
+ @staticmethod
194
+ def _get_receptive_field_size(layers, stacks, kernel_size,
195
+ dilation=lambda x: 2 ** x):
196
+ assert layers % stacks == 0
197
+ layers_per_cycle = layers // stacks
198
+ dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
199
+ return (kernel_size - 1) * sum(dilations) + 1
200
+
201
+ @property
202
+ def receptive_field_size(self):
203
+ """Return receptive field size."""
204
+ return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
205
+
206
+
207
+ class ParallelWaveGANDiscriminator(torch.nn.Module):
208
+ """Parallel WaveGAN Discriminator module."""
209
+
210
+ def __init__(self,
211
+ in_channels=1,
212
+ out_channels=1,
213
+ kernel_size=3,
214
+ layers=10,
215
+ conv_channels=64,
216
+ dilation_factor=1,
217
+ nonlinear_activation="LeakyReLU",
218
+ nonlinear_activation_params={"negative_slope": 0.2},
219
+ bias=True,
220
+ use_weight_norm=True,
221
+ ):
222
+ """Initialize Parallel WaveGAN Discriminator module.
223
+
224
+ Args:
225
+ in_channels (int): Number of input channels.
226
+ out_channels (int): Number of output channels.
227
+ kernel_size (int): Number of output channels.
228
+ layers (int): Number of conv layers.
229
+ conv_channels (int): Number of chnn layers.
230
+ dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
231
+ the dilation will be 2, 4, 8, ..., and so on.
232
+ nonlinear_activation (str): Nonlinear function after each conv.
233
+ nonlinear_activation_params (dict): Nonlinear function parameters
234
+ bias (bool): Whether to use bias parameter in conv.
235
+ use_weight_norm (bool) Whether to use weight norm.
236
+ If set to true, it will be applied to all of the conv layers.
237
+
238
+ """
239
+ super(ParallelWaveGANDiscriminator, self).__init__()
240
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
241
+ assert dilation_factor > 0, "Dilation factor must be > 0."
242
+ self.conv_layers = torch.nn.ModuleList()
243
+ conv_in_channels = in_channels
244
+ for i in range(layers - 1):
245
+ if i == 0:
246
+ dilation = 1
247
+ else:
248
+ dilation = i if dilation_factor == 1 else dilation_factor ** i
249
+ conv_in_channels = conv_channels
250
+ padding = (kernel_size - 1) // 2 * dilation
251
+ conv_layer = [
252
+ Conv1d(conv_in_channels, conv_channels,
253
+ kernel_size=kernel_size, padding=padding,
254
+ dilation=dilation, bias=bias),
255
+ getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
256
+ ]
257
+ self.conv_layers += conv_layer
258
+ padding = (kernel_size - 1) // 2
259
+ last_conv_layer = Conv1d(
260
+ conv_in_channels, out_channels,
261
+ kernel_size=kernel_size, padding=padding, bias=bias)
262
+ self.conv_layers += [last_conv_layer]
263
+
264
+ # apply weight norm
265
+ if use_weight_norm:
266
+ self.apply_weight_norm()
267
+
268
+ def forward(self, x):
269
+ """Calculate forward propagation.
270
+
271
+ Args:
272
+ x (Tensor): Input noise signal (B, 1, T).
273
+
274
+ Returns:
275
+ Tensor: Output tensor (B, 1, T)
276
+
277
+ """
278
+ for f in self.conv_layers:
279
+ x = f(x)
280
+ return x
281
+
282
+ def apply_weight_norm(self):
283
+ """Apply weight normalization module from all of the layers."""
284
+ def _apply_weight_norm(m):
285
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
286
+ torch.nn.utils.weight_norm(m)
287
+ logging.debug(f"Weight norm is applied to {m}.")
288
+
289
+ self.apply(_apply_weight_norm)
290
+
291
+ def remove_weight_norm(self):
292
+ """Remove weight normalization module from all of the layers."""
293
+ def _remove_weight_norm(m):
294
+ try:
295
+ logging.debug(f"Weight norm is removed from {m}.")
296
+ torch.nn.utils.remove_weight_norm(m)
297
+ except ValueError: # this module didn't have weight norm
298
+ return
299
+
300
+ self.apply(_remove_weight_norm)
301
+
302
+
303
+ class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
304
+ """Parallel WaveGAN Discriminator module."""
305
+
306
+ def __init__(self,
307
+ in_channels=1,
308
+ out_channels=1,
309
+ kernel_size=3,
310
+ layers=30,
311
+ stacks=3,
312
+ residual_channels=64,
313
+ gate_channels=128,
314
+ skip_channels=64,
315
+ dropout=0.0,
316
+ bias=True,
317
+ use_weight_norm=True,
318
+ use_causal_conv=False,
319
+ nonlinear_activation="LeakyReLU",
320
+ nonlinear_activation_params={"negative_slope": 0.2},
321
+ ):
322
+ """Initialize Parallel WaveGAN Discriminator module.
323
+
324
+ Args:
325
+ in_channels (int): Number of input channels.
326
+ out_channels (int): Number of output channels.
327
+ kernel_size (int): Kernel size of dilated convolution.
328
+ layers (int): Number of residual block layers.
329
+ stacks (int): Number of stacks i.e., dilation cycles.
330
+ residual_channels (int): Number of channels in residual conv.
331
+ gate_channels (int): Number of channels in gated conv.
332
+ skip_channels (int): Number of channels in skip conv.
333
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
334
+ bias (bool): Whether to use bias parameter in conv.
335
+ use_weight_norm (bool): Whether to use weight norm.
336
+ If set to true, it will be applied to all of the conv layers.
337
+ use_causal_conv (bool): Whether to use causal structure.
338
+ nonlinear_activation_params (dict): Nonlinear function parameters
339
+
340
+ """
341
+ super(ResidualParallelWaveGANDiscriminator, self).__init__()
342
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
343
+
344
+ self.in_channels = in_channels
345
+ self.out_channels = out_channels
346
+ self.layers = layers
347
+ self.stacks = stacks
348
+ self.kernel_size = kernel_size
349
+
350
+ # check the number of layers and stacks
351
+ assert layers % stacks == 0
352
+ layers_per_stack = layers // stacks
353
+
354
+ # define first convolution
355
+ self.first_conv = torch.nn.Sequential(
356
+ Conv1d1x1(in_channels, residual_channels, bias=True),
357
+ getattr(torch.nn, nonlinear_activation)(
358
+ inplace=True, **nonlinear_activation_params),
359
+ )
360
+
361
+ # define residual blocks
362
+ self.conv_layers = torch.nn.ModuleList()
363
+ for layer in range(layers):
364
+ dilation = 2 ** (layer % layers_per_stack)
365
+ conv = ResidualBlock(
366
+ kernel_size=kernel_size,
367
+ residual_channels=residual_channels,
368
+ gate_channels=gate_channels,
369
+ skip_channels=skip_channels,
370
+ aux_channels=-1,
371
+ dilation=dilation,
372
+ dropout=dropout,
373
+ bias=bias,
374
+ use_causal_conv=use_causal_conv,
375
+ )
376
+ self.conv_layers += [conv]
377
+
378
+ # define output layers
379
+ self.last_conv_layers = torch.nn.ModuleList([
380
+ getattr(torch.nn, nonlinear_activation)(
381
+ inplace=True, **nonlinear_activation_params),
382
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
383
+ getattr(torch.nn, nonlinear_activation)(
384
+ inplace=True, **nonlinear_activation_params),
385
+ Conv1d1x1(skip_channels, out_channels, bias=True),
386
+ ])
387
+
388
+ # apply weight norm
389
+ if use_weight_norm:
390
+ self.apply_weight_norm()
391
+
392
+ def forward(self, x):
393
+ """Calculate forward propagation.
394
+
395
+ Args:
396
+ x (Tensor): Input noise signal (B, 1, T).
397
+
398
+ Returns:
399
+ Tensor: Output tensor (B, 1, T)
400
+
401
+ """
402
+ x = self.first_conv(x)
403
+
404
+ skips = 0
405
+ for f in self.conv_layers:
406
+ x, h = f(x, None)
407
+ skips += h
408
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
409
+
410
+ # apply final layers
411
+ x = skips
412
+ for f in self.last_conv_layers:
413
+ x = f(x)
414
+ return x
415
+
416
+ def apply_weight_norm(self):
417
+ """Apply weight normalization module from all of the layers."""
418
+ def _apply_weight_norm(m):
419
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
420
+ torch.nn.utils.weight_norm(m)
421
+ logging.debug(f"Weight norm is applied to {m}.")
422
+
423
+ self.apply(_apply_weight_norm)
424
+
425
+ def remove_weight_norm(self):
426
+ """Remove weight normalization module from all of the layers."""
427
+ def _remove_weight_norm(m):
428
+ try:
429
+ logging.debug(f"Weight norm is removed from {m}.")
430
+ torch.nn.utils.remove_weight_norm(m)
431
+ except ValueError: # this module didn't have weight norm
432
+ return
433
+
434
+ self.apply(_remove_weight_norm)
modules/parallel_wavegan/models/source.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import sys
4
+ import torch.nn.functional as torch_nn_func
5
+
6
+
7
+ class SineGen(torch.nn.Module):
8
+ """ Definition of sine generator
9
+ SineGen(samp_rate, harmonic_num = 0,
10
+ sine_amp = 0.1, noise_std = 0.003,
11
+ voiced_threshold = 0,
12
+ flag_for_pulse=False)
13
+
14
+ samp_rate: sampling rate in Hz
15
+ harmonic_num: number of harmonic overtones (default 0)
16
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
17
+ noise_std: std of Gaussian noise (default 0.003)
18
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
19
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
20
+
21
+ Note: when flag_for_pulse is True, the first time step of a voiced
22
+ segment is always sin(np.pi) or cos(0)
23
+ """
24
+
25
+ def __init__(self, samp_rate, harmonic_num=0,
26
+ sine_amp=0.1, noise_std=0.003,
27
+ voiced_threshold=0,
28
+ flag_for_pulse=False):
29
+ super(SineGen, self).__init__()
30
+ self.sine_amp = sine_amp
31
+ self.noise_std = noise_std
32
+ self.harmonic_num = harmonic_num
33
+ self.dim = self.harmonic_num + 1
34
+ self.sampling_rate = samp_rate
35
+ self.voiced_threshold = voiced_threshold
36
+ self.flag_for_pulse = flag_for_pulse
37
+
38
+ def _f02uv(self, f0):
39
+ # generate uv signal
40
+ uv = torch.ones_like(f0)
41
+ uv = uv * (f0 > self.voiced_threshold)
42
+ return uv
43
+
44
+ def _f02sine(self, f0_values):
45
+ """ f0_values: (batchsize, length, dim)
46
+ where dim indicates fundamental tone and overtones
47
+ """
48
+ # convert to F0 in rad. The interger part n can be ignored
49
+ # because 2 * np.pi * n doesn't affect phase
50
+ rad_values = (f0_values / self.sampling_rate) % 1
51
+
52
+ # initial phase noise (no noise for fundamental component)
53
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
54
+ device=f0_values.device)
55
+ rand_ini[:, 0] = 0
56
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
57
+
58
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
59
+ if not self.flag_for_pulse:
60
+ # for normal case
61
+
62
+ # To prevent torch.cumsum numerical overflow,
63
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
64
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
65
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
66
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
67
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
68
+ tmp_over_one[:, :-1, :]) < 0
69
+ cumsum_shift = torch.zeros_like(rad_values)
70
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
71
+
72
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
73
+ * 2 * np.pi)
74
+ else:
75
+ # If necessary, make sure that the first time step of every
76
+ # voiced segments is sin(pi) or cos(0)
77
+ # This is used for pulse-train generation
78
+
79
+ # identify the last time step in unvoiced segments
80
+ uv = self._f02uv(f0_values)
81
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
82
+ uv_1[:, -1, :] = 1
83
+ u_loc = (uv < 1) * (uv_1 > 0)
84
+
85
+ # get the instantanouse phase
86
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
87
+ # different batch needs to be processed differently
88
+ for idx in range(f0_values.shape[0]):
89
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
90
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
91
+ # stores the accumulation of i.phase within
92
+ # each voiced segments
93
+ tmp_cumsum[idx, :, :] = 0
94
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
95
+
96
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
97
+ # within the previous voiced segment.
98
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
99
+
100
+ # get the sines
101
+ sines = torch.cos(i_phase * 2 * np.pi)
102
+ return sines
103
+
104
+ def forward(self, f0):
105
+ """ sine_tensor, uv = forward(f0)
106
+ input F0: tensor(batchsize=1, length, dim=1)
107
+ f0 for unvoiced steps should be 0
108
+ output sine_tensor: tensor(batchsize=1, length, dim)
109
+ output uv: tensor(batchsize=1, length, 1)
110
+ """
111
+ with torch.no_grad():
112
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
113
+ device=f0.device)
114
+ # fundamental component
115
+ f0_buf[:, :, 0] = f0[:, :, 0]
116
+ for idx in np.arange(self.harmonic_num):
117
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
118
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
119
+
120
+ # generate sine waveforms
121
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
122
+
123
+ # generate uv signal
124
+ # uv = torch.ones(f0.shape)
125
+ # uv = uv * (f0 > self.voiced_threshold)
126
+ uv = self._f02uv(f0)
127
+
128
+ # noise: for unvoiced should be similar to sine_amp
129
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
130
+ # . for voiced regions is self.noise_std
131
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
132
+ noise = noise_amp * torch.randn_like(sine_waves)
133
+
134
+ # first: set the unvoiced part to 0 by uv
135
+ # then: additive noise
136
+ sine_waves = sine_waves * uv + noise
137
+ return sine_waves, uv, noise
138
+
139
+
140
+ class PulseGen(torch.nn.Module):
141
+ """ Definition of Pulse train generator
142
+
143
+ There are many ways to implement pulse generator.
144
+ Here, PulseGen is based on SinGen. For a perfect
145
+ """
146
+ def __init__(self, samp_rate, pulse_amp = 0.1,
147
+ noise_std = 0.003, voiced_threshold = 0):
148
+ super(PulseGen, self).__init__()
149
+ self.pulse_amp = pulse_amp
150
+ self.sampling_rate = samp_rate
151
+ self.voiced_threshold = voiced_threshold
152
+ self.noise_std = noise_std
153
+ self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \
154
+ sine_amp=self.pulse_amp, noise_std=0, \
155
+ voiced_threshold=self.voiced_threshold, \
156
+ flag_for_pulse=True)
157
+
158
+ def forward(self, f0):
159
+ """ Pulse train generator
160
+ pulse_train, uv = forward(f0)
161
+ input F0: tensor(batchsize=1, length, dim=1)
162
+ f0 for unvoiced steps should be 0
163
+ output pulse_train: tensor(batchsize=1, length, dim)
164
+ output uv: tensor(batchsize=1, length, 1)
165
+
166
+ Note: self.l_sine doesn't make sure that the initial phase of
167
+ a voiced segment is np.pi, the first pulse in a voiced segment
168
+ may not be at the first time step within a voiced segment
169
+ """
170
+ with torch.no_grad():
171
+ sine_wav, uv, noise = self.l_sinegen(f0)
172
+
173
+ # sine without additive noise
174
+ pure_sine = sine_wav - noise
175
+
176
+ # step t corresponds to a pulse if
177
+ # sine[t] > sine[t+1] & sine[t] > sine[t-1]
178
+ # & sine[t-1], sine[t+1], and sine[t] are voiced
179
+ # or
180
+ # sine[t] is voiced, sine[t-1] is unvoiced
181
+ # we use torch.roll to simulate sine[t+1] and sine[t-1]
182
+ sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
183
+ uv_1 = torch.roll(uv, shifts=1, dims=1)
184
+ uv_1[:, 0, :] = 0
185
+ sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
186
+ uv_2 = torch.roll(uv, shifts=-1, dims=1)
187
+ uv_2[:, -1, :] = 0
188
+
189
+ loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
190
+ * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
191
+ + (uv_1 < 1) * (uv > 0)
192
+
193
+ # pulse train without noise
194
+ pulse_train = pure_sine * loc
195
+
196
+ # additive noise to pulse train
197
+ # note that noise from sinegen is zero in voiced regions
198
+ pulse_noise = torch.randn_like(pure_sine) * self.noise_std
199
+
200
+ # with additive noise on pulse, and unvoiced regions
201
+ pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
202
+ return pulse_train, sine_wav, uv, pulse_noise
203
+
204
+
205
+ class SignalsConv1d(torch.nn.Module):
206
+ """ Filtering input signal with time invariant filter
207
+ Note: FIRFilter conducted filtering given fixed FIR weight
208
+ SignalsConv1d convolves two signals
209
+ Note: this is based on torch.nn.functional.conv1d
210
+
211
+ """
212
+
213
+ def __init__(self):
214
+ super(SignalsConv1d, self).__init__()
215
+
216
+ def forward(self, signal, system_ir):
217
+ """ output = forward(signal, system_ir)
218
+
219
+ signal: (batchsize, length1, dim)
220
+ system_ir: (length2, dim)
221
+
222
+ output: (batchsize, length1, dim)
223
+ """
224
+ if signal.shape[-1] != system_ir.shape[-1]:
225
+ print("Error: SignalsConv1d expects shape:")
226
+ print("signal (batchsize, length1, dim)")
227
+ print("system_id (batchsize, length2, dim)")
228
+ print("But received signal: {:s}".format(str(signal.shape)))
229
+ print(" system_ir: {:s}".format(str(system_ir.shape)))
230
+ sys.exit(1)
231
+ padding_length = system_ir.shape[0] - 1
232
+ groups = signal.shape[-1]
233
+
234
+ # pad signal on the left
235
+ signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \
236
+ (padding_length, 0))
237
+ # prepare system impulse response as (dim, 1, length2)
238
+ # also flip the impulse response
239
+ ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \
240
+ dims=[2])
241
+ # convolute
242
+ output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
243
+ return output.permute(0, 2, 1)
244
+
245
+
246
+ class CyclicNoiseGen_v1(torch.nn.Module):
247
+ """ CyclicnoiseGen_v1
248
+ Cyclic noise with a single parameter of beta.
249
+ Pytorch v1 implementation assumes f_t is also fixed
250
+ """
251
+
252
+ def __init__(self, samp_rate,
253
+ noise_std=0.003, voiced_threshold=0):
254
+ super(CyclicNoiseGen_v1, self).__init__()
255
+ self.samp_rate = samp_rate
256
+ self.noise_std = noise_std
257
+ self.voiced_threshold = voiced_threshold
258
+
259
+ self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0,
260
+ noise_std=noise_std,
261
+ voiced_threshold=voiced_threshold)
262
+ self.l_conv = SignalsConv1d()
263
+
264
+ def noise_decay(self, beta, f0mean):
265
+ """ decayed_noise = noise_decay(beta, f0mean)
266
+ decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate)
267
+
268
+ beta: (dim=1) or (batchsize=1, 1, dim=1)
269
+ f0mean (batchsize=1, 1, dim=1)
270
+
271
+ decayed_noise (batchsize=1, length, dim=1)
272
+ """
273
+ with torch.no_grad():
274
+ # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
275
+ # truncate the noise when decayed by -40 dB
276
+ length = 4.6 * self.samp_rate / f0mean
277
+ length = length.int()
278
+ time_idx = torch.arange(0, length, device=beta.device)
279
+ time_idx = time_idx.unsqueeze(0).unsqueeze(2)
280
+ time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
281
+
282
+ noise = torch.randn(time_idx.shape, device=beta.device)
283
+
284
+ # due to Pytorch implementation, use f0_mean as the f0 factor
285
+ decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
286
+ return noise * self.noise_std * decay
287
+
288
+ def forward(self, f0s, beta):
289
+ """ Producde cyclic-noise
290
+ """
291
+ # pulse train
292
+ pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
293
+ pure_pulse = pulse_train - noise
294
+
295
+ # decayed_noise (length, dim=1)
296
+ if (uv < 1).all():
297
+ # all unvoiced
298
+ cyc_noise = torch.zeros_like(sine_wav)
299
+ else:
300
+ f0mean = f0s[uv > 0].mean()
301
+
302
+ decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
303
+ # convolute
304
+ cyc_noise = self.l_conv(pure_pulse, decayed_noise)
305
+
306
+ # add noise in invoiced segments
307
+ cyc_noise = cyc_noise + noise * (1.0 - uv)
308
+ return cyc_noise, pulse_train, sine_wav, uv, noise
309
+
310
+
311
+ class SineGen(torch.nn.Module):
312
+ """ Definition of sine generator
313
+ SineGen(samp_rate, harmonic_num = 0,
314
+ sine_amp = 0.1, noise_std = 0.003,
315
+ voiced_threshold = 0,
316
+ flag_for_pulse=False)
317
+
318
+ samp_rate: sampling rate in Hz
319
+ harmonic_num: number of harmonic overtones (default 0)
320
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
321
+ noise_std: std of Gaussian noise (default 0.003)
322
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
323
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
324
+
325
+ Note: when flag_for_pulse is True, the first time step of a voiced
326
+ segment is always sin(np.pi) or cos(0)
327
+ """
328
+
329
+ def __init__(self, samp_rate, harmonic_num=0,
330
+ sine_amp=0.1, noise_std=0.003,
331
+ voiced_threshold=0,
332
+ flag_for_pulse=False):
333
+ super(SineGen, self).__init__()
334
+ self.sine_amp = sine_amp
335
+ self.noise_std = noise_std
336
+ self.harmonic_num = harmonic_num
337
+ self.dim = self.harmonic_num + 1
338
+ self.sampling_rate = samp_rate
339
+ self.voiced_threshold = voiced_threshold
340
+ self.flag_for_pulse = flag_for_pulse
341
+
342
+ def _f02uv(self, f0):
343
+ # generate uv signal
344
+ uv = torch.ones_like(f0)
345
+ uv = uv * (f0 > self.voiced_threshold)
346
+ return uv
347
+
348
+ def _f02sine(self, f0_values):
349
+ """ f0_values: (batchsize, length, dim)
350
+ where dim indicates fundamental tone and overtones
351
+ """
352
+ # convert to F0 in rad. The interger part n can be ignored
353
+ # because 2 * np.pi * n doesn't affect phase
354
+ rad_values = (f0_values / self.sampling_rate) % 1
355
+
356
+ # initial phase noise (no noise for fundamental component)
357
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
358
+ device=f0_values.device)
359
+ rand_ini[:, 0] = 0
360
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
361
+
362
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
363
+ if not self.flag_for_pulse:
364
+ # for normal case
365
+
366
+ # To prevent torch.cumsum numerical overflow,
367
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
368
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
369
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
370
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
371
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
372
+ tmp_over_one[:, :-1, :]) < 0
373
+ cumsum_shift = torch.zeros_like(rad_values)
374
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
375
+
376
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
377
+ * 2 * np.pi)
378
+ else:
379
+ # If necessary, make sure that the first time step of every
380
+ # voiced segments is sin(pi) or cos(0)
381
+ # This is used for pulse-train generation
382
+
383
+ # identify the last time step in unvoiced segments
384
+ uv = self._f02uv(f0_values)
385
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
386
+ uv_1[:, -1, :] = 1
387
+ u_loc = (uv < 1) * (uv_1 > 0)
388
+
389
+ # get the instantanouse phase
390
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
391
+ # different batch needs to be processed differently
392
+ for idx in range(f0_values.shape[0]):
393
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
394
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
395
+ # stores the accumulation of i.phase within
396
+ # each voiced segments
397
+ tmp_cumsum[idx, :, :] = 0
398
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
399
+
400
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
401
+ # within the previous voiced segment.
402
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
403
+
404
+ # get the sines
405
+ sines = torch.cos(i_phase * 2 * np.pi)
406
+ return sines
407
+
408
+ def forward(self, f0):
409
+ """ sine_tensor, uv = forward(f0)
410
+ input F0: tensor(batchsize=1, length, dim=1)
411
+ f0 for unvoiced steps should be 0
412
+ output sine_tensor: tensor(batchsize=1, length, dim)
413
+ output uv: tensor(batchsize=1, length, 1)
414
+ """
415
+ with torch.no_grad():
416
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \
417
+ device=f0.device)
418
+ # fundamental component
419
+ f0_buf[:, :, 0] = f0[:, :, 0]
420
+ for idx in np.arange(self.harmonic_num):
421
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
422
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
423
+
424
+ # generate sine waveforms
425
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
426
+
427
+ # generate uv signal
428
+ # uv = torch.ones(f0.shape)
429
+ # uv = uv * (f0 > self.voiced_threshold)
430
+ uv = self._f02uv(f0)
431
+
432
+ # noise: for unvoiced should be similar to sine_amp
433
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
434
+ # . for voiced regions is self.noise_std
435
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
436
+ noise = noise_amp * torch.randn_like(sine_waves)
437
+
438
+ # first: set the unvoiced part to 0 by uv
439
+ # then: additive noise
440
+ sine_waves = sine_waves * uv + noise
441
+ return sine_waves, uv, noise
442
+
443
+
444
+ class SourceModuleCycNoise_v1(torch.nn.Module):
445
+ """ SourceModuleCycNoise_v1
446
+ SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
447
+ sampling_rate: sampling_rate in Hz
448
+
449
+ noise_std: std of Gaussian noise (default: 0.003)
450
+ voiced_threshold: threshold to set U/V given F0 (default: 0)
451
+
452
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
453
+ F0_upsampled (batchsize, length, 1)
454
+ beta (1)
455
+ cyc (batchsize, length, 1)
456
+ noise (batchsize, length, 1)
457
+ uv (batchsize, length, 1)
458
+ """
459
+
460
+ def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
461
+ super(SourceModuleCycNoise_v1, self).__init__()
462
+ self.sampling_rate = sampling_rate
463
+ self.noise_std = noise_std
464
+ self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std,
465
+ voiced_threshod)
466
+
467
+ def forward(self, f0_upsamped, beta):
468
+ """
469
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
470
+ F0_upsampled (batchsize, length, 1)
471
+ beta (1)
472
+ cyc (batchsize, length, 1)
473
+ noise (batchsize, length, 1)
474
+ uv (batchsize, length, 1)
475
+ """
476
+ # source for harmonic branch
477
+ cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
478
+
479
+ # source for noise branch, in the same shape as uv
480
+ noise = torch.randn_like(uv) * self.noise_std / 3
481
+ return cyc, noise, uv
482
+
483
+
484
+ class SourceModuleHnNSF(torch.nn.Module):
485
+ """ SourceModule for hn-nsf
486
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
487
+ add_noise_std=0.003, voiced_threshod=0)
488
+ sampling_rate: sampling_rate in Hz
489
+ harmonic_num: number of harmonic above F0 (default: 0)
490
+ sine_amp: amplitude of sine source signal (default: 0.1)
491
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
492
+ note that amplitude of noise in unvoiced is decided
493
+ by sine_amp
494
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
495
+
496
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
497
+ F0_sampled (batchsize, length, 1)
498
+ Sine_source (batchsize, length, 1)
499
+ noise_source (batchsize, length 1)
500
+ uv (batchsize, length, 1)
501
+ """
502
+
503
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
504
+ add_noise_std=0.003, voiced_threshod=0):
505
+ super(SourceModuleHnNSF, self).__init__()
506
+
507
+ self.sine_amp = sine_amp
508
+ self.noise_std = add_noise_std
509
+
510
+ # to produce sine waveforms
511
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
512
+ sine_amp, add_noise_std, voiced_threshod)
513
+
514
+ # to merge source harmonics into a single excitation
515
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
516
+ self.l_tanh = torch.nn.Tanh()
517
+
518
+ def forward(self, x):
519
+ """
520
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
521
+ F0_sampled (batchsize, length, 1)
522
+ Sine_source (batchsize, length, 1)
523
+ noise_source (batchsize, length 1)
524
+ """
525
+ # source for harmonic branch
526
+ sine_wavs, uv, _ = self.l_sin_gen(x)
527
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
528
+
529
+ # source for noise branch, in the same shape as uv
530
+ noise = torch.randn_like(uv) * self.sine_amp / 3
531
+ return sine_merge, noise, uv
532
+
533
+
534
+ if __name__ == '__main__':
535
+ source = SourceModuleCycNoise_v1(24000)
536
+ x = torch.randn(16, 25600, 1)
537
+
538
+
modules/parallel_wavegan/optimizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from torch.optim import * # NOQA
2
+ from .radam import * # NOQA
modules/parallel_wavegan/optimizers/radam.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """RAdam optimizer.
4
+
5
+ This code is drived from https://github.com/LiyuanLucasLiu/RAdam.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+
11
+ from torch.optim.optimizer import Optimizer
12
+
13
+
14
+ class RAdam(Optimizer):
15
+ """Rectified Adam optimizer."""
16
+
17
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
18
+ """Initilize RAdam optimizer."""
19
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
20
+ self.buffer = [[None, None, None] for ind in range(10)]
21
+ super(RAdam, self).__init__(params, defaults)
22
+
23
+ def __setstate__(self, state):
24
+ """Set state."""
25
+ super(RAdam, self).__setstate__(state)
26
+
27
+ def step(self, closure=None):
28
+ """Run one step."""
29
+ loss = None
30
+ if closure is not None:
31
+ loss = closure()
32
+
33
+ for group in self.param_groups:
34
+
35
+ for p in group['params']:
36
+ if p.grad is None:
37
+ continue
38
+ grad = p.grad.data.float()
39
+ if grad.is_sparse:
40
+ raise RuntimeError('RAdam does not support sparse gradients')
41
+
42
+ p_data_fp32 = p.data.float()
43
+
44
+ state = self.state[p]
45
+
46
+ if len(state) == 0:
47
+ state['step'] = 0
48
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
49
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
50
+ else:
51
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
52
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
53
+
54
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
55
+ beta1, beta2 = group['betas']
56
+
57
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
58
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
59
+
60
+ state['step'] += 1
61
+ buffered = self.buffer[int(state['step'] % 10)]
62
+ if state['step'] == buffered[0]:
63
+ N_sma, step_size = buffered[1], buffered[2]
64
+ else:
65
+ buffered[0] = state['step']
66
+ beta2_t = beta2 ** state['step']
67
+ N_sma_max = 2 / (1 - beta2) - 1
68
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
69
+ buffered[1] = N_sma
70
+
71
+ # more conservative since it's an approximated value
72
+ if N_sma >= 5:
73
+ step_size = math.sqrt(
74
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA
75
+ else:
76
+ step_size = 1.0 / (1 - beta1 ** state['step'])
77
+ buffered[2] = step_size
78
+
79
+ if group['weight_decay'] != 0:
80
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
81
+
82
+ # more conservative since it's an approximated value
83
+ if N_sma >= 5:
84
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
85
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
86
+ else:
87
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
88
+
89
+ p.data.copy_(p_data_fp32)
90
+
91
+ return loss
modules/parallel_wavegan/stft_loss.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """STFT-based Loss modules."""
7
+ import librosa
8
+ import torch
9
+
10
+ from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft
11
+
12
+
13
+ class STFTLoss(torch.nn.Module):
14
+ """STFT loss module."""
15
+
16
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
17
+ use_mel_loss=False):
18
+ """Initialize STFT loss module."""
19
+ super(STFTLoss, self).__init__()
20
+ self.fft_size = fft_size
21
+ self.shift_size = shift_size
22
+ self.win_length = win_length
23
+ self.window = getattr(torch, window)(win_length)
24
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
25
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
26
+ self.use_mel_loss = use_mel_loss
27
+ self.mel_basis = None
28
+
29
+ def forward(self, x, y):
30
+ """Calculate forward propagation.
31
+
32
+ Args:
33
+ x (Tensor): Predicted signal (B, T).
34
+ y (Tensor): Groundtruth signal (B, T).
35
+
36
+ Returns:
37
+ Tensor: Spectral convergence loss value.
38
+ Tensor: Log STFT magnitude loss value.
39
+
40
+ """
41
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
42
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
43
+ if self.use_mel_loss:
44
+ if self.mel_basis is None:
45
+ self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T
46
+ x_mag = x_mag @ self.mel_basis
47
+ y_mag = y_mag @ self.mel_basis
48
+
49
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
50
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
51
+
52
+ return sc_loss, mag_loss
53
+
54
+
55
+ class MultiResolutionSTFTLoss(torch.nn.Module):
56
+ """Multi resolution STFT loss module."""
57
+
58
+ def __init__(self,
59
+ fft_sizes=[1024, 2048, 512],
60
+ hop_sizes=[120, 240, 50],
61
+ win_lengths=[600, 1200, 240],
62
+ window="hann_window",
63
+ use_mel_loss=False):
64
+ """Initialize Multi resolution STFT loss module.
65
+
66
+ Args:
67
+ fft_sizes (list): List of FFT sizes.
68
+ hop_sizes (list): List of hop sizes.
69
+ win_lengths (list): List of window lengths.
70
+ window (str): Window function type.
71
+
72
+ """
73
+ super(MultiResolutionSTFTLoss, self).__init__()
74
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
75
+ self.stft_losses = torch.nn.ModuleList()
76
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
77
+ self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)]
78
+
79
+ def forward(self, x, y):
80
+ """Calculate forward propagation.
81
+
82
+ Args:
83
+ x (Tensor): Predicted signal (B, T).
84
+ y (Tensor): Groundtruth signal (B, T).
85
+
86
+ Returns:
87
+ Tensor: Multi resolution spectral convergence loss value.
88
+ Tensor: Multi resolution log STFT magnitude loss value.
89
+
90
+ """
91
+ sc_loss = 0.0
92
+ mag_loss = 0.0
93
+ for f in self.stft_losses:
94
+ sc_l, mag_l = f(x, y)
95
+ sc_loss += sc_l
96
+ mag_loss += mag_l
97
+ sc_loss /= len(self.stft_losses)
98
+ mag_loss /= len(self.stft_losses)
99
+
100
+ return sc_loss, mag_loss
modules/parallel_wavegan/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import * # NOQA
modules/parallel_wavegan/utils/utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Utility functions."""
7
+
8
+ import fnmatch
9
+ import logging
10
+ import os
11
+ import sys
12
+
13
+ import h5py
14
+ import numpy as np
15
+
16
+
17
+ def find_files(root_dir, query="*.wav", include_root_dir=True):
18
+ """Find files recursively.
19
+
20
+ Args:
21
+ root_dir (str): Root root_dir to find.
22
+ query (str): Query to find.
23
+ include_root_dir (bool): If False, root_dir name is not included.
24
+
25
+ Returns:
26
+ list: List of found filenames.
27
+
28
+ """
29
+ files = []
30
+ for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
31
+ for filename in fnmatch.filter(filenames, query):
32
+ files.append(os.path.join(root, filename))
33
+ if not include_root_dir:
34
+ files = [file_.replace(root_dir + "/", "") for file_ in files]
35
+
36
+ return files
37
+
38
+
39
+ def read_hdf5(hdf5_name, hdf5_path):
40
+ """Read hdf5 dataset.
41
+
42
+ Args:
43
+ hdf5_name (str): Filename of hdf5 file.
44
+ hdf5_path (str): Dataset name in hdf5 file.
45
+
46
+ Return:
47
+ any: Dataset values.
48
+
49
+ """
50
+ if not os.path.exists(hdf5_name):
51
+ logging.error(f"There is no such a hdf5 file ({hdf5_name}).")
52
+ sys.exit(1)
53
+
54
+ hdf5_file = h5py.File(hdf5_name, "r")
55
+
56
+ if hdf5_path not in hdf5_file:
57
+ logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
58
+ sys.exit(1)
59
+
60
+ hdf5_data = hdf5_file[hdf5_path][()]
61
+ hdf5_file.close()
62
+
63
+ return hdf5_data
64
+
65
+
66
+ def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
67
+ """Write dataset to hdf5.
68
+
69
+ Args:
70
+ hdf5_name (str): Hdf5 dataset filename.
71
+ hdf5_path (str): Dataset path in hdf5.
72
+ write_data (ndarray): Data to write.
73
+ is_overwrite (bool): Whether to overwrite dataset.
74
+
75
+ """
76
+ # convert to numpy array
77
+ write_data = np.array(write_data)
78
+
79
+ # check folder existence
80
+ folder_name, _ = os.path.split(hdf5_name)
81
+ if not os.path.exists(folder_name) and len(folder_name) != 0:
82
+ os.makedirs(folder_name)
83
+
84
+ # check hdf5 existence
85
+ if os.path.exists(hdf5_name):
86
+ # if already exists, open with r+ mode
87
+ hdf5_file = h5py.File(hdf5_name, "r+")
88
+ # check dataset existence
89
+ if hdf5_path in hdf5_file:
90
+ if is_overwrite:
91
+ logging.warning("Dataset in hdf5 file already exists. "
92
+ "recreate dataset in hdf5.")
93
+ hdf5_file.__delitem__(hdf5_path)
94
+ else:
95
+ logging.error("Dataset in hdf5 file already exists. "
96
+ "if you want to overwrite, please set is_overwrite = True.")
97
+ hdf5_file.close()
98
+ sys.exit(1)
99
+ else:
100
+ # if not exists, open with w mode
101
+ hdf5_file = h5py.File(hdf5_name, "w")
102
+
103
+ # write data to hdf5
104
+ hdf5_file.create_dataset(hdf5_path, data=write_data)
105
+ hdf5_file.flush()
106
+ hdf5_file.close()
107
+
108
+
109
+ class HDF5ScpLoader(object):
110
+ """Loader class for a fests.scp file of hdf5 file.
111
+
112
+ Examples:
113
+ key1 /some/path/a.h5:feats
114
+ key2 /some/path/b.h5:feats
115
+ key3 /some/path/c.h5:feats
116
+ key4 /some/path/d.h5:feats
117
+ ...
118
+ >>> loader = HDF5ScpLoader("hdf5.scp")
119
+ >>> array = loader["key1"]
120
+
121
+ key1 /some/path/a.h5
122
+ key2 /some/path/b.h5
123
+ key3 /some/path/c.h5
124
+ key4 /some/path/d.h5
125
+ ...
126
+ >>> loader = HDF5ScpLoader("hdf5.scp", "feats")
127
+ >>> array = loader["key1"]
128
+
129
+ """
130
+
131
+ def __init__(self, feats_scp, default_hdf5_path="feats"):
132
+ """Initialize HDF5 scp loader.
133
+
134
+ Args:
135
+ feats_scp (str): Kaldi-style feats.scp file with hdf5 format.
136
+ default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used.
137
+
138
+ """
139
+ self.default_hdf5_path = default_hdf5_path
140
+ with open(feats_scp, encoding='utf-8') as f:
141
+ lines = [line.replace("\n", "") for line in f.readlines()]
142
+ self.data = {}
143
+ for line in lines:
144
+ key, value = line.split()
145
+ self.data[key] = value
146
+
147
+ def get_path(self, key):
148
+ """Get hdf5 file path for a given key."""
149
+ return self.data[key]
150
+
151
+ def __getitem__(self, key):
152
+ """Get ndarray for a given key."""
153
+ p = self.data[key]
154
+ if ":" in p:
155
+ return read_hdf5(*p.split(":"))
156
+ else:
157
+ return read_hdf5(p, self.default_hdf5_path)
158
+
159
+ def __len__(self):
160
+ """Return the length of the scp file."""
161
+ return len(self.data)
162
+
163
+ def __iter__(self):
164
+ """Return the iterator of the scp file."""
165
+ return iter(self.data)
166
+
167
+ def keys(self):
168
+ """Return the keys of the scp file."""
169
+ return self.data.keys()
network/diff/candidate_decoder.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.fastspeech.tts_modules import FastspeechDecoder
2
+ # from modules.fastspeech.fast_tacotron import DecoderRNN
3
+ # from modules.fastspeech.speedy_speech.speedy_speech import ConvBlocks
4
+ # from modules.fastspeech.conformer.conformer import ConformerDecoder
5
+ import torch
6
+ from torch.nn import functional as F
7
+ import torch.nn as nn
8
+ import math
9
+ from utils.hparams import hparams
10
+ from modules.commons.common_layers import Mish
11
+ Linear = nn.Linear
12
+
13
+ class SinusoidalPosEmb(nn.Module):
14
+ def __init__(self, dim):
15
+ super().__init__()
16
+ self.dim = dim
17
+
18
+ def forward(self, x):
19
+ device = x.device
20
+ half_dim = self.dim // 2
21
+ emb = math.log(10000) / (half_dim - 1)
22
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
23
+ emb = x[:, None] * emb[None, :]
24
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
25
+ return emb
26
+
27
+
28
+ def Conv1d(*args, **kwargs):
29
+ layer = nn.Conv1d(*args, **kwargs)
30
+ nn.init.kaiming_normal_(layer.weight)
31
+ return layer
32
+
33
+
34
+ class FFT(FastspeechDecoder): # unused, because DiffSinger only uses FastspeechEncoder
35
+ # NOTE: this part of script is *isolated* from other scripts, which means
36
+ # it may not be compatible with the current version.
37
+
38
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
39
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
40
+ dim = hparams['residual_channels']
41
+ self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1)
42
+ self.diffusion_embedding = SinusoidalPosEmb(dim)
43
+ self.mlp = nn.Sequential(
44
+ nn.Linear(dim, dim * 4),
45
+ Mish(),
46
+ nn.Linear(dim * 4, dim)
47
+ )
48
+ self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True)
49
+ self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim,
50
+ hparams['hidden_size']) # hs + dim + 80 -> hs
51
+
52
+ def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False):
53
+ """
54
+ :param spec: [B, 1, 80, T]
55
+ :param diffusion_step: [B, 1]
56
+ :param cond: [B, M, T]
57
+ :return:
58
+ """
59
+ x = spec[:, 0]
60
+ x = self.input_projection(x).permute([0, 2, 1]) # [B, T, residual_channel]
61
+ diffusion_step = self.diffusion_embedding(diffusion_step)
62
+ diffusion_step = self.mlp(diffusion_step) # [B, dim]
63
+ cond = cond.permute([0, 2, 1]) # [B, T, M]
64
+
65
+ seq_len = cond.shape[1] # [T_mel]
66
+ time_embed = diffusion_step[:, None, :] # [B, 1, dim]
67
+ time_embed = time_embed.repeat([1, seq_len, 1]) # # [B, T, dim]
68
+
69
+ decoder_inp = torch.cat([x, cond, time_embed], dim=-1) # [B, T, dim + H + dim]
70
+ decoder_inp = self.get_decode_inp(decoder_inp) # [B, T, H]
71
+ x = decoder_inp
72
+
73
+ '''
74
+ Required x: [B, T, C]
75
+ :return: [B, T, C] or [L, B, T, C]
76
+ '''
77
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
78
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
79
+ if self.use_pos_embed:
80
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
81
+ x = x + positions
82
+ x = F.dropout(x, p=self.dropout, training=self.training)
83
+ # B x T x C -> T x B x C
84
+ x = x.transpose(0, 1) * nonpadding_mask_TB
85
+ hiddens = []
86
+ for layer in self.layers:
87
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
88
+ hiddens.append(x)
89
+ if self.use_last_norm:
90
+ x = self.layer_norm(x) * nonpadding_mask_TB
91
+ if return_hiddens:
92
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
93
+ x = x.transpose(1, 2) # [L, B, T, C]
94
+ else:
95
+ x = x.transpose(0, 1) # [B, T, C]
96
+
97
+ x = self.get_mel_out(x).permute([0, 2, 1]) # [B, 80, T]
98
+ return x[:, None, :, :]
network/diff/diffusion.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from functools import partial
3
+ from inspect import isfunction
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from tqdm import tqdm
9
+
10
+ from modules.fastspeech.fs2 import FastSpeech2
11
+ # from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
12
+ from utils.hparams import hparams
13
+ from training.train_pipeline import Batch2Loss
14
+
15
+
16
+ def exists(x):
17
+ return x is not None
18
+
19
+
20
+ def default(val, d):
21
+ if exists(val):
22
+ return val
23
+ return d() if isfunction(d) else d
24
+
25
+
26
+ # gaussian diffusion trainer class
27
+
28
+ def extract(a, t, x_shape):
29
+ b, *_ = t.shape
30
+ out = a.gather(-1, t)
31
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
32
+
33
+
34
+ def noise_like(shape, device, repeat=False):
35
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
36
+ noise = lambda: torch.randn(shape, device=device)
37
+ return repeat_noise() if repeat else noise()
38
+
39
+
40
+ def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
41
+ """
42
+ linear schedule
43
+ """
44
+ betas = np.linspace(1e-4, max_beta, timesteps)
45
+ return betas
46
+
47
+
48
+ def cosine_beta_schedule(timesteps, s=0.008):
49
+ """
50
+ cosine schedule
51
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
52
+ """
53
+ steps = timesteps + 1
54
+ x = np.linspace(0, steps, steps)
55
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
56
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
57
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
58
+ return np.clip(betas, a_min=0, a_max=0.999)
59
+
60
+
61
+ beta_schedule = {
62
+ "cosine": cosine_beta_schedule,
63
+ "linear": linear_beta_schedule,
64
+ }
65
+
66
+
67
+ class GaussianDiffusion(nn.Module):
68
+ def __init__(self, phone_encoder, out_dims, denoise_fn,
69
+ timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None,
70
+ spec_max=None):
71
+ super().__init__()
72
+ self.denoise_fn = denoise_fn
73
+ # if hparams.get('use_midi') is not None and hparams['use_midi']:
74
+ # self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
75
+ # else:
76
+ self.fs2 = FastSpeech2(phone_encoder, out_dims)
77
+ self.mel_bins = out_dims
78
+
79
+ if exists(betas):
80
+ betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
81
+ else:
82
+ if 'schedule_type' in hparams.keys():
83
+ betas = beta_schedule[hparams['schedule_type']](timesteps)
84
+ else:
85
+ betas = cosine_beta_schedule(timesteps)
86
+
87
+ alphas = 1. - betas
88
+ alphas_cumprod = np.cumprod(alphas, axis=0)
89
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
90
+
91
+ timesteps, = betas.shape
92
+ self.num_timesteps = int(timesteps)
93
+ self.K_step = K_step
94
+ self.loss_type = loss_type
95
+
96
+ self.noise_list = deque(maxlen=4)
97
+
98
+ to_torch = partial(torch.tensor, dtype=torch.float32)
99
+
100
+ self.register_buffer('betas', to_torch(betas))
101
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
102
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
103
+
104
+ # calculations for diffusion q(x_t | x_{t-1}) and others
105
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
106
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
107
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
108
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
109
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
110
+
111
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
112
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
113
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
114
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
115
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
116
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
117
+ self.register_buffer('posterior_mean_coef1', to_torch(
118
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
119
+ self.register_buffer('posterior_mean_coef2', to_torch(
120
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
121
+
122
+ self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
123
+ self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
124
+
125
+ def q_mean_variance(self, x_start, t):
126
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
127
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
128
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
129
+ return mean, variance, log_variance
130
+
131
+ def predict_start_from_noise(self, x_t, t, noise):
132
+ return (
133
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
134
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
135
+ )
136
+
137
+ def q_posterior(self, x_start, x_t, t):
138
+ posterior_mean = (
139
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
140
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
141
+ )
142
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
143
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
144
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
145
+
146
+ def p_mean_variance(self, x, t, cond, clip_denoised: bool):
147
+ noise_pred = self.denoise_fn(x, t, cond=cond)
148
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
149
+
150
+ if clip_denoised:
151
+ x_recon.clamp_(-1., 1.)
152
+
153
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
154
+ return model_mean, posterior_variance, posterior_log_variance
155
+
156
+ @torch.no_grad()
157
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
158
+ b, *_, device = *x.shape, x.device
159
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
160
+ noise = noise_like(x.shape, device, repeat_noise)
161
+ # no noise when t == 0
162
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
163
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
164
+
165
+ @torch.no_grad()
166
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
167
+ """
168
+ Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
169
+ """
170
+
171
+ def get_x_pred(x, noise_t, t):
172
+ a_t = extract(self.alphas_cumprod, t, x.shape)
173
+ a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
174
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
175
+
176
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
177
+ x_pred = x + x_delta
178
+
179
+ return x_pred
180
+
181
+ noise_list = self.noise_list
182
+ noise_pred = self.denoise_fn(x, t, cond=cond)
183
+
184
+ if len(noise_list) == 0:
185
+ x_pred = get_x_pred(x, noise_pred, t)
186
+ noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
187
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
188
+ elif len(noise_list) == 1:
189
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
190
+ elif len(noise_list) == 2:
191
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
192
+ elif len(noise_list) >= 3:
193
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
194
+
195
+ x_prev = get_x_pred(x, noise_pred_prime, t)
196
+ noise_list.append(noise_pred)
197
+
198
+ return x_prev
199
+
200
+ def q_sample(self, x_start, t, noise=None):
201
+ noise = default(noise, lambda: torch.randn_like(x_start))
202
+ return (
203
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
204
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
205
+ )
206
+
207
+ def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
208
+ noise = default(noise, lambda: torch.randn_like(x_start))
209
+
210
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
211
+ x_recon = self.denoise_fn(x_noisy, t, cond)
212
+
213
+ if self.loss_type == 'l1':
214
+ if nonpadding is not None:
215
+ loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
216
+ else:
217
+ # print('are you sure w/o nonpadding?')
218
+ loss = (noise - x_recon).abs().mean()
219
+
220
+ elif self.loss_type == 'l2':
221
+ loss = F.mse_loss(noise, x_recon)
222
+ else:
223
+ raise NotImplementedError()
224
+
225
+ return loss
226
+
227
+ def forward(self, hubert, mel2ph=None, spk_embed=None,
228
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
229
+ '''
230
+ conditioning diffusion, use fastspeech2 encoder output as the condition
231
+ '''
232
+ ret = self.fs2(hubert, mel2ph, spk_embed, None, f0, uv, energy,
233
+ skip_decoder=True, infer=infer, **kwargs)
234
+ cond = ret['decoder_inp'].transpose(1, 2)
235
+ b, *_, device = *hubert.shape, hubert.device
236
+
237
+ if not infer:
238
+ Batch2Loss.module4(
239
+ self.p_losses,
240
+ self.norm_spec(ref_mels), cond, ret, self.K_step, b, device
241
+ )
242
+ else:
243
+ '''
244
+ ret['fs2_mel'] = ret['mel_out']
245
+ fs2_mels = ret['mel_out']
246
+ t = self.K_step
247
+ fs2_mels = self.norm_spec(fs2_mels)
248
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
249
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
250
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
251
+ print('===> gaussion start.')
252
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
253
+ x = torch.randn(shape, device=device)
254
+ '''
255
+ if 'use_gt_mel' in kwargs.keys() and kwargs['use_gt_mel']:
256
+ t =kwargs['add_noise_step']
257
+ print('===>using ground truth mel as start, please make sure parameter "key==0" !')
258
+ fs2_mels = ref_mels
259
+ fs2_mels = self.norm_spec(fs2_mels)
260
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
261
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
262
+ # for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
263
+ # x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
264
+ else:
265
+ t = self.K_step
266
+ #print('===> gaussion start.')
267
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
268
+ x = torch.randn(shape, device=device)
269
+ if hparams.get('pndm_speedup') and hparams['pndm_speedup'] > 1:
270
+ self.noise_list = deque(maxlen=4)
271
+ iteration_interval =hparams['pndm_speedup']
272
+ for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
273
+ total=t // iteration_interval):
274
+ x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
275
+ cond)
276
+ else:
277
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
278
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
279
+ x = x[:, 0].transpose(1, 2)
280
+ if mel2ph is not None: # for singing
281
+ ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
282
+ else:
283
+ ret['mel_out'] = self.denorm_spec(x)
284
+ return ret
285
+
286
+ def norm_spec(self, x):
287
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
288
+
289
+ def denorm_spec(self, x):
290
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
291
+
292
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
293
+ return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
294
+
295
+ def out2mel(self, x):
296
+ return x
297
+
298
+
299
+ class OfflineGaussianDiffusion(GaussianDiffusion):
300
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
301
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
302
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
303
+
304
+ ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
305
+ skip_decoder=True, infer=True, **kwargs)
306
+ cond = ret['decoder_inp'].transpose(1, 2)
307
+ fs2_mels = ref_mels[1]
308
+ ref_mels = ref_mels[0]
309
+
310
+ if not infer:
311
+ t = torch.randint(0, self.K_step, (b,), device=device).long()
312
+ x = ref_mels
313
+ x = self.norm_spec(x)
314
+ x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
315
+ ret['diff_loss'] = self.p_losses(x, t, cond)
316
+ else:
317
+ t = self.K_step
318
+ fs2_mels = self.norm_spec(fs2_mels)
319
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
320
+
321
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
322
+
323
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
324
+ print('===> gaussion start.')
325
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
326
+ x = torch.randn(shape, device=device)
327
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
328
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
329
+ x = x[:, 0].transpose(1, 2)
330
+ ret['mel_out'] = self.denorm_spec(x)
331
+
332
+ return ret
network/diff/net.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from math import sqrt
8
+
9
+ from utils.hparams import hparams
10
+ from modules.commons.common_layers import Mish
11
+
12
+ Linear = nn.Linear
13
+ ConvTranspose2d = nn.ConvTranspose2d
14
+
15
+
16
+ class AttrDict(dict):
17
+ def __init__(self, *args, **kwargs):
18
+ super(AttrDict, self).__init__(*args, **kwargs)
19
+ self.__dict__ = self
20
+
21
+ def override(self, attrs):
22
+ if isinstance(attrs, dict):
23
+ self.__dict__.update(**attrs)
24
+ elif isinstance(attrs, (list, tuple, set)):
25
+ for attr in attrs:
26
+ self.override(attr)
27
+ elif attrs is not None:
28
+ raise NotImplementedError
29
+ return self
30
+
31
+
32
+ class SinusoidalPosEmb(nn.Module):
33
+ def __init__(self, dim):
34
+ super().__init__()
35
+ self.dim = dim
36
+
37
+ def forward(self, x):
38
+ device = x.device
39
+ half_dim = self.dim // 2
40
+ emb = math.log(10000) / (half_dim - 1)
41
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
42
+ emb = x[:, None] * emb[None, :]
43
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
44
+ return emb
45
+
46
+
47
+ def Conv1d(*args, **kwargs):
48
+ layer = nn.Conv1d(*args, **kwargs)
49
+ nn.init.kaiming_normal_(layer.weight)
50
+ return layer
51
+
52
+
53
+ @torch.jit.script
54
+ def silu(x):
55
+ return x * torch.sigmoid(x)
56
+
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, encoder_hidden, residual_channels, dilation):
60
+ super().__init__()
61
+ self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
62
+ self.diffusion_projection = Linear(residual_channels, residual_channels)
63
+ self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
64
+ self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
65
+
66
+ def forward(self, x, conditioner, diffusion_step):
67
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
68
+ conditioner = self.conditioner_projection(conditioner)
69
+ y = x + diffusion_step
70
+
71
+ y = self.dilated_conv(y) + conditioner
72
+
73
+ gate, filter = torch.chunk(y, 2, dim=1)
74
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
75
+ # gate, filter = torch.split(y, torch.div(y.shape[1], 2), dim=1)
76
+
77
+ y = torch.sigmoid(gate) * torch.tanh(filter)
78
+
79
+ y = self.output_projection(y)
80
+ residual, skip = torch.chunk(y, 2, dim=1)
81
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
82
+ # residual, skip = torch.split(y, torch.div(y.shape[1], 2), dim=1)
83
+
84
+ return (x + residual) / sqrt(2.0), skip
85
+
86
+ class DiffNet(nn.Module):
87
+ def __init__(self, in_dims=80):
88
+ super().__init__()
89
+ self.params = params = AttrDict(
90
+ # Model params
91
+ encoder_hidden=hparams['hidden_size'],
92
+ residual_layers=hparams['residual_layers'],
93
+ residual_channels=hparams['residual_channels'],
94
+ dilation_cycle_length=hparams['dilation_cycle_length'],
95
+ )
96
+ self.input_projection = Conv1d(in_dims, params.residual_channels, 1)
97
+ self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels)
98
+ dim = params.residual_channels
99
+ self.mlp = nn.Sequential(
100
+ nn.Linear(dim, dim * 4),
101
+ Mish(),
102
+ nn.Linear(dim * 4, dim)
103
+ )
104
+ self.residual_layers = nn.ModuleList([
105
+ ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length))
106
+ for i in range(params.residual_layers)
107
+ ])
108
+ self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
109
+ self.output_projection = Conv1d(params.residual_channels, in_dims, 1)
110
+ nn.init.zeros_(self.output_projection.weight)
111
+
112
+ def forward(self, spec, diffusion_step, cond):
113
+ """
114
+
115
+ :param spec: [B, 1, M, T]
116
+ :param diffusion_step: [B, 1]
117
+ :param cond: [B, M, T]
118
+ :return:
119
+ """
120
+ x = spec[:, 0]
121
+ x = self.input_projection(x) # x [B, residual_channel, T]
122
+
123
+ x = F.relu(x)
124
+ diffusion_step = self.diffusion_embedding(diffusion_step)
125
+ diffusion_step = self.mlp(diffusion_step)
126
+ skip = []
127
+ for layer_id, layer in enumerate(self.residual_layers):
128
+ x, skip_connection = layer(x, cond, diffusion_step)
129
+ skip.append(skip_connection)
130
+
131
+ x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
132
+ x = self.skip_projection(x)
133
+ x = F.relu(x)
134
+ x = self.output_projection(x) # [B, 80, T]
135
+ return x[:, None, :, :]
network/hubert/hubert_model.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ from typing import Optional, Tuple
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as t_func
11
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
12
+
13
+ from utils import hparams
14
+
15
+
16
+ class Hubert(nn.Module):
17
+ def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
18
+ super().__init__()
19
+ self._mask = mask
20
+ self.feature_extractor = FeatureExtractor()
21
+ self.feature_projection = FeatureProjection()
22
+ self.positional_embedding = PositionalConvEmbedding()
23
+ self.norm = nn.LayerNorm(768)
24
+ self.dropout = nn.Dropout(0.1)
25
+ self.encoder = TransformerEncoder(
26
+ nn.TransformerEncoderLayer(
27
+ 768, 12, 3072, activation="gelu", batch_first=True
28
+ ),
29
+ 12,
30
+ )
31
+ self.proj = nn.Linear(768, 256)
32
+
33
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
34
+ self.label_embedding = nn.Embedding(num_label_embeddings, 256)
35
+
36
+ def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ mask = None
38
+ if self.training and self._mask:
39
+ mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
40
+ x[mask] = self.masked_spec_embed.to(x.dtype)
41
+ return x, mask
42
+
43
+ def encode(
44
+ self, x: torch.Tensor, layer: Optional[int] = None
45
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ x = self.feature_extractor(x)
47
+ x = self.feature_projection(x.transpose(1, 2))
48
+ x, mask = self.mask(x)
49
+ x = x + self.positional_embedding(x)
50
+ x = self.dropout(self.norm(x))
51
+ x = self.encoder(x, output_layer=layer)
52
+ return x, mask
53
+
54
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
55
+ logits = torch.cosine_similarity(
56
+ x.unsqueeze(2),
57
+ self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
58
+ dim=-1,
59
+ )
60
+ return logits / 0.1
61
+
62
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ x, mask = self.encode(x)
64
+ x = self.proj(x)
65
+ logits = self.logits(x)
66
+ return logits, mask
67
+
68
+
69
+ class HubertSoft(Hubert):
70
+ def __init__(self):
71
+ super().__init__()
72
+
73
+ # @torch.inference_mode()
74
+ def units(self, wav: torch.Tensor) -> torch.Tensor:
75
+ wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
76
+ x, _ = self.encode(wav)
77
+ return self.proj(x)
78
+
79
+ def forward(self, wav: torch.Tensor):
80
+ return self.units(wav)
81
+
82
+
83
+ class FeatureExtractor(nn.Module):
84
+ def __init__(self):
85
+ super().__init__()
86
+ self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
87
+ self.norm0 = nn.GroupNorm(512, 512)
88
+ self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
89
+ self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
90
+ self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
91
+ self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
92
+ self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
93
+ self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ x = t_func.gelu(self.norm0(self.conv0(x)))
97
+ x = t_func.gelu(self.conv1(x))
98
+ x = t_func.gelu(self.conv2(x))
99
+ x = t_func.gelu(self.conv3(x))
100
+ x = t_func.gelu(self.conv4(x))
101
+ x = t_func.gelu(self.conv5(x))
102
+ x = t_func.gelu(self.conv6(x))
103
+ return x
104
+
105
+
106
+ class FeatureProjection(nn.Module):
107
+ def __init__(self):
108
+ super().__init__()
109
+ self.norm = nn.LayerNorm(512)
110
+ self.projection = nn.Linear(512, 768)
111
+ self.dropout = nn.Dropout(0.1)
112
+
113
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
114
+ x = self.norm(x)
115
+ x = self.projection(x)
116
+ x = self.dropout(x)
117
+ return x
118
+
119
+
120
+ class PositionalConvEmbedding(nn.Module):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self.conv = nn.Conv1d(
124
+ 768,
125
+ 768,
126
+ kernel_size=128,
127
+ padding=128 // 2,
128
+ groups=16,
129
+ )
130
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ x = self.conv(x.transpose(1, 2))
134
+ x = t_func.gelu(x[:, :, :-1])
135
+ return x.transpose(1, 2)
136
+
137
+
138
+ class TransformerEncoder(nn.Module):
139
+ def __init__(
140
+ self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
141
+ ) -> None:
142
+ super(TransformerEncoder, self).__init__()
143
+ self.layers = nn.ModuleList(
144
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
145
+ )
146
+ self.num_layers = num_layers
147
+
148
+ def forward(
149
+ self,
150
+ src: torch.Tensor,
151
+ mask: torch.Tensor = None,
152
+ src_key_padding_mask: torch.Tensor = None,
153
+ output_layer: Optional[int] = None,
154
+ ) -> torch.Tensor:
155
+ output = src
156
+ for layer in self.layers[:output_layer]:
157
+ output = layer(
158
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
159
+ )
160
+ return output
161
+
162
+
163
+ def _compute_mask(
164
+ shape: Tuple[int, int],
165
+ mask_prob: float,
166
+ mask_length: int,
167
+ device: torch.device,
168
+ min_masks: int = 0,
169
+ ) -> torch.Tensor:
170
+ batch_size, sequence_length = shape
171
+
172
+ if mask_length < 1:
173
+ raise ValueError("`mask_length` has to be bigger than 0.")
174
+
175
+ if mask_length > sequence_length:
176
+ raise ValueError(
177
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
178
+ )
179
+
180
+ # compute number of masked spans in batch
181
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
182
+ num_masked_spans = max(num_masked_spans, min_masks)
183
+
184
+ # make sure num masked indices <= sequence_length
185
+ if num_masked_spans * mask_length > sequence_length:
186
+ num_masked_spans = sequence_length // mask_length
187
+
188
+ # SpecAugment mask to fill
189
+ mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
190
+
191
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
192
+ uniform_dist = torch.ones(
193
+ (batch_size, sequence_length - (mask_length - 1)), device=device
194
+ )
195
+
196
+ # get random indices to mask
197
+ mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
198
+
199
+ # expand masked indices to masked spans
200
+ mask_indices = (
201
+ mask_indices.unsqueeze(dim=-1)
202
+ .expand((batch_size, num_masked_spans, mask_length))
203
+ .reshape(batch_size, num_masked_spans * mask_length)
204
+ )
205
+ offsets = (
206
+ torch.arange(mask_length, device=device)[None, None, :]
207
+ .expand((batch_size, num_masked_spans, mask_length))
208
+ .reshape(batch_size, num_masked_spans * mask_length)
209
+ )
210
+ mask_idxs = mask_indices + offsets
211
+
212
+ # scatter indices to mask
213
+ mask = mask.scatter(1, mask_idxs, True)
214
+
215
+ return mask
216
+
217
+
218
+ def hubert_soft(
219
+ path: str
220
+ ) -> HubertSoft:
221
+ r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
222
+ Args:
223
+ path (str): path of a pretrained model
224
+ """
225
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
+ hubert = HubertSoft()
227
+ checkpoint = torch.load(path)
228
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
229
+ hubert.load_state_dict(checkpoint)
230
+ hubert.eval().to(dev)
231
+ return hubert
232
+
233
+
234
+ def get_units(hbt_soft, raw_wav_path, dev=torch.device('cuda')):
235
+ wav, sr = librosa.load(raw_wav_path, sr=None)
236
+ assert (sr >= 16000)
237
+ if len(wav.shape) > 1:
238
+ wav = librosa.to_mono(wav)
239
+ if sr != 16000:
240
+ wav16 = librosa.resample(wav, sr, 16000)
241
+ else:
242
+ wav16 = wav
243
+ dev = torch.device("cuda" if (dev == torch.device('cuda') and torch.cuda.is_available()) else "cpu")
244
+ torch.cuda.is_available() and torch.cuda.empty_cache()
245
+ with torch.inference_mode():
246
+ units = hbt_soft.units(torch.FloatTensor(wav16.astype(float)).unsqueeze(0).unsqueeze(0).to(dev))
247
+ return units
248
+
249
+
250
+ def get_end_file(dir_path, end):
251
+ file_list = []
252
+ for root, dirs, files in os.walk(dir_path):
253
+ files = [f for f in files if f[0] != '.']
254
+ dirs[:] = [d for d in dirs if d[0] != '.']
255
+ for f_file in files:
256
+ if f_file.endswith(end):
257
+ file_list.append(os.path.join(root, f_file).replace("\\", "/"))
258
+ return file_list
259
+
260
+
261
+ if __name__ == '__main__':
262
+ from pathlib import Path
263
+
264
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
265
+ # hubert的模型路径
266
+ hbt_model = hubert_soft(str(list(Path(hparams['hubert_path']).home().rglob('*.pt'))[0]))
267
+ # 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
268
+ file_lists = list(Path(hparams['raw_data_dir']).rglob('*.wav'))
269
+ nums = len(file_lists)
270
+ count = 0
271
+ for wav_path in file_lists:
272
+ npy_path = wav_path.with_suffix(".npy")
273
+ npy_content = get_units(hbt_model, wav_path).cpu().numpy()[0]
274
+ np.save(str(npy_path), npy_content)
275
+ count += 1
276
+ print(f"hubert process:{round(count * 100 / nums, 2)}%")
network/hubert/vec_model.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+
9
+ def load_model(vec_path):
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print("load model(s) from {}".format(vec_path))
12
+ from fairseq import checkpoint_utils
13
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
14
+ [vec_path],
15
+ suffix="",
16
+ )
17
+ model = models[0]
18
+ model = model.to(device)
19
+ model.eval()
20
+ return model
21
+
22
+
23
+ def get_vec_units(con_model, audio_path, dev):
24
+ audio, sampling_rate = librosa.load(audio_path)
25
+ if len(audio.shape) > 1:
26
+ audio = librosa.to_mono(audio.transpose(1, 0))
27
+ if sampling_rate != 16000:
28
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
29
+
30
+ feats = torch.from_numpy(audio).float()
31
+ if feats.dim() == 2: # double channels
32
+ feats = feats.mean(-1)
33
+ assert feats.dim() == 1, feats.dim()
34
+ feats = feats.view(1, -1)
35
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
36
+ inputs = {
37
+ "source": feats.to(dev),
38
+ "padding_mask": padding_mask.to(dev),
39
+ "output_layer": 9, # layer 9
40
+ }
41
+ with torch.no_grad():
42
+ logits = con_model.extract_features(**inputs)
43
+ feats = con_model.final_proj(logits[0])
44
+ return feats
45
+
46
+
47
+ if __name__ == '__main__':
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model_path = "../../checkpoints/checkpoint_best_legacy_500.pt" # checkpoint_best_legacy_500.pt
50
+ vec_model = load_model(model_path)
51
+ # 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
52
+ file_lists = list(Path("../../data/vecfox").rglob('*.wav'))
53
+ nums = len(file_lists)
54
+ count = 0
55
+ for wav_path in file_lists:
56
+ npy_path = wav_path.with_suffix(".npy")
57
+ npy_content = get_vec_units(vec_model, str(wav_path), device).cpu().numpy()[0]
58
+ np.save(str(npy_path), npy_content)
59
+ count += 1
60
+ print(f"hubert process:{round(count * 100 / nums, 2)}%")
network/vocoders/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from network.vocoders import hifigan
2
+ from network.vocoders import nsf_hifigan
network/vocoders/base_vocoder.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ VOCODERS = {}
3
+
4
+
5
+ def register_vocoder(cls):
6
+ VOCODERS[cls.__name__.lower()] = cls
7
+ VOCODERS[cls.__name__] = cls
8
+ return cls
9
+
10
+
11
+ def get_vocoder_cls(hparams):
12
+ if hparams['vocoder'] in VOCODERS:
13
+ return VOCODERS[hparams['vocoder']]
14
+ else:
15
+ vocoder_cls = hparams['vocoder']
16
+ pkg = ".".join(vocoder_cls.split(".")[:-1])
17
+ cls_name = vocoder_cls.split(".")[-1]
18
+ vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
19
+ return vocoder_cls
20
+
21
+
22
+ class BaseVocoder:
23
+ def spec2wav(self, mel):
24
+ """
25
+
26
+ :param mel: [T, 80]
27
+ :return: wav: [T']
28
+ """
29
+
30
+ raise NotImplementedError
31
+
32
+ @staticmethod
33
+ def wav2spec(wav_fn):
34
+ """
35
+
36
+ :param wav_fn: str
37
+ :return: wav, mel: [T, 80]
38
+ """
39
+ raise NotImplementedError
network/vocoders/hifigan.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ import librosa
7
+ import torch
8
+
9
+ import utils
10
+ from modules.hifigan.hifigan import HifiGanGenerator
11
+ from utils.hparams import hparams, set_hparams
12
+ from network.vocoders.base_vocoder import register_vocoder
13
+ from network.vocoders.pwg import PWG
14
+ from network.vocoders.vocoder_utils import denoise
15
+
16
+
17
+ def load_model(config_path, file_path):
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ ext = os.path.splitext(file_path)[-1]
20
+ if ext == '.pth':
21
+ if '.yaml' in config_path:
22
+ config = set_hparams(config_path, global_hparams=False)
23
+ elif '.json' in config_path:
24
+ config = json.load(open(config_path, 'r', encoding='utf-8'))
25
+ model = torch.load(file_path, map_location="cpu")
26
+ elif ext == '.ckpt':
27
+ ckpt_dict = torch.load(file_path, map_location="cpu")
28
+ if '.yaml' in config_path:
29
+ config = set_hparams(config_path, global_hparams=False)
30
+ state = ckpt_dict["state_dict"]["model_gen"]
31
+ elif '.json' in config_path:
32
+ config = json.load(open(config_path, 'r', encoding='utf-8'))
33
+ state = ckpt_dict["generator"]
34
+ model = HifiGanGenerator(config)
35
+ model.load_state_dict(state, strict=True)
36
+ model.remove_weight_norm()
37
+ model = model.eval().to(device)
38
+ print(f"| Loaded model parameters from {file_path}.")
39
+ print(f"| HifiGAN device: {device}.")
40
+ return model, config, device
41
+
42
+
43
+ total_time = 0
44
+
45
+
46
+ @register_vocoder
47
+ class HifiGAN(PWG):
48
+ def __init__(self):
49
+ base_dir = hparams['vocoder_ckpt']
50
+ config_path = f'{base_dir}/config.yaml'
51
+ if os.path.exists(config_path):
52
+ file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
53
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
54
+ print('| load HifiGAN: ', file_path)
55
+ self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
56
+ else:
57
+ config_path = f'{base_dir}/config.json'
58
+ ckpt = f'{base_dir}/generator_v1'
59
+ if os.path.exists(config_path):
60
+ self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
61
+
62
+ def spec2wav(self, mel, **kwargs):
63
+ device = self.device
64
+ with torch.no_grad():
65
+ c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
66
+ with utils.Timer('hifigan', print_time=hparams['profile_infer']):
67
+ f0 = kwargs.get('f0')
68
+ if f0 is not None and hparams.get('use_nsf'):
69
+ f0 = torch.FloatTensor(f0[None, :]).to(device)
70
+ y = self.model(c, f0).view(-1)
71
+ else:
72
+ y = self.model(c).view(-1)
73
+ wav_out = y.cpu().numpy()
74
+ if hparams.get('vocoder_denoise_c', 0.0) > 0:
75
+ wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c'])
76
+ return wav_out
77
+
78
+ # @staticmethod
79
+ # def wav2spec(wav_fn, **kwargs):
80
+ # wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate'])
81
+ # wav_torch = torch.FloatTensor(wav)[None, :]
82
+ # mel = mel_spectrogram(wav_torch, hparams).numpy()[0]
83
+ # return wav, mel.T
network/vocoders/nsf_hifigan.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from modules.nsf_hifigan.models import load_model, Generator
4
+ from modules.nsf_hifigan.nvSTFT import load_wav_to_torch, STFT
5
+ from utils.hparams import hparams
6
+ from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
7
+
8
+ @register_vocoder
9
+ class NsfHifiGAN(BaseVocoder):
10
+ def __init__(self, device=None):
11
+ if device is None:
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ self.device = device
14
+ model_path = hparams['vocoder_ckpt']
15
+ if os.path.exists(model_path):
16
+ print('| Load HifiGAN: ', model_path)
17
+ self.model, self.h = load_model(model_path, device=self.device)
18
+ else:
19
+ print('Error: HifiGAN model file is not found!')
20
+
21
+ def spec2wav_torch(self, mel, **kwargs): # mel: [B, T, bins]
22
+ if self.h.sampling_rate != hparams['audio_sample_rate']:
23
+ print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
24
+ if self.h.num_mels != hparams['audio_num_mel_bins']:
25
+ print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
26
+ if self.h.n_fft != hparams['fft_size']:
27
+ print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
28
+ if self.h.win_size != hparams['win_size']:
29
+ print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
30
+ if self.h.hop_size != hparams['hop_size']:
31
+ print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
32
+ if self.h.fmin != hparams['fmin']:
33
+ print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
34
+ if self.h.fmax != hparams['fmax']:
35
+ print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
36
+ with torch.no_grad():
37
+ c = mel.transpose(2, 1) #[B, T, bins]
38
+ #log10 to log mel
39
+ c = 2.30259 * c
40
+ f0 = kwargs.get('f0') #[B, T]
41
+ if f0 is not None and hparams.get('use_nsf'):
42
+ y = self.model(c, f0).view(-1)
43
+ else:
44
+ y = self.model(c).view(-1)
45
+ return y
46
+
47
+ def spec2wav(self, mel, **kwargs):
48
+ if self.h.sampling_rate != hparams['audio_sample_rate']:
49
+ print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
50
+ if self.h.num_mels != hparams['audio_num_mel_bins']:
51
+ print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
52
+ if self.h.n_fft != hparams['fft_size']:
53
+ print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
54
+ if self.h.win_size != hparams['win_size']:
55
+ print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
56
+ if self.h.hop_size != hparams['hop_size']:
57
+ print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
58
+ if self.h.fmin != hparams['fmin']:
59
+ print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
60
+ if self.h.fmax != hparams['fmax']:
61
+ print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
62
+ with torch.no_grad():
63
+ c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(self.device)
64
+ #log10 to log mel
65
+ c = 2.30259 * c
66
+ f0 = kwargs.get('f0')
67
+ if f0 is not None and hparams.get('use_nsf'):
68
+ f0 = torch.FloatTensor(f0[None, :]).to(self.device)
69
+ y = self.model(c, f0).view(-1)
70
+ else:
71
+ y = self.model(c).view(-1)
72
+ wav_out = y.cpu().numpy()
73
+ return wav_out
74
+
75
+ @staticmethod
76
+ def wav2spec(inp_path, device=None):
77
+ if device is None:
78
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
79
+ sampling_rate = hparams['audio_sample_rate']
80
+ num_mels = hparams['audio_num_mel_bins']
81
+ n_fft = hparams['fft_size']
82
+ win_size =hparams['win_size']
83
+ hop_size = hparams['hop_size']
84
+ fmin = hparams['fmin']
85
+ fmax = hparams['fmax']
86
+ stft = STFT(sampling_rate, num_mels, n_fft, win_size, hop_size, fmin, fmax)
87
+ with torch.no_grad():
88
+ wav_torch, _ = load_wav_to_torch(inp_path, target_sr=stft.target_sr)
89
+ mel_torch = stft.get_mel(wav_torch.unsqueeze(0).to(device)).squeeze(0).T
90
+ #log mel to log10 mel
91
+ mel_torch = 0.434294 * mel_torch
92
+ return wav_torch.cpu().numpy(), mel_torch.cpu().numpy()
network/vocoders/pwg.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import re
3
+ import librosa
4
+ import torch
5
+ import yaml
6
+ from sklearn.preprocessing import StandardScaler
7
+ from torch import nn
8
+ from modules.parallel_wavegan.models import ParallelWaveGANGenerator
9
+ from modules.parallel_wavegan.utils import read_hdf5
10
+ from utils.hparams import hparams
11
+ from utils.pitch_utils import f0_to_coarse
12
+ from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
13
+ import numpy as np
14
+
15
+
16
+ def load_pwg_model(config_path, checkpoint_path, stats_path):
17
+ # load config
18
+ with open(config_path, encoding='utf-8') as f:
19
+ config = yaml.load(f, Loader=yaml.Loader)
20
+
21
+ # setup
22
+ if torch.cuda.is_available():
23
+ device = torch.device("cuda")
24
+ else:
25
+ device = torch.device("cpu")
26
+ model = ParallelWaveGANGenerator(**config["generator_params"])
27
+
28
+ ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
29
+ if 'state_dict' not in ckpt_dict: # official vocoder
30
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"])
31
+ scaler = StandardScaler()
32
+ if config["format"] == "hdf5":
33
+ scaler.mean_ = read_hdf5(stats_path, "mean")
34
+ scaler.scale_ = read_hdf5(stats_path, "scale")
35
+ elif config["format"] == "npy":
36
+ scaler.mean_ = np.load(stats_path)[0]
37
+ scaler.scale_ = np.load(stats_path)[1]
38
+ else:
39
+ raise ValueError("support only hdf5 or npy format.")
40
+ else: # custom PWG vocoder
41
+ fake_task = nn.Module()
42
+ fake_task.model_gen = model
43
+ fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False)
44
+ scaler = None
45
+
46
+ model.remove_weight_norm()
47
+ model = model.eval().to(device)
48
+ print(f"| Loaded model parameters from {checkpoint_path}.")
49
+ print(f"| PWG device: {device}.")
50
+ return model, scaler, config, device
51
+
52
+
53
+ @register_vocoder
54
+ class PWG(BaseVocoder):
55
+ def __init__(self):
56
+ if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model
57
+ base_dir = 'wavegan_pretrained'
58
+ ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
59
+ ckpt = sorted(ckpts, key=
60
+ lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
61
+ config_path = f'{base_dir}/config.yaml'
62
+ print('| load PWG: ', ckpt)
63
+ self.model, self.scaler, self.config, self.device = load_pwg_model(
64
+ config_path=config_path,
65
+ checkpoint_path=ckpt,
66
+ stats_path=f'{base_dir}/stats.h5',
67
+ )
68
+ else:
69
+ base_dir = hparams['vocoder_ckpt']
70
+ print(base_dir)
71
+ config_path = f'{base_dir}/config.yaml'
72
+ ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
73
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
74
+ print('| load PWG: ', ckpt)
75
+ self.scaler = None
76
+ self.model, _, self.config, self.device = load_pwg_model(
77
+ config_path=config_path,
78
+ checkpoint_path=ckpt,
79
+ stats_path=f'{base_dir}/stats.h5',
80
+ )
81
+
82
+ def spec2wav(self, mel, **kwargs):
83
+ # start generation
84
+ config = self.config
85
+ device = self.device
86
+ pad_size = (config["generator_params"]["aux_context_window"],
87
+ config["generator_params"]["aux_context_window"])
88
+ c = mel
89
+ if self.scaler is not None:
90
+ c = self.scaler.transform(c)
91
+
92
+ with torch.no_grad():
93
+ z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
94
+ c = np.pad(c, (pad_size, (0, 0)), "edge")
95
+ c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
96
+ p = kwargs.get('f0')
97
+ if p is not None:
98
+ p = f0_to_coarse(p)
99
+ p = np.pad(p, (pad_size,), "edge")
100
+ p = torch.LongTensor(p[None, :]).to(device)
101
+ y = self.model(z, c, p).view(-1)
102
+ wav_out = y.cpu().numpy()
103
+ return wav_out
104
+
105
+ @staticmethod
106
+ def wav2spec(wav_fn, return_linear=False):
107
+ from preprocessing.data_gen_utils import process_utterance
108
+ res = process_utterance(
109
+ wav_fn, fft_size=hparams['fft_size'],
110
+ hop_size=hparams['hop_size'],
111
+ win_length=hparams['win_size'],
112
+ num_mels=hparams['audio_num_mel_bins'],
113
+ fmin=hparams['fmin'],
114
+ fmax=hparams['fmax'],
115
+ sample_rate=hparams['audio_sample_rate'],
116
+ loud_norm=hparams['loud_norm'],
117
+ min_level_db=hparams['min_level_db'],
118
+ return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10)))
119
+ if return_linear:
120
+ return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft]
121
+ else:
122
+ return res[0], res[1].T
123
+
124
+ @staticmethod
125
+ def wav2mfcc(wav_fn):
126
+ fft_size = hparams['fft_size']
127
+ hop_size = hparams['hop_size']
128
+ win_length = hparams['win_size']
129
+ sample_rate = hparams['audio_sample_rate']
130
+ wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
131
+ mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
132
+ n_fft=fft_size, hop_length=hop_size,
133
+ win_length=win_length, pad_mode="constant", power=1.0)
134
+ mfcc_delta = librosa.feature.delta(mfcc, order=1)
135
+ mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
136
+ mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
137
+ return mfcc
network/vocoders/vocoder_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+
3
+ from utils.hparams import hparams
4
+ import numpy as np
5
+
6
+
7
+ def denoise(wav, v=0.1):
8
+ spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'],
9
+ win_length=hparams['win_size'], pad_mode='constant')
10
+ spec_m = np.abs(spec)
11
+ spec_m = np.clip(spec_m - v, a_min=0, a_max=None)
12
+ spec_a = np.angle(spec)
13
+
14
+ return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'],
15
+ win_length=hparams['win_size'])
preprocessing/SVCpre.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+
3
+ item: one piece of data
4
+ item_name: data id
5
+ wavfn: wave file path
6
+ txt: lyrics
7
+ ph: phoneme
8
+ tgfn: text grid file path (unused)
9
+ spk: dataset name
10
+ wdb: word boundary
11
+ ph_durs: phoneme durations
12
+ midi: pitch as midi notes
13
+ midi_dur: midi duration
14
+ is_slur: keep singing upon note changes
15
+ '''
16
+
17
+
18
+ from copy import deepcopy
19
+
20
+ import logging
21
+
22
+ from preprocessing.process_pipeline import File2Batch
23
+ from utils.hparams import hparams
24
+ from preprocessing.base_binarizer import BaseBinarizer
25
+
26
+ SVCSINGING_ITEM_ATTRIBUTES = ['wav_fn', 'spk_id']
27
+ class SVCBinarizer(BaseBinarizer):
28
+ def __init__(self, item_attributes=SVCSINGING_ITEM_ATTRIBUTES):
29
+ super().__init__(item_attributes)
30
+ print('spkers: ', set(item['spk_id'] for item in self.items.values()))
31
+ self.item_names = sorted(list(self.items.keys()))
32
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
33
+ # self._valid_item_names=[]
34
+
35
+ def split_train_test_set(self, item_names):
36
+ item_names = deepcopy(item_names)
37
+ if hparams['choose_test_manually']:
38
+ test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
39
+ else:
40
+ test_item_names = item_names[-5:]
41
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
42
+ logging.info("train {}".format(len(train_item_names)))
43
+ logging.info("test {}".format(len(test_item_names)))
44
+ return train_item_names, test_item_names
45
+
46
+ @property
47
+ def train_item_names(self):
48
+ return self._train_item_names
49
+
50
+ @property
51
+ def valid_item_names(self):
52
+ return self._test_item_names
53
+
54
+ @property
55
+ def test_item_names(self):
56
+ return self._test_item_names
57
+
58
+ def load_meta_data(self):
59
+ self.items = File2Batch.file2temporary_dict()
60
+
61
+ def _phone_encoder(self):
62
+ from preprocessing.hubertinfer import Hubertencoder
63
+ return Hubertencoder(hparams['hubert_path'])
preprocessing/base_binarizer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from webbrowser import get
3
+ os.environ["OMP_NUM_THREADS"] = "1"
4
+ import yaml
5
+ from utils.multiprocess_utils import chunked_multiprocess_run
6
+ import random
7
+ import json
8
+ # from resemblyzer import VoiceEncoder
9
+ from tqdm import tqdm
10
+ from preprocessing.data_gen_utils import get_mel2ph, get_pitch_parselmouth, build_phone_encoder,get_pitch_crepe
11
+ from utils.hparams import set_hparams, hparams
12
+ import numpy as np
13
+ from utils.indexed_datasets import IndexedDatasetBuilder
14
+
15
+
16
+ class BinarizationError(Exception):
17
+ pass
18
+
19
+ BASE_ITEM_ATTRIBUTES = ['txt', 'ph', 'wav_fn', 'tg_fn', 'spk_id']
20
+
21
+ class BaseBinarizer:
22
+ '''
23
+ Base class for data processing.
24
+ 1. *process* and *process_data_split*:
25
+ process entire data, generate the train-test split (support parallel processing);
26
+ 2. *process_item*:
27
+ process singe piece of data;
28
+ 3. *get_pitch*:
29
+ infer the pitch using some algorithm;
30
+ 4. *get_align*:
31
+ get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
32
+ 5. phoneme encoder, voice encoder, etc.
33
+
34
+ Subclasses should define:
35
+ 1. *load_metadata*:
36
+ how to read multiple datasets from files;
37
+ 2. *train_item_names*, *valid_item_names*, *test_item_names*:
38
+ how to split the dataset;
39
+ 3. load_ph_set:
40
+ the phoneme set.
41
+ '''
42
+ def __init__(self, item_attributes=BASE_ITEM_ATTRIBUTES):
43
+ self.binarization_args = hparams['binarization_args']
44
+ #self.pre_align_args = hparams['pre_align_args']
45
+
46
+ self.items = {}
47
+ # every item in self.items has some attributes
48
+ self.item_attributes = item_attributes
49
+
50
+ self.load_meta_data()
51
+ # check program correctness 检查itemdict的key只能在给定的列表中取值
52
+ assert all([attr in self.item_attributes for attr in list(self.items.values())[0].keys()])
53
+ self.item_names = sorted(list(self.items.keys()))
54
+
55
+ if self.binarization_args['shuffle']:
56
+ random.seed(1234)
57
+ random.shuffle(self.item_names)
58
+
59
+ # set default get_pitch algorithm
60
+ if hparams['use_crepe']:
61
+ self.get_pitch_algorithm = get_pitch_crepe
62
+ else:
63
+ self.get_pitch_algorithm = get_pitch_parselmouth
64
+
65
+ def load_meta_data(self):
66
+ raise NotImplementedError
67
+
68
+ @property
69
+ def train_item_names(self):
70
+ raise NotImplementedError
71
+
72
+ @property
73
+ def valid_item_names(self):
74
+ raise NotImplementedError
75
+
76
+ @property
77
+ def test_item_names(self):
78
+ raise NotImplementedError
79
+
80
+ def build_spk_map(self):
81
+ spk_map = set()
82
+ for item_name in self.item_names:
83
+ spk_name = self.items[item_name]['spk_id']
84
+ spk_map.add(spk_name)
85
+ spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
86
+ assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
87
+ return spk_map
88
+
89
+ def item_name2spk_id(self, item_name):
90
+ return self.spk_map[self.items[item_name]['spk_id']]
91
+
92
+ def _phone_encoder(self):
93
+ '''
94
+ use hubert encoder
95
+ '''
96
+ raise NotImplementedError
97
+ '''
98
+ create 'phone_set.json' file if it doesn't exist
99
+ '''
100
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
101
+ ph_set = []
102
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
103
+ self.load_ph_set(ph_set)
104
+ ph_set = sorted(set(ph_set))
105
+ json.dump(ph_set, open(ph_set_fn, 'w', encoding='utf-8'))
106
+ print("| Build phone set: ", ph_set)
107
+ else:
108
+ ph_set = json.load(open(ph_set_fn, 'r', encoding='utf-8'))
109
+ print("| Load phone set: ", ph_set)
110
+ return build_phone_encoder(hparams['binary_data_dir'])
111
+
112
+
113
+ def load_ph_set(self, ph_set):
114
+ raise NotImplementedError
115
+
116
+ def meta_data_iterator(self, prefix):
117
+ if prefix == 'valid':
118
+ item_names = self.valid_item_names
119
+ elif prefix == 'test':
120
+ item_names = self.test_item_names
121
+ else:
122
+ item_names = self.train_item_names
123
+ for item_name in item_names:
124
+ meta_data = self.items[item_name]
125
+ yield item_name, meta_data
126
+
127
+ def process(self):
128
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
129
+ self.spk_map = self.build_spk_map()
130
+ print("| spk_map: ", self.spk_map)
131
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
132
+ json.dump(self.spk_map, open(spk_map_fn, 'w', encoding='utf-8'))
133
+
134
+ self.phone_encoder =self._phone_encoder()
135
+ self.process_data_split('valid')
136
+ self.process_data_split('test')
137
+ self.process_data_split('train')
138
+
139
+ def process_data_split(self, prefix):
140
+ data_dir = hparams['binary_data_dir']
141
+ args = []
142
+ builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
143
+ lengths = []
144
+ f0s = []
145
+ total_sec = 0
146
+ # if self.binarization_args['with_spk_embed']:
147
+ # voice_encoder = VoiceEncoder().cuda()
148
+
149
+ for item_name, meta_data in self.meta_data_iterator(prefix):
150
+ args.append([item_name, meta_data, self.binarization_args])
151
+ spec_min=[]
152
+ spec_max=[]
153
+ # code for single cpu processing
154
+ for i in tqdm(reversed(range(len(args))), total=len(args)):
155
+ a = args[i]
156
+ item = self.process_item(*a)
157
+ if item is None:
158
+ continue
159
+ spec_min.append(item['spec_min'])
160
+ spec_max.append(item['spec_max'])
161
+ # item['spk_embe'] = voice_encoder.embed_utterance(item['wav']) \
162
+ # if self.binardization_args['with_spk_embed'] else None
163
+ if not self.binarization_args['with_wav'] and 'wav' in item:
164
+ if hparams['debug']:
165
+ print("del wav")
166
+ del item['wav']
167
+ if(hparams['debug']):
168
+ print(item)
169
+ builder.add_item(item)
170
+ lengths.append(item['len'])
171
+ total_sec += item['sec']
172
+ # if item.get('f0') is not None:
173
+ # f0s.append(item['f0'])
174
+ if prefix=='train':
175
+ spec_max=np.max(spec_max,0)
176
+ spec_min=np.min(spec_min,0)
177
+ print(spec_max.shape)
178
+ with open(hparams['config_path'], encoding='utf-8') as f:
179
+ _hparams=yaml.safe_load(f)
180
+ _hparams['spec_max']=spec_max.tolist()
181
+ _hparams['spec_min']=spec_min.tolist()
182
+ with open(hparams['config_path'], 'w', encoding='utf-8') as f:
183
+ yaml.safe_dump(_hparams,f)
184
+ builder.finalize()
185
+ np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
186
+ if len(f0s) > 0:
187
+ f0s = np.concatenate(f0s, 0)
188
+ f0s = f0s[f0s != 0]
189
+ np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
190
+ print(f"| {prefix} total duration: {total_sec:.3f}s")
191
+
192
+ def process_item(self, item_name, meta_data, binarization_args):
193
+ from preprocessing.process_pipeline import File2Batch
194
+ return File2Batch.temporary_dict2processed_input(item_name, meta_data, self.phone_encoder, binarization_args)
195
+
196
+ def get_align(self, meta_data, mel, phone_encoded, res):
197
+ raise NotImplementedError
198
+
199
+ def get_align_from_textgrid(self, meta_data, mel, phone_encoded, res):
200
+ '''
201
+ NOTE: this part of script is *isolated* from other scripts, which means
202
+ it may not be compatible with the current version.
203
+ '''
204
+ return
205
+ tg_fn, ph = meta_data['tg_fn'], meta_data['ph']
206
+ if tg_fn is not None and os.path.exists(tg_fn):
207
+ mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
208
+ else:
209
+ raise BinarizationError(f"Align not found")
210
+ if mel2ph.max() - 1 >= len(phone_encoded):
211
+ raise BinarizationError(
212
+ f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
213
+ res['mel2ph'] = mel2ph
214
+ res['dur'] = dur
215
+
216
+ def get_f0cwt(self, f0, res):
217
+ '''
218
+ NOTE: this part of script is *isolated* from other scripts, which means
219
+ it may not be compatible with the current version.
220
+ '''
221
+ return
222
+ from utils.cwt import get_cont_lf0, get_lf0_cwt
223
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
224
+ logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
225
+ cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
226
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
227
+ if np.any(np.isnan(Wavelet_lf0)):
228
+ raise BinarizationError("NaN CWT")
229
+ res['cwt_spec'] = Wavelet_lf0
230
+ res['cwt_scales'] = scales
231
+ res['f0_mean'] = logf0s_mean_org
232
+ res['f0_std'] = logf0s_std_org
233
+
234
+
235
+ if __name__ == "__main__":
236
+ set_hparams()
237
+ BaseBinarizer().process()