Spaces:
Runtime error
Runtime error
10kwon
commited on
Commit
·
2bfc29a
1
Parent(s):
10f66bf
DiffSVC
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- batch.py +43 -0
- flask_api.py +54 -0
- infer.py +98 -0
- infer_tools/__init__.py +0 -0
- infer_tools/infer_tool.py +334 -0
- infer_tools/slicer.py +158 -0
- modules/commons/common_layers.py +671 -0
- modules/commons/espnet_positional_embedding.py +113 -0
- modules/commons/ssim.py +391 -0
- modules/fastspeech/fs2.py +255 -0
- modules/fastspeech/pe.py +149 -0
- modules/fastspeech/tts_modules.py +364 -0
- modules/hifigan/hifigan.py +365 -0
- modules/hifigan/mel_utils.py +80 -0
- modules/nsf_hifigan/env.py +15 -0
- modules/nsf_hifigan/models.py +549 -0
- modules/nsf_hifigan/nvSTFT.py +111 -0
- modules/nsf_hifigan/utils.py +67 -0
- modules/parallel_wavegan/__init__.py +0 -0
- modules/parallel_wavegan/layers/__init__.py +5 -0
- modules/parallel_wavegan/layers/causal_conv.py +56 -0
- modules/parallel_wavegan/layers/pqmf.py +129 -0
- modules/parallel_wavegan/layers/residual_block.py +129 -0
- modules/parallel_wavegan/layers/residual_stack.py +75 -0
- modules/parallel_wavegan/layers/tf_layers.py +129 -0
- modules/parallel_wavegan/layers/upsample.py +183 -0
- modules/parallel_wavegan/losses/__init__.py +1 -0
- modules/parallel_wavegan/losses/stft_loss.py +153 -0
- modules/parallel_wavegan/models/__init__.py +2 -0
- modules/parallel_wavegan/models/melgan.py +427 -0
- modules/parallel_wavegan/models/parallel_wavegan.py +434 -0
- modules/parallel_wavegan/models/source.py +538 -0
- modules/parallel_wavegan/optimizers/__init__.py +2 -0
- modules/parallel_wavegan/optimizers/radam.py +91 -0
- modules/parallel_wavegan/stft_loss.py +100 -0
- modules/parallel_wavegan/utils/__init__.py +1 -0
- modules/parallel_wavegan/utils/utils.py +169 -0
- network/diff/candidate_decoder.py +98 -0
- network/diff/diffusion.py +332 -0
- network/diff/net.py +135 -0
- network/hubert/hubert_model.py +276 -0
- network/hubert/vec_model.py +60 -0
- network/vocoders/__init__.py +2 -0
- network/vocoders/base_vocoder.py +39 -0
- network/vocoders/hifigan.py +83 -0
- network/vocoders/nsf_hifigan.py +92 -0
- network/vocoders/pwg.py +137 -0
- network/vocoders/vocoder_utils.py +15 -0
- preprocessing/SVCpre.py +63 -0
- preprocessing/base_binarizer.py +237 -0
batch.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import soundfile
|
2 |
+
|
3 |
+
from infer_tools import infer_tool
|
4 |
+
from infer_tools.infer_tool import Svc
|
5 |
+
|
6 |
+
|
7 |
+
def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
|
8 |
+
file_path=None, out_path=None):
|
9 |
+
raw_audio_path = f_name
|
10 |
+
infer_tool.format_wav(raw_audio_path)
|
11 |
+
_f0_tst, _f0_pred, _audio = svc_model.infer(raw_audio_path, key=key, acc=acc, singer=True, use_pe=use_pe,
|
12 |
+
use_crepe=use_crepe,
|
13 |
+
thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
|
14 |
+
out_path = f'./singer_data/{f_name.split("/")[-1]}'
|
15 |
+
soundfile.write(out_path, _audio, 44100, 'PCM_16')
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
# 工程文件夹名,训练时用的那个
|
20 |
+
project_name = "firefox"
|
21 |
+
model_path = f'./checkpoints/{project_name}/clean_model_ckpt_steps_100000.ckpt'
|
22 |
+
config_path = f'./checkpoints/{project_name}/config.yaml'
|
23 |
+
|
24 |
+
# 支持多个wav/ogg文件,放在raw文件夹下,带扩展名
|
25 |
+
file_names = infer_tool.get_end_file("./batch", "wav")
|
26 |
+
trans = [-6] # 音高调整,支持正负(半音),数量与上一行对应,不足的自动按第一个移调参数补齐
|
27 |
+
# 加速倍数
|
28 |
+
accelerate = 50
|
29 |
+
hubert_gpu = True
|
30 |
+
cut_time = 30
|
31 |
+
|
32 |
+
# 下面不动
|
33 |
+
infer_tool.mkdir(["./batch", "./singer_data"])
|
34 |
+
infer_tool.fill_a_to_b(trans, file_names)
|
35 |
+
|
36 |
+
model = Svc(project_name, config_path, hubert_gpu, model_path)
|
37 |
+
count = 0
|
38 |
+
for f_name, tran in zip(file_names, trans):
|
39 |
+
print(f_name)
|
40 |
+
run_clip(model, key=tran, acc=accelerate, use_crepe=False, thre=0.05, use_pe=False, use_gt_mel=False,
|
41 |
+
add_noise_step=500, f_name=f_name, project_name=project_name)
|
42 |
+
count += 1
|
43 |
+
print(f"process:{round(count * 100 / len(file_names), 2)}%")
|
flask_api.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import soundfile
|
6 |
+
from flask import Flask, request, send_file
|
7 |
+
from flask_cors import CORS
|
8 |
+
|
9 |
+
from infer_tools.infer_tool import Svc
|
10 |
+
from utils.hparams import hparams
|
11 |
+
|
12 |
+
app = Flask(__name__)
|
13 |
+
|
14 |
+
CORS(app)
|
15 |
+
|
16 |
+
logging.getLogger('numba').setLevel(logging.WARNING)
|
17 |
+
|
18 |
+
|
19 |
+
@app.route("/voiceChangeModel", methods=["POST"])
|
20 |
+
def voice_change_model():
|
21 |
+
request_form = request.form
|
22 |
+
wave_file = request.files.get("sample", None)
|
23 |
+
# 变调信息
|
24 |
+
f_pitch_change = float(request_form.get("fPitchChange", 0))
|
25 |
+
# DAW所需的采样率
|
26 |
+
daw_sample = int(float(request_form.get("sampleRate", 0)))
|
27 |
+
speaker_id = int(float(request_form.get("sSpeakId", 0)))
|
28 |
+
# http获得wav文件并转换
|
29 |
+
input_wav_path = io.BytesIO(wave_file.read())
|
30 |
+
# 模型推理
|
31 |
+
_f0_tst, _f0_pred, _audio = model.infer(input_wav_path, key=f_pitch_change, acc=accelerate, use_pe=False,
|
32 |
+
use_crepe=False)
|
33 |
+
tar_audio = librosa.resample(_audio, hparams["audio_sample_rate"], daw_sample)
|
34 |
+
# 返回音频
|
35 |
+
out_wav_path = io.BytesIO()
|
36 |
+
soundfile.write(out_wav_path, tar_audio, daw_sample, format="wav")
|
37 |
+
out_wav_path.seek(0)
|
38 |
+
return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
# 工程文件夹名,训练时用的那个
|
43 |
+
project_name = "firefox"
|
44 |
+
model_path = f'./checkpoints/{project_name}/model_ckpt_steps_188000.ckpt'
|
45 |
+
config_path = f'./checkpoints/{project_name}/config.yaml'
|
46 |
+
|
47 |
+
# 加速倍数
|
48 |
+
accelerate = 50
|
49 |
+
hubert_gpu = True
|
50 |
+
|
51 |
+
model = Svc(project_name, config_path, hubert_gpu, model_path)
|
52 |
+
|
53 |
+
# 此处与vst插件对应,不建议更改
|
54 |
+
app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
|
infer.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import soundfile
|
8 |
+
|
9 |
+
from infer_tools import infer_tool
|
10 |
+
from infer_tools import slicer
|
11 |
+
from infer_tools.infer_tool import Svc
|
12 |
+
from utils.hparams import hparams
|
13 |
+
|
14 |
+
chunks_dict = infer_tool.read_temp("./infer_tools/new_chunks_temp.json")
|
15 |
+
|
16 |
+
|
17 |
+
def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
|
18 |
+
file_path=None, out_path=None, slice_db=-40,**kwargs):
|
19 |
+
print(f'code version:2022-12-04')
|
20 |
+
use_pe = use_pe if hparams['audio_sample_rate'] == 24000 else False
|
21 |
+
if file_path is None:
|
22 |
+
raw_audio_path = f"./raw/{f_name}"
|
23 |
+
clean_name = f_name[:-4]
|
24 |
+
else:
|
25 |
+
raw_audio_path = file_path
|
26 |
+
clean_name = str(Path(file_path).name)[:-4]
|
27 |
+
infer_tool.format_wav(raw_audio_path)
|
28 |
+
wav_path = Path(raw_audio_path).with_suffix('.wav')
|
29 |
+
global chunks_dict
|
30 |
+
audio, sr = librosa.load(wav_path, mono=True,sr=None)
|
31 |
+
wav_hash = infer_tool.get_md5(audio)
|
32 |
+
if wav_hash in chunks_dict.keys():
|
33 |
+
print("load chunks from temp")
|
34 |
+
chunks = chunks_dict[wav_hash]["chunks"]
|
35 |
+
else:
|
36 |
+
chunks = slicer.cut(wav_path, db_thresh=slice_db)
|
37 |
+
chunks_dict[wav_hash] = {"chunks": chunks, "time": int(time.time())}
|
38 |
+
infer_tool.write_temp("./infer_tools/new_chunks_temp.json", chunks_dict)
|
39 |
+
audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
|
40 |
+
|
41 |
+
count = 0
|
42 |
+
f0_tst = []
|
43 |
+
f0_pred = []
|
44 |
+
audio = []
|
45 |
+
for (slice_tag, data) in audio_data:
|
46 |
+
print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
|
47 |
+
length = int(np.ceil(len(data) / audio_sr * hparams['audio_sample_rate']))
|
48 |
+
raw_path = io.BytesIO()
|
49 |
+
soundfile.write(raw_path, data, audio_sr, format="wav")
|
50 |
+
if hparams['debug']:
|
51 |
+
print(np.mean(data), np.var(data))
|
52 |
+
raw_path.seek(0)
|
53 |
+
if slice_tag:
|
54 |
+
print('jump empty segment')
|
55 |
+
_f0_tst, _f0_pred, _audio = (
|
56 |
+
np.zeros(int(np.ceil(length / hparams['hop_size']))), np.zeros(int(np.ceil(length / hparams['hop_size']))),
|
57 |
+
np.zeros(length))
|
58 |
+
else:
|
59 |
+
_f0_tst, _f0_pred, _audio = svc_model.infer(raw_path, key=key, acc=acc, use_pe=use_pe, use_crepe=use_crepe,
|
60 |
+
thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
|
61 |
+
fix_audio = np.zeros(length)
|
62 |
+
fix_audio[:] = np.mean(_audio)
|
63 |
+
fix_audio[:len(_audio)] = _audio[0 if len(_audio)<len(fix_audio) else len(_audio)-len(fix_audio):]
|
64 |
+
f0_tst.extend(_f0_tst)
|
65 |
+
f0_pred.extend(_f0_pred)
|
66 |
+
audio.extend(list(fix_audio))
|
67 |
+
count += 1
|
68 |
+
if out_path is None:
|
69 |
+
out_path = f'./results/{clean_name}_{key}key_{project_name}_{hparams["residual_channels"]}_{hparams["residual_layers"]}_{int(step / 1000)}k_{accelerate}x.{kwargs["format"]}'
|
70 |
+
soundfile.write(out_path, audio, hparams["audio_sample_rate"], 'PCM_16',format=out_path.split('.')[-1])
|
71 |
+
return np.array(f0_tst), np.array(f0_pred), audio
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
# 工程文件夹名,训练时用的那个
|
76 |
+
project_name = "yilanqiu"
|
77 |
+
model_path = f'./checkpoints/{project_name}/model_ckpt_steps_246000.ckpt'
|
78 |
+
config_path = f'./checkpoints/{project_name}/config.yaml'
|
79 |
+
|
80 |
+
# 支持多个wav/ogg文件,放在raw文件夹下,带扩展名
|
81 |
+
file_names = ["青花瓷.wav"]
|
82 |
+
trans = [0] # 音高调整,支持正负(半音),数量与上一行对应,不足的自动按第一个移调参数补齐
|
83 |
+
# 加速倍数
|
84 |
+
accelerate = 20
|
85 |
+
hubert_gpu = True
|
86 |
+
format='flac'
|
87 |
+
step = int(model_path.split("_")[-1].split(".")[0])
|
88 |
+
|
89 |
+
# 下面不动
|
90 |
+
infer_tool.mkdir(["./raw", "./results"])
|
91 |
+
infer_tool.fill_a_to_b(trans, file_names)
|
92 |
+
|
93 |
+
model = Svc(project_name, config_path, hubert_gpu, model_path)
|
94 |
+
for f_name, tran in zip(file_names, trans):
|
95 |
+
if "." not in f_name:
|
96 |
+
f_name += ".wav"
|
97 |
+
run_clip(model, key=tran, acc=accelerate, use_crepe=True, thre=0.05, use_pe=True, use_gt_mel=False,
|
98 |
+
add_noise_step=500, f_name=f_name, project_name=project_name, format=format)
|
infer_tools/__init__.py
ADDED
File without changes
|
infer_tools/infer_tool.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from io import BytesIO
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
import soundfile
|
11 |
+
import torch
|
12 |
+
|
13 |
+
import utils
|
14 |
+
from modules.fastspeech.pe import PitchExtractor
|
15 |
+
from network.diff.candidate_decoder import FFT
|
16 |
+
from network.diff.diffusion import GaussianDiffusion
|
17 |
+
from network.diff.net import DiffNet
|
18 |
+
from network.vocoders.base_vocoder import VOCODERS, get_vocoder_cls
|
19 |
+
from preprocessing.data_gen_utils import get_pitch_parselmouth, get_pitch_crepe
|
20 |
+
from preprocessing.hubertinfer import Hubertencoder
|
21 |
+
from utils.hparams import hparams, set_hparams
|
22 |
+
from utils.pitch_utils import denorm_f0, norm_interp_f0
|
23 |
+
|
24 |
+
if os.path.exists("chunks_temp.json"):
|
25 |
+
os.remove("chunks_temp.json")
|
26 |
+
|
27 |
+
|
28 |
+
def read_temp(file_name):
|
29 |
+
if not os.path.exists(file_name):
|
30 |
+
with open(file_name, "w") as f:
|
31 |
+
f.write(json.dumps({"info": "temp_dict"}))
|
32 |
+
return {}
|
33 |
+
else:
|
34 |
+
try:
|
35 |
+
with open(file_name, "r") as f:
|
36 |
+
data = f.read()
|
37 |
+
data_dict = json.loads(data)
|
38 |
+
if os.path.getsize(file_name) > 50 * 1024 * 1024:
|
39 |
+
f_name = file_name.split("/")[-1]
|
40 |
+
print(f"clean {f_name}")
|
41 |
+
for wav_hash in list(data_dict.keys()):
|
42 |
+
if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
|
43 |
+
del data_dict[wav_hash]
|
44 |
+
except Exception as e:
|
45 |
+
print(e)
|
46 |
+
print(f"{file_name} error,auto rebuild file")
|
47 |
+
data_dict = {"info": "temp_dict"}
|
48 |
+
return data_dict
|
49 |
+
|
50 |
+
|
51 |
+
f0_dict = read_temp("./infer_tools/f0_temp.json")
|
52 |
+
|
53 |
+
|
54 |
+
def write_temp(file_name, data):
|
55 |
+
with open(file_name, "w") as f:
|
56 |
+
f.write(json.dumps(data))
|
57 |
+
|
58 |
+
|
59 |
+
def timeit(func):
|
60 |
+
def run(*args, **kwargs):
|
61 |
+
t = time.time()
|
62 |
+
res = func(*args, **kwargs)
|
63 |
+
print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
|
64 |
+
return res
|
65 |
+
|
66 |
+
return run
|
67 |
+
|
68 |
+
|
69 |
+
def format_wav(audio_path):
|
70 |
+
if Path(audio_path).suffix=='.wav':
|
71 |
+
return
|
72 |
+
raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True,sr=None)
|
73 |
+
soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
|
74 |
+
|
75 |
+
|
76 |
+
def fill_a_to_b(a, b):
|
77 |
+
if len(a) < len(b):
|
78 |
+
for _ in range(0, len(b) - len(a)):
|
79 |
+
a.append(a[0])
|
80 |
+
|
81 |
+
|
82 |
+
def get_end_file(dir_path, end):
|
83 |
+
file_lists = []
|
84 |
+
for root, dirs, files in os.walk(dir_path):
|
85 |
+
files = [f for f in files if f[0] != '.']
|
86 |
+
dirs[:] = [d for d in dirs if d[0] != '.']
|
87 |
+
for f_file in files:
|
88 |
+
if f_file.endswith(end):
|
89 |
+
file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
|
90 |
+
return file_lists
|
91 |
+
|
92 |
+
|
93 |
+
def mkdir(paths: list):
|
94 |
+
for path in paths:
|
95 |
+
if not os.path.exists(path):
|
96 |
+
os.mkdir(path)
|
97 |
+
|
98 |
+
|
99 |
+
def get_md5(content):
|
100 |
+
return hashlib.new("md5", content).hexdigest()
|
101 |
+
|
102 |
+
|
103 |
+
class Svc:
|
104 |
+
def __init__(self, project_name, config_name, hubert_gpu, model_path):
|
105 |
+
self.project_name = project_name
|
106 |
+
self.DIFF_DECODERS = {
|
107 |
+
'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
|
108 |
+
'fft': lambda hp: FFT(
|
109 |
+
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
|
110 |
+
}
|
111 |
+
|
112 |
+
self.model_path = model_path
|
113 |
+
self.dev = torch.device("cuda")
|
114 |
+
|
115 |
+
self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
|
116 |
+
reset=True,
|
117 |
+
hparams_str='',
|
118 |
+
print_hparams=False)
|
119 |
+
|
120 |
+
self.mel_bins = hparams['audio_num_mel_bins']
|
121 |
+
self.model = GaussianDiffusion(
|
122 |
+
phone_encoder=Hubertencoder(hparams['hubert_path']),
|
123 |
+
out_dims=self.mel_bins, denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
124 |
+
timesteps=hparams['timesteps'],
|
125 |
+
K_step=hparams['K_step'],
|
126 |
+
loss_type=hparams['diff_loss_type'],
|
127 |
+
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
128 |
+
)
|
129 |
+
self.load_ckpt()
|
130 |
+
self.model.cuda()
|
131 |
+
hparams['hubert_gpu'] = hubert_gpu
|
132 |
+
self.hubert = Hubertencoder(hparams['hubert_path'])
|
133 |
+
self.pe = PitchExtractor().cuda()
|
134 |
+
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
|
135 |
+
self.pe.eval()
|
136 |
+
self.vocoder = get_vocoder_cls(hparams)()
|
137 |
+
|
138 |
+
def load_ckpt(self, model_name='model', force=True, strict=True):
|
139 |
+
utils.load_ckpt(self.model, self.model_path, model_name, force, strict)
|
140 |
+
|
141 |
+
def infer(self, in_path, key, acc, use_pe=True, use_crepe=True, thre=0.05, singer=False, **kwargs):
|
142 |
+
batch = self.pre(in_path, acc, use_crepe, thre)
|
143 |
+
spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
|
144 |
+
hubert = batch['hubert']
|
145 |
+
ref_mels = batch["mels"]
|
146 |
+
energy=batch['energy']
|
147 |
+
mel2ph = batch['mel2ph']
|
148 |
+
batch['f0'] = batch['f0'] + (key / 12)
|
149 |
+
batch['f0'][batch['f0']>np.log2(hparams['f0_max'])]=0
|
150 |
+
f0 = batch['f0']
|
151 |
+
uv = batch['uv']
|
152 |
+
@timeit
|
153 |
+
def diff_infer():
|
154 |
+
outputs = self.model(
|
155 |
+
hubert.cuda(), spk_embed=spk_embed, mel2ph=mel2ph.cuda(), f0=f0.cuda(), uv=uv.cuda(),energy=energy.cuda(),
|
156 |
+
ref_mels=ref_mels.cuda(),
|
157 |
+
infer=True, **kwargs)
|
158 |
+
return outputs
|
159 |
+
outputs=diff_infer()
|
160 |
+
batch['outputs'] = self.model.out2mel(outputs['mel_out'])
|
161 |
+
batch['mel2ph_pred'] = outputs['mel2ph']
|
162 |
+
batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
|
163 |
+
if use_pe:
|
164 |
+
batch['f0_pred'] = self.pe(outputs['mel_out'])['f0_denorm_pred'].detach()
|
165 |
+
else:
|
166 |
+
batch['f0_pred'] = outputs.get('f0_denorm')
|
167 |
+
return self.after_infer(batch, singer, in_path)
|
168 |
+
|
169 |
+
@timeit
|
170 |
+
def after_infer(self, prediction, singer, in_path):
|
171 |
+
for k, v in prediction.items():
|
172 |
+
if type(v) is torch.Tensor:
|
173 |
+
prediction[k] = v.cpu().numpy()
|
174 |
+
|
175 |
+
# remove paddings
|
176 |
+
mel_gt = prediction["mels"]
|
177 |
+
mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
|
178 |
+
|
179 |
+
mel_pred = prediction["outputs"]
|
180 |
+
mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
|
181 |
+
mel_pred = mel_pred[mel_pred_mask]
|
182 |
+
mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
|
183 |
+
|
184 |
+
f0_gt = prediction.get("f0_gt")
|
185 |
+
f0_pred = prediction.get("f0_pred")
|
186 |
+
if f0_pred is not None:
|
187 |
+
f0_gt = f0_gt[mel_gt_mask]
|
188 |
+
if len(f0_pred) > len(mel_pred_mask):
|
189 |
+
f0_pred = f0_pred[:len(mel_pred_mask)]
|
190 |
+
f0_pred = f0_pred[mel_pred_mask]
|
191 |
+
torch.cuda.is_available() and torch.cuda.empty_cache()
|
192 |
+
|
193 |
+
if singer:
|
194 |
+
data_path = in_path.replace("batch", "singer_data")
|
195 |
+
mel_path = data_path[:-4] + "_mel.npy"
|
196 |
+
f0_path = data_path[:-4] + "_f0.npy"
|
197 |
+
np.save(mel_path, mel_pred)
|
198 |
+
np.save(f0_path, f0_pred)
|
199 |
+
wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
|
200 |
+
return f0_gt, f0_pred, wav_pred
|
201 |
+
|
202 |
+
def temporary_dict2processed_input(self, item_name, temp_dict, use_crepe=True, thre=0.05):
|
203 |
+
'''
|
204 |
+
process data in temporary_dicts
|
205 |
+
'''
|
206 |
+
|
207 |
+
binarization_args = hparams['binarization_args']
|
208 |
+
|
209 |
+
@timeit
|
210 |
+
def get_pitch(wav, mel):
|
211 |
+
# get ground truth f0 by self.get_pitch_algorithm
|
212 |
+
global f0_dict
|
213 |
+
if use_crepe:
|
214 |
+
md5 = get_md5(wav)
|
215 |
+
if f"{md5}_gt" in f0_dict.keys():
|
216 |
+
print("load temp crepe f0")
|
217 |
+
gt_f0 = np.array(f0_dict[f"{md5}_gt"]["f0"])
|
218 |
+
coarse_f0 = np.array(f0_dict[f"{md5}_coarse"]["f0"])
|
219 |
+
else:
|
220 |
+
torch.cuda.is_available() and torch.cuda.empty_cache()
|
221 |
+
gt_f0, coarse_f0 = get_pitch_crepe(wav, mel, hparams, thre)
|
222 |
+
f0_dict[f"{md5}_gt"] = {"f0": gt_f0.tolist(), "time": int(time.time())}
|
223 |
+
f0_dict[f"{md5}_coarse"] = {"f0": coarse_f0.tolist(), "time": int(time.time())}
|
224 |
+
write_temp("./infer_tools/f0_temp.json", f0_dict)
|
225 |
+
else:
|
226 |
+
gt_f0, coarse_f0 = get_pitch_parselmouth(wav, mel, hparams)
|
227 |
+
processed_input['f0'] = gt_f0
|
228 |
+
processed_input['pitch'] = coarse_f0
|
229 |
+
|
230 |
+
def get_align(mel, phone_encoded):
|
231 |
+
mel2ph = np.zeros([mel.shape[0]], int)
|
232 |
+
start_frame = 0
|
233 |
+
ph_durs = mel.shape[0] / phone_encoded.shape[0]
|
234 |
+
if hparams['debug']:
|
235 |
+
print(mel.shape, phone_encoded.shape, mel.shape[0] / phone_encoded.shape[0])
|
236 |
+
for i_ph in range(phone_encoded.shape[0]):
|
237 |
+
end_frame = int(i_ph * ph_durs + ph_durs + 0.5)
|
238 |
+
mel2ph[start_frame:end_frame + 1] = i_ph + 1
|
239 |
+
start_frame = end_frame + 1
|
240 |
+
|
241 |
+
processed_input['mel2ph'] = mel2ph
|
242 |
+
|
243 |
+
if hparams['vocoder'] in VOCODERS:
|
244 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(temp_dict['wav_fn'])
|
245 |
+
else:
|
246 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(temp_dict['wav_fn'])
|
247 |
+
processed_input = {
|
248 |
+
'item_name': item_name, 'mel': mel,
|
249 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
|
250 |
+
}
|
251 |
+
processed_input = {**temp_dict, **processed_input} # merge two dicts
|
252 |
+
|
253 |
+
if binarization_args['with_f0']:
|
254 |
+
get_pitch(wav, mel)
|
255 |
+
if binarization_args['with_hubert']:
|
256 |
+
st = time.time()
|
257 |
+
hubert_encoded = processed_input['hubert'] = self.hubert.encode(temp_dict['wav_fn'])
|
258 |
+
et = time.time()
|
259 |
+
dev = 'cuda' if hparams['hubert_gpu'] and torch.cuda.is_available() else 'cpu'
|
260 |
+
print(f'hubert (on {dev}) time used {et - st}')
|
261 |
+
|
262 |
+
if binarization_args['with_align']:
|
263 |
+
get_align(mel, hubert_encoded)
|
264 |
+
return processed_input
|
265 |
+
|
266 |
+
def pre(self, wav_fn, accelerate, use_crepe=True, thre=0.05):
|
267 |
+
if isinstance(wav_fn, BytesIO):
|
268 |
+
item_name = self.project_name
|
269 |
+
else:
|
270 |
+
song_info = wav_fn.split('/')
|
271 |
+
item_name = song_info[-1].split('.')[-2]
|
272 |
+
temp_dict = {'wav_fn': wav_fn, 'spk_id': self.project_name}
|
273 |
+
|
274 |
+
temp_dict = self.temporary_dict2processed_input(item_name, temp_dict, use_crepe, thre)
|
275 |
+
hparams['pndm_speedup'] = accelerate
|
276 |
+
batch = processed_input2batch([getitem(temp_dict)])
|
277 |
+
return batch
|
278 |
+
|
279 |
+
|
280 |
+
def getitem(item):
|
281 |
+
max_frames = hparams['max_frames']
|
282 |
+
spec = torch.Tensor(item['mel'])[:max_frames]
|
283 |
+
energy = (spec.exp() ** 2).sum(-1).sqrt()
|
284 |
+
mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
|
285 |
+
f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
|
286 |
+
hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
|
287 |
+
pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
|
288 |
+
sample = {
|
289 |
+
"item_name": item['item_name'],
|
290 |
+
"hubert": hubert,
|
291 |
+
"mel": spec,
|
292 |
+
"pitch": pitch,
|
293 |
+
"energy": energy,
|
294 |
+
"f0": f0,
|
295 |
+
"uv": uv,
|
296 |
+
"mel2ph": mel2ph,
|
297 |
+
"mel_nonpadding": spec.abs().sum(-1) > 0,
|
298 |
+
}
|
299 |
+
return sample
|
300 |
+
|
301 |
+
|
302 |
+
def processed_input2batch(samples):
|
303 |
+
'''
|
304 |
+
Args:
|
305 |
+
samples: one batch of processed_input
|
306 |
+
NOTE:
|
307 |
+
the batch size is controlled by hparams['max_sentences']
|
308 |
+
'''
|
309 |
+
if len(samples) == 0:
|
310 |
+
return {}
|
311 |
+
item_names = [s['item_name'] for s in samples]
|
312 |
+
hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
|
313 |
+
f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
|
314 |
+
pitch = utils.collate_1d([s['pitch'] for s in samples])
|
315 |
+
uv = utils.collate_1d([s['uv'] for s in samples])
|
316 |
+
energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
|
317 |
+
mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
|
318 |
+
if samples[0]['mel2ph'] is not None else None
|
319 |
+
mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
|
320 |
+
mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
|
321 |
+
|
322 |
+
batch = {
|
323 |
+
'item_name': item_names,
|
324 |
+
'nsamples': len(samples),
|
325 |
+
'hubert': hubert,
|
326 |
+
'mels': mels,
|
327 |
+
'mel_lengths': mel_lengths,
|
328 |
+
'mel2ph': mel2ph,
|
329 |
+
'energy': energy,
|
330 |
+
'pitch': pitch,
|
331 |
+
'f0': f0,
|
332 |
+
'uv': uv,
|
333 |
+
}
|
334 |
+
return batch
|
infer_tools/slicer.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from scipy.ndimage import maximum_filter1d, uniform_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
def timeit(func):
|
10 |
+
def run(*args, **kwargs):
|
11 |
+
t = time.time()
|
12 |
+
res = func(*args, **kwargs)
|
13 |
+
print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
|
14 |
+
return res
|
15 |
+
|
16 |
+
return run
|
17 |
+
|
18 |
+
|
19 |
+
# @timeit
|
20 |
+
def _window_maximum(arr, win_sz):
|
21 |
+
return maximum_filter1d(arr, size=win_sz)[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
|
22 |
+
|
23 |
+
|
24 |
+
# @timeit
|
25 |
+
def _window_rms(arr, win_sz):
|
26 |
+
filtered = np.sqrt(uniform_filter1d(np.power(arr, 2), win_sz) - np.power(uniform_filter1d(arr, win_sz), 2))
|
27 |
+
return filtered[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
|
28 |
+
|
29 |
+
|
30 |
+
def level2db(levels, eps=1e-12):
|
31 |
+
return 20 * np.log10(np.clip(levels, a_min=eps, a_max=1))
|
32 |
+
|
33 |
+
|
34 |
+
def _apply_slice(audio, begin, end):
|
35 |
+
if len(audio.shape) > 1:
|
36 |
+
return audio[:, begin: end]
|
37 |
+
else:
|
38 |
+
return audio[begin: end]
|
39 |
+
|
40 |
+
|
41 |
+
class Slicer:
|
42 |
+
def __init__(self,
|
43 |
+
sr: int,
|
44 |
+
db_threshold: float = -40,
|
45 |
+
min_length: int = 5000,
|
46 |
+
win_l: int = 300,
|
47 |
+
win_s: int = 20,
|
48 |
+
max_silence_kept: int = 500):
|
49 |
+
self.db_threshold = db_threshold
|
50 |
+
self.min_samples = round(sr * min_length / 1000)
|
51 |
+
self.win_ln = round(sr * win_l / 1000)
|
52 |
+
self.win_sn = round(sr * win_s / 1000)
|
53 |
+
self.max_silence = round(sr * max_silence_kept / 1000)
|
54 |
+
if not self.min_samples >= self.win_ln >= self.win_sn:
|
55 |
+
raise ValueError('The following condition must be satisfied: min_length >= win_l >= win_s')
|
56 |
+
if not self.max_silence >= self.win_sn:
|
57 |
+
raise ValueError('The following condition must be satisfied: max_silence_kept >= win_s')
|
58 |
+
|
59 |
+
@timeit
|
60 |
+
def slice(self, audio):
|
61 |
+
samples = audio
|
62 |
+
if samples.shape[0] <= self.min_samples:
|
63 |
+
return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
|
64 |
+
# get absolute amplitudes
|
65 |
+
abs_amp = np.abs(samples - np.mean(samples))
|
66 |
+
# calculate local maximum with large window
|
67 |
+
win_max_db = level2db(_window_maximum(abs_amp, win_sz=self.win_ln))
|
68 |
+
sil_tags = []
|
69 |
+
left = right = 0
|
70 |
+
while right < win_max_db.shape[0]:
|
71 |
+
if win_max_db[right] < self.db_threshold:
|
72 |
+
right += 1
|
73 |
+
elif left == right:
|
74 |
+
left += 1
|
75 |
+
right += 1
|
76 |
+
else:
|
77 |
+
if left == 0:
|
78 |
+
split_loc_l = left
|
79 |
+
else:
|
80 |
+
sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
|
81 |
+
rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
|
82 |
+
split_win_l = left + np.argmin(rms_db_left)
|
83 |
+
split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
|
84 |
+
if len(sil_tags) != 0 and split_loc_l - sil_tags[-1][1] < self.min_samples and right < win_max_db.shape[
|
85 |
+
0] - 1:
|
86 |
+
right += 1
|
87 |
+
left = right
|
88 |
+
continue
|
89 |
+
if right == win_max_db.shape[0] - 1:
|
90 |
+
split_loc_r = right + self.win_ln
|
91 |
+
else:
|
92 |
+
sil_right_n = min(self.max_silence, (right + self.win_ln - left) // 2)
|
93 |
+
rms_db_right = level2db(_window_rms(samples[right + self.win_ln - sil_right_n: right + self.win_ln],
|
94 |
+
win_sz=self.win_sn))
|
95 |
+
split_win_r = right + self.win_ln - sil_right_n + np.argmin(rms_db_right)
|
96 |
+
split_loc_r = split_win_r + np.argmin(abs_amp[split_win_r: split_win_r + self.win_sn])
|
97 |
+
sil_tags.append((split_loc_l, split_loc_r))
|
98 |
+
right += 1
|
99 |
+
left = right
|
100 |
+
if left != right:
|
101 |
+
sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
|
102 |
+
rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
|
103 |
+
split_win_l = left + np.argmin(rms_db_left)
|
104 |
+
split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
|
105 |
+
sil_tags.append((split_loc_l, samples.shape[0]))
|
106 |
+
if len(sil_tags) == 0:
|
107 |
+
return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
|
108 |
+
else:
|
109 |
+
chunks = []
|
110 |
+
# 第一段静音并非从头开始,补上有声片段
|
111 |
+
if sil_tags[0][0]:
|
112 |
+
chunks.append({"slice": False, "split_time": f"0,{sil_tags[0][0]}"})
|
113 |
+
for i in range(0, len(sil_tags)):
|
114 |
+
# 标识有声片段(跳过第一段)
|
115 |
+
if i:
|
116 |
+
chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1]},{sil_tags[i][0]}"})
|
117 |
+
# 标识所有静音片段
|
118 |
+
chunks.append({"slice": True, "split_time": f"{sil_tags[i][0]},{sil_tags[i][1]}"})
|
119 |
+
# 最后一段静音并非结尾,补上结尾片段
|
120 |
+
if sil_tags[-1][1] != len(audio):
|
121 |
+
chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1]},{len(audio)}"})
|
122 |
+
chunk_dict = {}
|
123 |
+
for i in range(len(chunks)):
|
124 |
+
chunk_dict[str(i)] = chunks[i]
|
125 |
+
return chunk_dict
|
126 |
+
|
127 |
+
|
128 |
+
def cut(audio_path, db_thresh=-30, min_len=5000, win_l=300, win_s=20, max_sil_kept=500):
|
129 |
+
audio, sr = torchaudio.load(audio_path)
|
130 |
+
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
131 |
+
audio = torch.mean(audio, dim=0).unsqueeze(0)
|
132 |
+
audio = audio.cpu().numpy()[0]
|
133 |
+
|
134 |
+
slicer = Slicer(
|
135 |
+
sr=sr,
|
136 |
+
db_threshold=db_thresh,
|
137 |
+
min_length=min_len,
|
138 |
+
win_l=win_l,
|
139 |
+
win_s=win_s,
|
140 |
+
max_silence_kept=max_sil_kept
|
141 |
+
)
|
142 |
+
chunks = slicer.slice(audio)
|
143 |
+
return chunks
|
144 |
+
|
145 |
+
|
146 |
+
def chunks2audio(audio_path, chunks):
|
147 |
+
chunks = dict(chunks)
|
148 |
+
audio, sr = torchaudio.load(audio_path)
|
149 |
+
if len(audio.shape) == 2 and audio.shape[1] >= 2:
|
150 |
+
audio = torch.mean(audio, dim=0).unsqueeze(0)
|
151 |
+
audio = audio.cpu().numpy()[0]
|
152 |
+
result = []
|
153 |
+
for k, v in chunks.items():
|
154 |
+
tag = v["split_time"].split(",")
|
155 |
+
result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
|
156 |
+
return result, sr
|
157 |
+
|
158 |
+
|
modules/commons/common_layers.py
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Parameter
|
5 |
+
import torch.onnx.operators
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import utils
|
8 |
+
|
9 |
+
|
10 |
+
class Reshape(nn.Module):
|
11 |
+
def __init__(self, *args):
|
12 |
+
super(Reshape, self).__init__()
|
13 |
+
self.shape = args
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return x.view(self.shape)
|
17 |
+
|
18 |
+
|
19 |
+
class Permute(nn.Module):
|
20 |
+
def __init__(self, *args):
|
21 |
+
super(Permute, self).__init__()
|
22 |
+
self.args = args
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x.permute(self.args)
|
26 |
+
|
27 |
+
|
28 |
+
class LinearNorm(torch.nn.Module):
|
29 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
30 |
+
super(LinearNorm, self).__init__()
|
31 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
32 |
+
|
33 |
+
torch.nn.init.xavier_uniform_(
|
34 |
+
self.linear_layer.weight,
|
35 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.linear_layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConvNorm(torch.nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
43 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
44 |
+
super(ConvNorm, self).__init__()
|
45 |
+
if padding is None:
|
46 |
+
assert (kernel_size % 2 == 1)
|
47 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
48 |
+
|
49 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
50 |
+
kernel_size=kernel_size, stride=stride,
|
51 |
+
padding=padding, dilation=dilation,
|
52 |
+
bias=bias)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(
|
55 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
56 |
+
|
57 |
+
def forward(self, signal):
|
58 |
+
conv_signal = self.conv(signal)
|
59 |
+
return conv_signal
|
60 |
+
|
61 |
+
|
62 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
63 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
64 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
65 |
+
if padding_idx is not None:
|
66 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
67 |
+
return m
|
68 |
+
|
69 |
+
|
70 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
71 |
+
if not export and torch.cuda.is_available():
|
72 |
+
try:
|
73 |
+
from apex.normalization import FusedLayerNorm
|
74 |
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
75 |
+
except ImportError:
|
76 |
+
pass
|
77 |
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
78 |
+
|
79 |
+
|
80 |
+
def Linear(in_features, out_features, bias=True):
|
81 |
+
m = nn.Linear(in_features, out_features, bias)
|
82 |
+
nn.init.xavier_uniform_(m.weight)
|
83 |
+
if bias:
|
84 |
+
nn.init.constant_(m.bias, 0.)
|
85 |
+
return m
|
86 |
+
|
87 |
+
|
88 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
89 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
90 |
+
|
91 |
+
Padding symbols are ignored.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
95 |
+
super().__init__()
|
96 |
+
self.embedding_dim = embedding_dim
|
97 |
+
self.padding_idx = padding_idx
|
98 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
99 |
+
init_size,
|
100 |
+
embedding_dim,
|
101 |
+
padding_idx,
|
102 |
+
)
|
103 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
107 |
+
"""Build sinusoidal embeddings.
|
108 |
+
|
109 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
110 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
111 |
+
"""
|
112 |
+
half_dim = embedding_dim // 2
|
113 |
+
emb = math.log(10000) / (half_dim - 1)
|
114 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
115 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
116 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
117 |
+
if embedding_dim % 2 == 1:
|
118 |
+
# zero pad
|
119 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
120 |
+
if padding_idx is not None:
|
121 |
+
emb[padding_idx, :] = 0
|
122 |
+
return emb
|
123 |
+
|
124 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
125 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
126 |
+
bsz, seq_len = input.shape[:2]
|
127 |
+
max_pos = self.padding_idx + 1 + seq_len
|
128 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
129 |
+
# recompute/expand embeddings if needed
|
130 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
131 |
+
max_pos,
|
132 |
+
self.embedding_dim,
|
133 |
+
self.padding_idx,
|
134 |
+
)
|
135 |
+
self.weights = self.weights.to(self._float_tensor)
|
136 |
+
|
137 |
+
if incremental_state is not None:
|
138 |
+
# positions is the same for every token when decoding a single step
|
139 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
140 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
141 |
+
|
142 |
+
positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
|
143 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
144 |
+
|
145 |
+
def max_positions(self):
|
146 |
+
"""Maximum number of supported positions."""
|
147 |
+
return int(1e5) # an arbitrary large number
|
148 |
+
|
149 |
+
|
150 |
+
class ConvTBC(nn.Module):
|
151 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
|
152 |
+
super(ConvTBC, self).__init__()
|
153 |
+
self.in_channels = in_channels
|
154 |
+
self.out_channels = out_channels
|
155 |
+
self.kernel_size = kernel_size
|
156 |
+
self.padding = padding
|
157 |
+
|
158 |
+
self.weight = torch.nn.Parameter(torch.Tensor(
|
159 |
+
self.kernel_size, in_channels, out_channels))
|
160 |
+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
161 |
+
|
162 |
+
def forward(self, input):
|
163 |
+
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
|
164 |
+
|
165 |
+
|
166 |
+
class MultiheadAttention(nn.Module):
|
167 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
168 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
169 |
+
encoder_decoder_attention=False):
|
170 |
+
super().__init__()
|
171 |
+
self.embed_dim = embed_dim
|
172 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
173 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
174 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
175 |
+
|
176 |
+
self.num_heads = num_heads
|
177 |
+
self.dropout = dropout
|
178 |
+
self.head_dim = embed_dim // num_heads
|
179 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
180 |
+
self.scaling = self.head_dim ** -0.5
|
181 |
+
|
182 |
+
self.self_attention = self_attention
|
183 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
184 |
+
|
185 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
186 |
+
'value to be of the same size'
|
187 |
+
|
188 |
+
if self.qkv_same_dim:
|
189 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
190 |
+
else:
|
191 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
192 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
193 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
194 |
+
|
195 |
+
if bias:
|
196 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
197 |
+
else:
|
198 |
+
self.register_parameter('in_proj_bias', None)
|
199 |
+
|
200 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
201 |
+
|
202 |
+
if add_bias_kv:
|
203 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
204 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
205 |
+
else:
|
206 |
+
self.bias_k = self.bias_v = None
|
207 |
+
|
208 |
+
self.add_zero_attn = add_zero_attn
|
209 |
+
|
210 |
+
self.reset_parameters()
|
211 |
+
|
212 |
+
self.enable_torch_version = False
|
213 |
+
if hasattr(F, "multi_head_attention_forward"):
|
214 |
+
self.enable_torch_version = True
|
215 |
+
else:
|
216 |
+
self.enable_torch_version = False
|
217 |
+
self.last_attn_probs = None
|
218 |
+
|
219 |
+
def reset_parameters(self):
|
220 |
+
if self.qkv_same_dim:
|
221 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
222 |
+
else:
|
223 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
224 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
225 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
226 |
+
|
227 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
228 |
+
if self.in_proj_bias is not None:
|
229 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
230 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
231 |
+
if self.bias_k is not None:
|
232 |
+
nn.init.xavier_normal_(self.bias_k)
|
233 |
+
if self.bias_v is not None:
|
234 |
+
nn.init.xavier_normal_(self.bias_v)
|
235 |
+
|
236 |
+
def forward(
|
237 |
+
self,
|
238 |
+
query, key, value,
|
239 |
+
key_padding_mask=None,
|
240 |
+
incremental_state=None,
|
241 |
+
need_weights=True,
|
242 |
+
static_kv=False,
|
243 |
+
attn_mask=None,
|
244 |
+
before_softmax=False,
|
245 |
+
need_head_weights=False,
|
246 |
+
enc_dec_attn_constraint_mask=None,
|
247 |
+
reset_attn_weight=None
|
248 |
+
):
|
249 |
+
"""Input shape: Time x Batch x Channel
|
250 |
+
|
251 |
+
Args:
|
252 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
253 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
254 |
+
padding elements are indicated by 1s.
|
255 |
+
need_weights (bool, optional): return the attention weights,
|
256 |
+
averaged over heads (default: False).
|
257 |
+
attn_mask (ByteTensor, optional): typically used to
|
258 |
+
implement causal attention, where the mask prevents the
|
259 |
+
attention from looking forward in time (default: None).
|
260 |
+
before_softmax (bool, optional): return the raw attention
|
261 |
+
weights and values before the attention softmax.
|
262 |
+
need_head_weights (bool, optional): return the attention
|
263 |
+
weights for each head. Implies *need_weights*. Default:
|
264 |
+
return the average attention weights over all heads.
|
265 |
+
"""
|
266 |
+
if need_head_weights:
|
267 |
+
need_weights = True
|
268 |
+
|
269 |
+
tgt_len, bsz, embed_dim = query.size()
|
270 |
+
assert embed_dim == self.embed_dim
|
271 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
272 |
+
|
273 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
274 |
+
if self.qkv_same_dim:
|
275 |
+
return F.multi_head_attention_forward(query, key, value,
|
276 |
+
self.embed_dim, self.num_heads,
|
277 |
+
self.in_proj_weight,
|
278 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
279 |
+
self.add_zero_attn, self.dropout,
|
280 |
+
self.out_proj.weight, self.out_proj.bias,
|
281 |
+
self.training, key_padding_mask, need_weights,
|
282 |
+
attn_mask)
|
283 |
+
else:
|
284 |
+
return F.multi_head_attention_forward(query, key, value,
|
285 |
+
self.embed_dim, self.num_heads,
|
286 |
+
torch.empty([0]),
|
287 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
288 |
+
self.add_zero_attn, self.dropout,
|
289 |
+
self.out_proj.weight, self.out_proj.bias,
|
290 |
+
self.training, key_padding_mask, need_weights,
|
291 |
+
attn_mask, use_separate_proj_weight=True,
|
292 |
+
q_proj_weight=self.q_proj_weight,
|
293 |
+
k_proj_weight=self.k_proj_weight,
|
294 |
+
v_proj_weight=self.v_proj_weight)
|
295 |
+
|
296 |
+
if incremental_state is not None:
|
297 |
+
print('Not implemented error.')
|
298 |
+
exit()
|
299 |
+
else:
|
300 |
+
saved_state = None
|
301 |
+
|
302 |
+
if self.self_attention:
|
303 |
+
# self-attention
|
304 |
+
q, k, v = self.in_proj_qkv(query)
|
305 |
+
elif self.encoder_decoder_attention:
|
306 |
+
# encoder-decoder attention
|
307 |
+
q = self.in_proj_q(query)
|
308 |
+
if key is None:
|
309 |
+
assert value is None
|
310 |
+
k = v = None
|
311 |
+
else:
|
312 |
+
k = self.in_proj_k(key)
|
313 |
+
v = self.in_proj_v(key)
|
314 |
+
|
315 |
+
else:
|
316 |
+
q = self.in_proj_q(query)
|
317 |
+
k = self.in_proj_k(key)
|
318 |
+
v = self.in_proj_v(value)
|
319 |
+
q *= self.scaling
|
320 |
+
|
321 |
+
if self.bias_k is not None:
|
322 |
+
assert self.bias_v is not None
|
323 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
324 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
325 |
+
if attn_mask is not None:
|
326 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
327 |
+
if key_padding_mask is not None:
|
328 |
+
key_padding_mask = torch.cat(
|
329 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
330 |
+
|
331 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
332 |
+
if k is not None:
|
333 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
334 |
+
if v is not None:
|
335 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
336 |
+
|
337 |
+
if saved_state is not None:
|
338 |
+
print('Not implemented error.')
|
339 |
+
exit()
|
340 |
+
|
341 |
+
src_len = k.size(1)
|
342 |
+
|
343 |
+
# This is part of a workaround to get around fork/join parallelism
|
344 |
+
# not supporting Optional types.
|
345 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
346 |
+
key_padding_mask = None
|
347 |
+
|
348 |
+
if key_padding_mask is not None:
|
349 |
+
assert key_padding_mask.size(0) == bsz
|
350 |
+
assert key_padding_mask.size(1) == src_len
|
351 |
+
|
352 |
+
if self.add_zero_attn:
|
353 |
+
src_len += 1
|
354 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
355 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
356 |
+
if attn_mask is not None:
|
357 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
358 |
+
if key_padding_mask is not None:
|
359 |
+
key_padding_mask = torch.cat(
|
360 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
361 |
+
|
362 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
363 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
364 |
+
|
365 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
366 |
+
|
367 |
+
if attn_mask is not None:
|
368 |
+
if len(attn_mask.shape) == 2:
|
369 |
+
attn_mask = attn_mask.unsqueeze(0)
|
370 |
+
elif len(attn_mask.shape) == 3:
|
371 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
372 |
+
bsz * self.num_heads, tgt_len, src_len)
|
373 |
+
attn_weights = attn_weights + attn_mask
|
374 |
+
|
375 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
376 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
377 |
+
attn_weights = attn_weights.masked_fill(
|
378 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
379 |
+
-1e9,
|
380 |
+
)
|
381 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
382 |
+
|
383 |
+
if key_padding_mask is not None:
|
384 |
+
# don't attend to padding symbols
|
385 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
386 |
+
attn_weights = attn_weights.masked_fill(
|
387 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
388 |
+
-1e9,
|
389 |
+
)
|
390 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
391 |
+
|
392 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
393 |
+
|
394 |
+
if before_softmax:
|
395 |
+
return attn_weights, v
|
396 |
+
|
397 |
+
attn_weights_float = utils.softmax(attn_weights, dim=-1)
|
398 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
399 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
400 |
+
|
401 |
+
if reset_attn_weight is not None:
|
402 |
+
if reset_attn_weight:
|
403 |
+
self.last_attn_probs = attn_probs.detach()
|
404 |
+
else:
|
405 |
+
assert self.last_attn_probs is not None
|
406 |
+
attn_probs = self.last_attn_probs
|
407 |
+
attn = torch.bmm(attn_probs, v)
|
408 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
409 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
410 |
+
attn = self.out_proj(attn)
|
411 |
+
|
412 |
+
if need_weights:
|
413 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
414 |
+
if not need_head_weights:
|
415 |
+
# average attention weights over heads
|
416 |
+
attn_weights = attn_weights.mean(dim=0)
|
417 |
+
else:
|
418 |
+
attn_weights = None
|
419 |
+
|
420 |
+
return attn, (attn_weights, attn_logits)
|
421 |
+
|
422 |
+
def in_proj_qkv(self, query):
|
423 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
424 |
+
|
425 |
+
def in_proj_q(self, query):
|
426 |
+
if self.qkv_same_dim:
|
427 |
+
return self._in_proj(query, end=self.embed_dim)
|
428 |
+
else:
|
429 |
+
bias = self.in_proj_bias
|
430 |
+
if bias is not None:
|
431 |
+
bias = bias[:self.embed_dim]
|
432 |
+
return F.linear(query, self.q_proj_weight, bias)
|
433 |
+
|
434 |
+
def in_proj_k(self, key):
|
435 |
+
if self.qkv_same_dim:
|
436 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
437 |
+
else:
|
438 |
+
weight = self.k_proj_weight
|
439 |
+
bias = self.in_proj_bias
|
440 |
+
if bias is not None:
|
441 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
442 |
+
return F.linear(key, weight, bias)
|
443 |
+
|
444 |
+
def in_proj_v(self, value):
|
445 |
+
if self.qkv_same_dim:
|
446 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
447 |
+
else:
|
448 |
+
weight = self.v_proj_weight
|
449 |
+
bias = self.in_proj_bias
|
450 |
+
if bias is not None:
|
451 |
+
bias = bias[2 * self.embed_dim:]
|
452 |
+
return F.linear(value, weight, bias)
|
453 |
+
|
454 |
+
def _in_proj(self, input, start=0, end=None):
|
455 |
+
weight = self.in_proj_weight
|
456 |
+
bias = self.in_proj_bias
|
457 |
+
weight = weight[start:end, :]
|
458 |
+
if bias is not None:
|
459 |
+
bias = bias[start:end]
|
460 |
+
return F.linear(input, weight, bias)
|
461 |
+
|
462 |
+
|
463 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
464 |
+
return attn_weights
|
465 |
+
|
466 |
+
|
467 |
+
class Swish(torch.autograd.Function):
|
468 |
+
@staticmethod
|
469 |
+
def forward(ctx, i):
|
470 |
+
result = i * torch.sigmoid(i)
|
471 |
+
ctx.save_for_backward(i)
|
472 |
+
return result
|
473 |
+
|
474 |
+
@staticmethod
|
475 |
+
def backward(ctx, grad_output):
|
476 |
+
i = ctx.saved_variables[0]
|
477 |
+
sigmoid_i = torch.sigmoid(i)
|
478 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
479 |
+
|
480 |
+
|
481 |
+
class CustomSwish(nn.Module):
|
482 |
+
def forward(self, input_tensor):
|
483 |
+
return Swish.apply(input_tensor)
|
484 |
+
|
485 |
+
class Mish(nn.Module):
|
486 |
+
def forward(self, x):
|
487 |
+
return x * torch.tanh(F.softplus(x))
|
488 |
+
|
489 |
+
class TransformerFFNLayer(nn.Module):
|
490 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
491 |
+
super().__init__()
|
492 |
+
self.kernel_size = kernel_size
|
493 |
+
self.dropout = dropout
|
494 |
+
self.act = act
|
495 |
+
if padding == 'SAME':
|
496 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
497 |
+
elif padding == 'LEFT':
|
498 |
+
self.ffn_1 = nn.Sequential(
|
499 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
500 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
501 |
+
)
|
502 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
503 |
+
if self.act == 'swish':
|
504 |
+
self.swish_fn = CustomSwish()
|
505 |
+
|
506 |
+
def forward(self, x, incremental_state=None):
|
507 |
+
# x: T x B x C
|
508 |
+
if incremental_state is not None:
|
509 |
+
assert incremental_state is None, 'Nar-generation does not allow this.'
|
510 |
+
exit(1)
|
511 |
+
|
512 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
513 |
+
x = x * self.kernel_size ** -0.5
|
514 |
+
|
515 |
+
if incremental_state is not None:
|
516 |
+
x = x[-1:]
|
517 |
+
if self.act == 'gelu':
|
518 |
+
x = F.gelu(x)
|
519 |
+
if self.act == 'relu':
|
520 |
+
x = F.relu(x)
|
521 |
+
if self.act == 'swish':
|
522 |
+
x = self.swish_fn(x)
|
523 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
524 |
+
x = self.ffn_2(x)
|
525 |
+
return x
|
526 |
+
|
527 |
+
|
528 |
+
class BatchNorm1dTBC(nn.Module):
|
529 |
+
def __init__(self, c):
|
530 |
+
super(BatchNorm1dTBC, self).__init__()
|
531 |
+
self.bn = nn.BatchNorm1d(c)
|
532 |
+
|
533 |
+
def forward(self, x):
|
534 |
+
"""
|
535 |
+
|
536 |
+
:param x: [T, B, C]
|
537 |
+
:return: [T, B, C]
|
538 |
+
"""
|
539 |
+
x = x.permute(1, 2, 0) # [B, C, T]
|
540 |
+
x = self.bn(x) # [B, C, T]
|
541 |
+
x = x.permute(2, 0, 1) # [T, B, C]
|
542 |
+
return x
|
543 |
+
|
544 |
+
|
545 |
+
class EncSALayer(nn.Module):
|
546 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
547 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
|
548 |
+
super().__init__()
|
549 |
+
self.c = c
|
550 |
+
self.dropout = dropout
|
551 |
+
self.num_heads = num_heads
|
552 |
+
if num_heads > 0:
|
553 |
+
if norm == 'ln':
|
554 |
+
self.layer_norm1 = LayerNorm(c)
|
555 |
+
elif norm == 'bn':
|
556 |
+
self.layer_norm1 = BatchNorm1dTBC(c)
|
557 |
+
self.self_attn = MultiheadAttention(
|
558 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
|
559 |
+
)
|
560 |
+
if norm == 'ln':
|
561 |
+
self.layer_norm2 = LayerNorm(c)
|
562 |
+
elif norm == 'bn':
|
563 |
+
self.layer_norm2 = BatchNorm1dTBC(c)
|
564 |
+
self.ffn = TransformerFFNLayer(
|
565 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
566 |
+
|
567 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
568 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
569 |
+
if layer_norm_training is not None:
|
570 |
+
self.layer_norm1.training = layer_norm_training
|
571 |
+
self.layer_norm2.training = layer_norm_training
|
572 |
+
if self.num_heads > 0:
|
573 |
+
residual = x
|
574 |
+
x = self.layer_norm1(x)
|
575 |
+
x, _, = self.self_attn(
|
576 |
+
query=x,
|
577 |
+
key=x,
|
578 |
+
value=x,
|
579 |
+
key_padding_mask=encoder_padding_mask
|
580 |
+
)
|
581 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
582 |
+
x = residual + x
|
583 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
584 |
+
|
585 |
+
residual = x
|
586 |
+
x = self.layer_norm2(x)
|
587 |
+
x = self.ffn(x)
|
588 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
589 |
+
x = residual + x
|
590 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
591 |
+
return x
|
592 |
+
|
593 |
+
|
594 |
+
class DecSALayer(nn.Module):
|
595 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
|
596 |
+
super().__init__()
|
597 |
+
self.c = c
|
598 |
+
self.dropout = dropout
|
599 |
+
self.layer_norm1 = LayerNorm(c)
|
600 |
+
self.self_attn = MultiheadAttention(
|
601 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
602 |
+
)
|
603 |
+
self.layer_norm2 = LayerNorm(c)
|
604 |
+
self.encoder_attn = MultiheadAttention(
|
605 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
606 |
+
)
|
607 |
+
self.layer_norm3 = LayerNorm(c)
|
608 |
+
self.ffn = TransformerFFNLayer(
|
609 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
610 |
+
|
611 |
+
def forward(
|
612 |
+
self,
|
613 |
+
x,
|
614 |
+
encoder_out=None,
|
615 |
+
encoder_padding_mask=None,
|
616 |
+
incremental_state=None,
|
617 |
+
self_attn_mask=None,
|
618 |
+
self_attn_padding_mask=None,
|
619 |
+
attn_out=None,
|
620 |
+
reset_attn_weight=None,
|
621 |
+
**kwargs,
|
622 |
+
):
|
623 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
624 |
+
if layer_norm_training is not None:
|
625 |
+
self.layer_norm1.training = layer_norm_training
|
626 |
+
self.layer_norm2.training = layer_norm_training
|
627 |
+
self.layer_norm3.training = layer_norm_training
|
628 |
+
residual = x
|
629 |
+
x = self.layer_norm1(x)
|
630 |
+
x, _ = self.self_attn(
|
631 |
+
query=x,
|
632 |
+
key=x,
|
633 |
+
value=x,
|
634 |
+
key_padding_mask=self_attn_padding_mask,
|
635 |
+
incremental_state=incremental_state,
|
636 |
+
attn_mask=self_attn_mask
|
637 |
+
)
|
638 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
639 |
+
x = residual + x
|
640 |
+
|
641 |
+
residual = x
|
642 |
+
x = self.layer_norm2(x)
|
643 |
+
if encoder_out is not None:
|
644 |
+
x, attn = self.encoder_attn(
|
645 |
+
query=x,
|
646 |
+
key=encoder_out,
|
647 |
+
value=encoder_out,
|
648 |
+
key_padding_mask=encoder_padding_mask,
|
649 |
+
incremental_state=incremental_state,
|
650 |
+
static_kv=True,
|
651 |
+
enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
|
652 |
+
reset_attn_weight=reset_attn_weight
|
653 |
+
)
|
654 |
+
attn_logits = attn[1]
|
655 |
+
else:
|
656 |
+
assert attn_out is not None
|
657 |
+
x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
|
658 |
+
attn_logits = None
|
659 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
660 |
+
x = residual + x
|
661 |
+
|
662 |
+
residual = x
|
663 |
+
x = self.layer_norm3(x)
|
664 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
665 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
666 |
+
x = residual + x
|
667 |
+
# if len(attn_logits.size()) > 3:
|
668 |
+
# indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
|
669 |
+
# attn_logits = attn_logits.gather(1,
|
670 |
+
# indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
|
671 |
+
return x, attn_logits
|
modules/commons/espnet_positional_embedding.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class PositionalEncoding(torch.nn.Module):
|
6 |
+
"""Positional encoding.
|
7 |
+
Args:
|
8 |
+
d_model (int): Embedding dimension.
|
9 |
+
dropout_rate (float): Dropout rate.
|
10 |
+
max_len (int): Maximum input length.
|
11 |
+
reverse (bool): Whether to reverse the input position.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
15 |
+
"""Construct an PositionalEncoding object."""
|
16 |
+
super(PositionalEncoding, self).__init__()
|
17 |
+
self.d_model = d_model
|
18 |
+
self.reverse = reverse
|
19 |
+
self.xscale = math.sqrt(self.d_model)
|
20 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
21 |
+
self.pe = None
|
22 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
23 |
+
|
24 |
+
def extend_pe(self, x):
|
25 |
+
"""Reset the positional encodings."""
|
26 |
+
if self.pe is not None:
|
27 |
+
if self.pe.size(1) >= x.size(1):
|
28 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
29 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
30 |
+
return
|
31 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
32 |
+
if self.reverse:
|
33 |
+
position = torch.arange(
|
34 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
35 |
+
).unsqueeze(1)
|
36 |
+
else:
|
37 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
38 |
+
div_term = torch.exp(
|
39 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
40 |
+
* -(math.log(10000.0) / self.d_model)
|
41 |
+
)
|
42 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
43 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
44 |
+
pe = pe.unsqueeze(0)
|
45 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
"""Add positional encoding.
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
53 |
+
"""
|
54 |
+
self.extend_pe(x)
|
55 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
56 |
+
return self.dropout(x)
|
57 |
+
|
58 |
+
|
59 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
60 |
+
"""Scaled positional encoding module.
|
61 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
62 |
+
Args:
|
63 |
+
d_model (int): Embedding dimension.
|
64 |
+
dropout_rate (float): Dropout rate.
|
65 |
+
max_len (int): Maximum input length.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
69 |
+
"""Initialize class."""
|
70 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
71 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
72 |
+
|
73 |
+
def reset_parameters(self):
|
74 |
+
"""Reset parameters."""
|
75 |
+
self.alpha.data = torch.tensor(1.0)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
"""Add positional encoding.
|
79 |
+
Args:
|
80 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
81 |
+
Returns:
|
82 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
83 |
+
"""
|
84 |
+
self.extend_pe(x)
|
85 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
86 |
+
return self.dropout(x)
|
87 |
+
|
88 |
+
|
89 |
+
class RelPositionalEncoding(PositionalEncoding):
|
90 |
+
"""Relative positional encoding module.
|
91 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
92 |
+
Args:
|
93 |
+
d_model (int): Embedding dimension.
|
94 |
+
dropout_rate (float): Dropout rate.
|
95 |
+
max_len (int): Maximum input length.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
99 |
+
"""Initialize class."""
|
100 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
"""Compute positional encoding.
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
106 |
+
Returns:
|
107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
108 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
109 |
+
"""
|
110 |
+
self.extend_pe(x)
|
111 |
+
x = x * self.xscale
|
112 |
+
pos_emb = self.pe[:, : x.size(1)]
|
113 |
+
return self.dropout(x) + self.dropout(pos_emb)
|
modules/commons/ssim.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# '''
|
2 |
+
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
|
3 |
+
# '''
|
4 |
+
#
|
5 |
+
# import torch
|
6 |
+
# import torch.jit
|
7 |
+
# import torch.nn.functional as F
|
8 |
+
#
|
9 |
+
#
|
10 |
+
# @torch.jit.script
|
11 |
+
# def create_window(window_size: int, sigma: float, channel: int):
|
12 |
+
# '''
|
13 |
+
# Create 1-D gauss kernel
|
14 |
+
# :param window_size: the size of gauss kernel
|
15 |
+
# :param sigma: sigma of normal distribution
|
16 |
+
# :param channel: input channel
|
17 |
+
# :return: 1D kernel
|
18 |
+
# '''
|
19 |
+
# coords = torch.arange(window_size, dtype=torch.float)
|
20 |
+
# coords -= window_size // 2
|
21 |
+
#
|
22 |
+
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
23 |
+
# g /= g.sum()
|
24 |
+
#
|
25 |
+
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
|
26 |
+
# return g
|
27 |
+
#
|
28 |
+
#
|
29 |
+
# @torch.jit.script
|
30 |
+
# def _gaussian_filter(x, window_1d, use_padding: bool):
|
31 |
+
# '''
|
32 |
+
# Blur input with 1-D kernel
|
33 |
+
# :param x: batch of tensors to be blured
|
34 |
+
# :param window_1d: 1-D gauss kernel
|
35 |
+
# :param use_padding: padding image before conv
|
36 |
+
# :return: blured tensors
|
37 |
+
# '''
|
38 |
+
# C = x.shape[1]
|
39 |
+
# padding = 0
|
40 |
+
# if use_padding:
|
41 |
+
# window_size = window_1d.shape[3]
|
42 |
+
# padding = window_size // 2
|
43 |
+
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
|
44 |
+
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
|
45 |
+
# return out
|
46 |
+
#
|
47 |
+
#
|
48 |
+
# @torch.jit.script
|
49 |
+
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
|
50 |
+
# '''
|
51 |
+
# Calculate ssim index for X and Y
|
52 |
+
# :param X: images [B, C, H, N_bins]
|
53 |
+
# :param Y: images [B, C, H, N_bins]
|
54 |
+
# :param window: 1-D gauss kernel
|
55 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
56 |
+
# :param use_padding: padding image before conv
|
57 |
+
# :return:
|
58 |
+
# '''
|
59 |
+
#
|
60 |
+
# K1 = 0.01
|
61 |
+
# K2 = 0.03
|
62 |
+
# compensation = 1.0
|
63 |
+
#
|
64 |
+
# C1 = (K1 * data_range) ** 2
|
65 |
+
# C2 = (K2 * data_range) ** 2
|
66 |
+
#
|
67 |
+
# mu1 = _gaussian_filter(X, window, use_padding)
|
68 |
+
# mu2 = _gaussian_filter(Y, window, use_padding)
|
69 |
+
# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
|
70 |
+
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
|
71 |
+
# sigma12 = _gaussian_filter(X * Y, window, use_padding)
|
72 |
+
#
|
73 |
+
# mu1_sq = mu1.pow(2)
|
74 |
+
# mu2_sq = mu2.pow(2)
|
75 |
+
# mu1_mu2 = mu1 * mu2
|
76 |
+
#
|
77 |
+
# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
|
78 |
+
# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
|
79 |
+
# sigma12 = compensation * (sigma12 - mu1_mu2)
|
80 |
+
#
|
81 |
+
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
|
82 |
+
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
|
83 |
+
# cs_map = cs_map.clamp_min(0.)
|
84 |
+
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
|
85 |
+
#
|
86 |
+
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
|
87 |
+
# cs = cs_map.mean(dim=(1, 2, 3))
|
88 |
+
#
|
89 |
+
# return ssim_val, cs
|
90 |
+
#
|
91 |
+
#
|
92 |
+
# @torch.jit.script
|
93 |
+
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
|
94 |
+
# '''
|
95 |
+
# interface of ms-ssim
|
96 |
+
# :param X: a batch of images, (N,C,H,W)
|
97 |
+
# :param Y: a batch of images, (N,C,H,W)
|
98 |
+
# :param window: 1-D gauss kernel
|
99 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
100 |
+
# :param weights: weights for different levels
|
101 |
+
# :param use_padding: padding image before conv
|
102 |
+
# :param eps: use for avoid grad nan.
|
103 |
+
# :return:
|
104 |
+
# '''
|
105 |
+
# levels = weights.shape[0]
|
106 |
+
# cs_vals = []
|
107 |
+
# ssim_vals = []
|
108 |
+
# for _ in range(levels):
|
109 |
+
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
|
110 |
+
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
111 |
+
# ssim_val = ssim_val.clamp_min(eps)
|
112 |
+
# cs = cs.clamp_min(eps)
|
113 |
+
# cs_vals.append(cs)
|
114 |
+
#
|
115 |
+
# ssim_vals.append(ssim_val)
|
116 |
+
# padding = (X.shape[2] % 2, X.shape[3] % 2)
|
117 |
+
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
|
118 |
+
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
|
119 |
+
#
|
120 |
+
# cs_vals = torch.stack(cs_vals, dim=0)
|
121 |
+
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
|
122 |
+
# return ms_ssim_val
|
123 |
+
#
|
124 |
+
#
|
125 |
+
# class SSIM(torch.jit.ScriptModule):
|
126 |
+
# __constants__ = ['data_range', 'use_padding']
|
127 |
+
#
|
128 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
|
129 |
+
# '''
|
130 |
+
# :param window_size: the size of gauss kernel
|
131 |
+
# :param window_sigma: sigma of normal distribution
|
132 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
133 |
+
# :param channel: input channels (default: 3)
|
134 |
+
# :param use_padding: padding image before conv
|
135 |
+
# '''
|
136 |
+
# super().__init__()
|
137 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
138 |
+
# window = create_window(window_size, window_sigma, channel)
|
139 |
+
# self.register_buffer('window', window)
|
140 |
+
# self.data_range = data_range
|
141 |
+
# self.use_padding = use_padding
|
142 |
+
#
|
143 |
+
# @torch.jit.script_method
|
144 |
+
# def forward(self, X, Y):
|
145 |
+
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
|
146 |
+
# return r[0]
|
147 |
+
#
|
148 |
+
#
|
149 |
+
# class MS_SSIM(torch.jit.ScriptModule):
|
150 |
+
# __constants__ = ['data_range', 'use_padding', 'eps']
|
151 |
+
#
|
152 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
|
153 |
+
# levels=None, eps=1e-8):
|
154 |
+
# '''
|
155 |
+
# class for ms-ssim
|
156 |
+
# :param window_size: the size of gauss kernel
|
157 |
+
# :param window_sigma: sigma of normal distribution
|
158 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
159 |
+
# :param channel: input channels
|
160 |
+
# :param use_padding: padding image before conv
|
161 |
+
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
|
162 |
+
# :param levels: number of downsampling
|
163 |
+
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
164 |
+
# '''
|
165 |
+
# super().__init__()
|
166 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
167 |
+
# self.data_range = data_range
|
168 |
+
# self.use_padding = use_padding
|
169 |
+
# self.eps = eps
|
170 |
+
#
|
171 |
+
# window = create_window(window_size, window_sigma, channel)
|
172 |
+
# self.register_buffer('window', window)
|
173 |
+
#
|
174 |
+
# if weights is None:
|
175 |
+
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
|
176 |
+
# weights = torch.tensor(weights, dtype=torch.float)
|
177 |
+
#
|
178 |
+
# if levels is not None:
|
179 |
+
# weights = weights[:levels]
|
180 |
+
# weights = weights / weights.sum()
|
181 |
+
#
|
182 |
+
# self.register_buffer('weights', weights)
|
183 |
+
#
|
184 |
+
# @torch.jit.script_method
|
185 |
+
# def forward(self, X, Y):
|
186 |
+
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
|
187 |
+
# use_padding=self.use_padding, eps=self.eps)
|
188 |
+
#
|
189 |
+
#
|
190 |
+
# if __name__ == '__main__':
|
191 |
+
# print('Simple Test')
|
192 |
+
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
|
193 |
+
# img1 = im / 255
|
194 |
+
# img2 = img1 * 0.5
|
195 |
+
#
|
196 |
+
# losser = SSIM(data_range=1.).cuda()
|
197 |
+
# loss = losser(img1, img2).mean()
|
198 |
+
#
|
199 |
+
# losser2 = MS_SSIM(data_range=1.).cuda()
|
200 |
+
# loss2 = losser2(img1, img2).mean()
|
201 |
+
#
|
202 |
+
# print(loss.item())
|
203 |
+
# print(loss2.item())
|
204 |
+
#
|
205 |
+
# if __name__ == '__main__':
|
206 |
+
# print('Training Test')
|
207 |
+
# import cv2
|
208 |
+
# import torch.optim
|
209 |
+
# import numpy as np
|
210 |
+
# import imageio
|
211 |
+
# import time
|
212 |
+
#
|
213 |
+
# out_test_video = False
|
214 |
+
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
|
215 |
+
# video_use_gif = False
|
216 |
+
#
|
217 |
+
# im = cv2.imread('test_img1.jpg', 1)
|
218 |
+
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
|
219 |
+
#
|
220 |
+
# if out_test_video:
|
221 |
+
# if video_use_gif:
|
222 |
+
# fps = 0.5
|
223 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
224 |
+
# suffix = '.gif'
|
225 |
+
# else:
|
226 |
+
# fps = 5
|
227 |
+
# out_wh = (im.shape[1], im.shape[0])
|
228 |
+
# suffix = '.mkv'
|
229 |
+
# video_last_time = time.perf_counter()
|
230 |
+
# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
|
231 |
+
#
|
232 |
+
# # 测试ssim
|
233 |
+
# print('Training SSIM')
|
234 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
235 |
+
# rand_im.requires_grad = True
|
236 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
237 |
+
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
238 |
+
# ssim_score = 0
|
239 |
+
# while ssim_score < 0.999:
|
240 |
+
# optim.zero_grad()
|
241 |
+
# loss = losser(rand_im, t_im)
|
242 |
+
# (-loss).sum().backward()
|
243 |
+
# ssim_score = loss.item()
|
244 |
+
# optim.step()
|
245 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
246 |
+
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
247 |
+
#
|
248 |
+
# if out_test_video:
|
249 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
250 |
+
# video_last_time = time.perf_counter()
|
251 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
252 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
253 |
+
# if isinstance(out_frame, cv2.UMat):
|
254 |
+
# out_frame = out_frame.get()
|
255 |
+
# video.append_data(out_frame)
|
256 |
+
#
|
257 |
+
# cv2.imshow('ssim', r_im)
|
258 |
+
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
|
259 |
+
# cv2.waitKey(1)
|
260 |
+
#
|
261 |
+
# if out_test_video:
|
262 |
+
# video.close()
|
263 |
+
#
|
264 |
+
# # 测试ms_ssim
|
265 |
+
# if out_test_video:
|
266 |
+
# if video_use_gif:
|
267 |
+
# fps = 0.5
|
268 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
269 |
+
# suffix = '.gif'
|
270 |
+
# else:
|
271 |
+
# fps = 5
|
272 |
+
# out_wh = (im.shape[1], im.shape[0])
|
273 |
+
# suffix = '.mkv'
|
274 |
+
# video_last_time = time.perf_counter()
|
275 |
+
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
|
276 |
+
#
|
277 |
+
# print('Training MS_SSIM')
|
278 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
279 |
+
# rand_im.requires_grad = True
|
280 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
281 |
+
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
282 |
+
# ssim_score = 0
|
283 |
+
# while ssim_score < 0.999:
|
284 |
+
# optim.zero_grad()
|
285 |
+
# loss = losser(rand_im, t_im)
|
286 |
+
# (-loss).sum().backward()
|
287 |
+
# ssim_score = loss.item()
|
288 |
+
# optim.step()
|
289 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
290 |
+
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
291 |
+
#
|
292 |
+
# if out_test_video:
|
293 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
294 |
+
# video_last_time = time.perf_counter()
|
295 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
296 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
297 |
+
# if isinstance(out_frame, cv2.UMat):
|
298 |
+
# out_frame = out_frame.get()
|
299 |
+
# video.append_data(out_frame)
|
300 |
+
#
|
301 |
+
# cv2.imshow('ms_ssim', r_im)
|
302 |
+
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
|
303 |
+
# cv2.waitKey(1)
|
304 |
+
#
|
305 |
+
# if out_test_video:
|
306 |
+
# video.close()
|
307 |
+
|
308 |
+
"""
|
309 |
+
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
|
310 |
+
"""
|
311 |
+
|
312 |
+
import torch
|
313 |
+
import torch.nn.functional as F
|
314 |
+
from torch.autograd import Variable
|
315 |
+
import numpy as np
|
316 |
+
from math import exp
|
317 |
+
|
318 |
+
|
319 |
+
def gaussian(window_size, sigma):
|
320 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
321 |
+
return gauss / gauss.sum()
|
322 |
+
|
323 |
+
|
324 |
+
def create_window(window_size, channel):
|
325 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
326 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
327 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
328 |
+
return window
|
329 |
+
|
330 |
+
|
331 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
332 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
333 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
334 |
+
|
335 |
+
mu1_sq = mu1.pow(2)
|
336 |
+
mu2_sq = mu2.pow(2)
|
337 |
+
mu1_mu2 = mu1 * mu2
|
338 |
+
|
339 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
340 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
341 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
342 |
+
|
343 |
+
C1 = 0.01 ** 2
|
344 |
+
C2 = 0.03 ** 2
|
345 |
+
|
346 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
347 |
+
|
348 |
+
if size_average:
|
349 |
+
return ssim_map.mean()
|
350 |
+
else:
|
351 |
+
return ssim_map.mean(1)
|
352 |
+
|
353 |
+
|
354 |
+
class SSIM(torch.nn.Module):
|
355 |
+
def __init__(self, window_size=11, size_average=True):
|
356 |
+
super(SSIM, self).__init__()
|
357 |
+
self.window_size = window_size
|
358 |
+
self.size_average = size_average
|
359 |
+
self.channel = 1
|
360 |
+
self.window = create_window(window_size, self.channel)
|
361 |
+
|
362 |
+
def forward(self, img1, img2):
|
363 |
+
(_, channel, _, _) = img1.size()
|
364 |
+
|
365 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
366 |
+
window = self.window
|
367 |
+
else:
|
368 |
+
window = create_window(self.window_size, channel)
|
369 |
+
|
370 |
+
if img1.is_cuda:
|
371 |
+
window = window.cuda(img1.get_device())
|
372 |
+
window = window.type_as(img1)
|
373 |
+
|
374 |
+
self.window = window
|
375 |
+
self.channel = channel
|
376 |
+
|
377 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
378 |
+
|
379 |
+
|
380 |
+
window = None
|
381 |
+
|
382 |
+
|
383 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
384 |
+
(_, channel, _, _) = img1.size()
|
385 |
+
global window
|
386 |
+
if window is None:
|
387 |
+
window = create_window(window_size, channel)
|
388 |
+
if img1.is_cuda:
|
389 |
+
window = window.cuda(img1.get_device())
|
390 |
+
window = window.type_as(img1)
|
391 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
modules/fastspeech/fs2.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from modules.commons.common_layers import Embedding
|
3 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
4 |
+
EnergyPredictor, FastspeechEncoder
|
5 |
+
from utils.cwt import cwt2f0
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
8 |
+
|
9 |
+
FS_ENCODERS = {
|
10 |
+
'fft': lambda hp: FastspeechEncoder(
|
11 |
+
hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
12 |
+
num_heads=hp['num_heads']),
|
13 |
+
}
|
14 |
+
|
15 |
+
FS_DECODERS = {
|
16 |
+
'fft': lambda hp: FastspeechDecoder(
|
17 |
+
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
class FastSpeech2(nn.Module):
|
22 |
+
def __init__(self, dictionary, out_dims=None):
|
23 |
+
super().__init__()
|
24 |
+
# self.dictionary = dictionary
|
25 |
+
self.padding_idx = 0
|
26 |
+
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
|
27 |
+
self.enc_layers = hparams['enc_layers']
|
28 |
+
self.dec_layers = hparams['dec_layers']
|
29 |
+
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams)
|
30 |
+
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
31 |
+
self.hidden_size = hparams['hidden_size']
|
32 |
+
# self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
33 |
+
self.out_dims = out_dims
|
34 |
+
if out_dims is None:
|
35 |
+
self.out_dims = hparams['audio_num_mel_bins']
|
36 |
+
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
37 |
+
#=========not used===========
|
38 |
+
# if hparams['use_spk_id']:
|
39 |
+
# self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
40 |
+
# if hparams['use_split_spk_id']:
|
41 |
+
# self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
42 |
+
# self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
43 |
+
# elif hparams['use_spk_embed']:
|
44 |
+
# self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
|
45 |
+
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
46 |
+
# self.dur_predictor = DurationPredictor(
|
47 |
+
# self.hidden_size,
|
48 |
+
# n_chans=predictor_hidden,
|
49 |
+
# n_layers=hparams['dur_predictor_layers'],
|
50 |
+
# dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
|
51 |
+
# kernel_size=hparams['dur_predictor_kernel'])
|
52 |
+
# self.length_regulator = LengthRegulator()
|
53 |
+
if hparams['use_pitch_embed']:
|
54 |
+
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
55 |
+
if hparams['pitch_type'] == 'cwt':
|
56 |
+
h = hparams['cwt_hidden_size']
|
57 |
+
cwt_out_dims = 10
|
58 |
+
if hparams['use_uv']:
|
59 |
+
cwt_out_dims = cwt_out_dims + 1
|
60 |
+
self.cwt_predictor = nn.Sequential(
|
61 |
+
nn.Linear(self.hidden_size, h),
|
62 |
+
PitchPredictor(
|
63 |
+
h,
|
64 |
+
n_chans=predictor_hidden,
|
65 |
+
n_layers=hparams['predictor_layers'],
|
66 |
+
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
|
67 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
|
68 |
+
self.cwt_stats_layers = nn.Sequential(
|
69 |
+
nn.Linear(self.hidden_size, h), nn.ReLU(),
|
70 |
+
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
self.pitch_predictor = PitchPredictor(
|
74 |
+
self.hidden_size,
|
75 |
+
n_chans=predictor_hidden,
|
76 |
+
n_layers=hparams['predictor_layers'],
|
77 |
+
dropout_rate=hparams['predictor_dropout'],
|
78 |
+
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
79 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
80 |
+
if hparams['use_energy_embed']:
|
81 |
+
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
82 |
+
# self.energy_predictor = EnergyPredictor(
|
83 |
+
# self.hidden_size,
|
84 |
+
# n_chans=predictor_hidden,
|
85 |
+
# n_layers=hparams['predictor_layers'],
|
86 |
+
# dropout_rate=hparams['predictor_dropout'], odim=1,
|
87 |
+
# padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
88 |
+
|
89 |
+
# def build_embedding(self, dictionary, embed_dim):
|
90 |
+
# num_embeddings = len(dictionary)
|
91 |
+
# emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
|
92 |
+
# return emb
|
93 |
+
|
94 |
+
def forward(self, hubert, mel2ph=None, spk_embed=None,
|
95 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=True,
|
96 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
97 |
+
ret = {}
|
98 |
+
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
|
99 |
+
encoder_out =self.encoder(hubert) # [B, T, C]
|
100 |
+
else:
|
101 |
+
encoder_out =hubert
|
102 |
+
src_nonpadding = (hubert!=0).any(-1)[:,:,None]
|
103 |
+
|
104 |
+
# add ref style embed
|
105 |
+
# Not implemented
|
106 |
+
# variance encoder
|
107 |
+
var_embed = 0
|
108 |
+
|
109 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
110 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
111 |
+
if hparams['use_spk_embed']:
|
112 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
113 |
+
elif hparams['use_spk_id']:
|
114 |
+
spk_embed_id = spk_embed
|
115 |
+
if spk_embed_dur_id is None:
|
116 |
+
spk_embed_dur_id = spk_embed_id
|
117 |
+
if spk_embed_f0_id is None:
|
118 |
+
spk_embed_f0_id = spk_embed_id
|
119 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
120 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
121 |
+
if hparams['use_split_spk_id']:
|
122 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
123 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
124 |
+
else:
|
125 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
126 |
+
|
127 |
+
# add dur
|
128 |
+
# dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
129 |
+
|
130 |
+
# mel2ph = self.add_dur(dur_inp, mel2ph, hubert, ret)
|
131 |
+
ret['mel2ph'] = mel2ph
|
132 |
+
|
133 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
134 |
+
|
135 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
136 |
+
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
137 |
+
|
138 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
139 |
+
|
140 |
+
# add pitch and energy embed
|
141 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
142 |
+
if hparams['use_pitch_embed']:
|
143 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
144 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
145 |
+
if hparams['use_energy_embed']:
|
146 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
147 |
+
|
148 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
149 |
+
if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
|
150 |
+
if skip_decoder:
|
151 |
+
return ret
|
152 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
153 |
+
|
154 |
+
return ret
|
155 |
+
|
156 |
+
def add_dur(self, dur_input, mel2ph, hubert, ret):
|
157 |
+
src_padding = (hubert==0).all(-1)
|
158 |
+
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
|
159 |
+
if mel2ph is None:
|
160 |
+
dur, xs = self.dur_predictor.inference(dur_input, src_padding)
|
161 |
+
ret['dur'] = xs
|
162 |
+
ret['dur_choice'] = dur
|
163 |
+
mel2ph = self.length_regulator(dur, src_padding).detach()
|
164 |
+
else:
|
165 |
+
ret['dur'] = self.dur_predictor(dur_input, src_padding)
|
166 |
+
ret['mel2ph'] = mel2ph
|
167 |
+
return mel2ph
|
168 |
+
|
169 |
+
def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
|
170 |
+
x = decoder_inp # [B, T, H]
|
171 |
+
x = self.decoder(x)
|
172 |
+
x = self.mel_out(x)
|
173 |
+
return x * tgt_nonpadding
|
174 |
+
|
175 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
176 |
+
f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
|
177 |
+
f0 = torch.cat(
|
178 |
+
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
|
179 |
+
f0_norm = norm_f0(f0, None, hparams)
|
180 |
+
return f0_norm
|
181 |
+
|
182 |
+
def out2mel(self, out):
|
183 |
+
return out
|
184 |
+
|
185 |
+
def add_pitch(self,decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
186 |
+
# if hparams['pitch_type'] == 'ph':
|
187 |
+
# pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
|
188 |
+
# pitch_padding = (encoder_out.sum().abs() == 0)
|
189 |
+
# ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
|
190 |
+
# if f0 is None:
|
191 |
+
# f0 = pitch_pred[:, :, 0]
|
192 |
+
# ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
|
193 |
+
# pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
|
194 |
+
# pitch = F.pad(pitch, [1, 0])
|
195 |
+
# pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
|
196 |
+
# pitch_embedding = pitch_embed(pitch)
|
197 |
+
# return pitch_embedding
|
198 |
+
|
199 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
200 |
+
|
201 |
+
pitch_padding = (mel2ph == 0)
|
202 |
+
|
203 |
+
# if hparams['pitch_type'] == 'cwt':
|
204 |
+
# # NOTE: this part of script is *isolated* from other scripts, which means
|
205 |
+
# # it may not be compatible with the current version.
|
206 |
+
# pass
|
207 |
+
# # pitch_padding = None
|
208 |
+
# # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
|
209 |
+
# # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
|
210 |
+
# # mean = ret['f0_mean'] = stats_out[:, 0]
|
211 |
+
# # std = ret['f0_std'] = stats_out[:, 1]
|
212 |
+
# # cwt_spec = cwt_out[:, :, :10]
|
213 |
+
# # if f0 is None:
|
214 |
+
# # std = std * hparams['cwt_std_scale']
|
215 |
+
# # f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
216 |
+
# # if hparams['use_uv']:
|
217 |
+
# # assert cwt_out.shape[-1] == 11
|
218 |
+
# # uv = cwt_out[:, :, -1] > 0
|
219 |
+
# elif hparams['pitch_ar']:
|
220 |
+
# ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if is_training else None)
|
221 |
+
# if f0 is None:
|
222 |
+
# f0 = pitch_pred[:, :, 0]
|
223 |
+
# else:
|
224 |
+
#ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
225 |
+
# if f0 is None:
|
226 |
+
# f0 = pitch_pred[:, :, 0]
|
227 |
+
# if hparams['use_uv'] and uv is None:
|
228 |
+
# uv = pitch_pred[:, :, 1] > 0
|
229 |
+
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
230 |
+
if pitch_padding is not None:
|
231 |
+
f0[pitch_padding] = 0
|
232 |
+
|
233 |
+
pitch = f0_to_coarse(f0_denorm,hparams) # start from 0
|
234 |
+
ret['pitch_pred']=pitch.unsqueeze(-1)
|
235 |
+
# print(ret['pitch_pred'].shape)
|
236 |
+
# print(pitch.shape)
|
237 |
+
pitch_embedding = self.pitch_embed(pitch)
|
238 |
+
return pitch_embedding
|
239 |
+
|
240 |
+
def add_energy(self,decoder_inp, energy, ret):
|
241 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
242 |
+
ret['energy_pred'] = energy#energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
|
243 |
+
# if energy is None:
|
244 |
+
# energy = energy_pred
|
245 |
+
energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
|
246 |
+
energy_embedding = self.energy_embed(energy)
|
247 |
+
return energy_embedding
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def mel_norm(x):
|
251 |
+
return (x + 5.5) / (6.3 / 2) - 1
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def mel_denorm(x):
|
255 |
+
return (x + 1) * (6.3 / 2) - 5.5
|
modules/fastspeech/pe.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from utils.hparams import hparams
|
3 |
+
from modules.fastspeech.tts_modules import PitchPredictor
|
4 |
+
from utils.pitch_utils import denorm_f0
|
5 |
+
|
6 |
+
|
7 |
+
class Prenet(nn.Module):
|
8 |
+
def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
|
9 |
+
super(Prenet, self).__init__()
|
10 |
+
padding = kernel // 2
|
11 |
+
self.layers = []
|
12 |
+
self.strides = strides if strides is not None else [1] * n_layers
|
13 |
+
for l in range(n_layers):
|
14 |
+
self.layers.append(nn.Sequential(
|
15 |
+
nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.BatchNorm1d(out_dim)
|
18 |
+
))
|
19 |
+
in_dim = out_dim
|
20 |
+
self.layers = nn.ModuleList(self.layers)
|
21 |
+
self.out_proj = nn.Linear(out_dim, out_dim)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
"""
|
25 |
+
|
26 |
+
:param x: [B, T, 80]
|
27 |
+
:return: [L, B, T, H], [B, T, H]
|
28 |
+
"""
|
29 |
+
# padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
|
30 |
+
padding_mask = x.abs().sum(-1).eq(0).detach()
|
31 |
+
nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
|
32 |
+
x = x.transpose(1, 2)
|
33 |
+
hiddens = []
|
34 |
+
for i, l in enumerate(self.layers):
|
35 |
+
nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
|
36 |
+
x = l(x) * nonpadding_mask_TB
|
37 |
+
hiddens.append(x)
|
38 |
+
hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
|
39 |
+
hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
|
40 |
+
x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
|
41 |
+
x = x * nonpadding_mask_TB.transpose(1, 2)
|
42 |
+
return hiddens, x
|
43 |
+
|
44 |
+
|
45 |
+
class ConvBlock(nn.Module):
|
46 |
+
def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
|
47 |
+
super().__init__()
|
48 |
+
self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
|
49 |
+
self.norm = norm
|
50 |
+
if self.norm == 'bn':
|
51 |
+
self.norm = nn.BatchNorm1d(n_chans)
|
52 |
+
elif self.norm == 'in':
|
53 |
+
self.norm = nn.InstanceNorm1d(n_chans, affine=True)
|
54 |
+
elif self.norm == 'gn':
|
55 |
+
self.norm = nn.GroupNorm(n_chans // 16, n_chans)
|
56 |
+
elif self.norm == 'ln':
|
57 |
+
self.norm = LayerNorm(n_chans // 16, n_chans)
|
58 |
+
elif self.norm == 'wn':
|
59 |
+
self.conv = torch.nn.utils.weight_norm(self.conv.conv)
|
60 |
+
self.dropout = nn.Dropout(dropout)
|
61 |
+
self.relu = nn.ReLU()
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
"""
|
65 |
+
|
66 |
+
:param x: [B, C, T]
|
67 |
+
:return: [B, C, T]
|
68 |
+
"""
|
69 |
+
x = self.conv(x)
|
70 |
+
if not isinstance(self.norm, str):
|
71 |
+
if self.norm == 'none':
|
72 |
+
pass
|
73 |
+
elif self.norm == 'ln':
|
74 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
75 |
+
else:
|
76 |
+
x = self.norm(x)
|
77 |
+
x = self.relu(x)
|
78 |
+
x = self.dropout(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class ConvStacks(nn.Module):
|
83 |
+
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
|
84 |
+
dropout=0, strides=None, res=True):
|
85 |
+
super().__init__()
|
86 |
+
self.conv = torch.nn.ModuleList()
|
87 |
+
self.kernel_size = kernel_size
|
88 |
+
self.res = res
|
89 |
+
self.in_proj = Linear(idim, n_chans)
|
90 |
+
if strides is None:
|
91 |
+
strides = [1] * n_layers
|
92 |
+
else:
|
93 |
+
assert len(strides) == n_layers
|
94 |
+
for idx in range(n_layers):
|
95 |
+
self.conv.append(ConvBlock(
|
96 |
+
n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
|
97 |
+
self.out_proj = Linear(n_chans, odim)
|
98 |
+
|
99 |
+
def forward(self, x, return_hiddens=False):
|
100 |
+
"""
|
101 |
+
|
102 |
+
:param x: [B, T, H]
|
103 |
+
:return: [B, T, H]
|
104 |
+
"""
|
105 |
+
x = self.in_proj(x)
|
106 |
+
x = x.transpose(1, -1) # (B, idim, Tmax)
|
107 |
+
hiddens = []
|
108 |
+
for f in self.conv:
|
109 |
+
x_ = f(x)
|
110 |
+
x = x + x_ if self.res else x_ # (B, C, Tmax)
|
111 |
+
hiddens.append(x)
|
112 |
+
x = x.transpose(1, -1)
|
113 |
+
x = self.out_proj(x) # (B, Tmax, H)
|
114 |
+
if return_hiddens:
|
115 |
+
hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
|
116 |
+
return x, hiddens
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class PitchExtractor(nn.Module):
|
121 |
+
def __init__(self, n_mel_bins=80, conv_layers=2):
|
122 |
+
super().__init__()
|
123 |
+
self.hidden_size = hparams['hidden_size']
|
124 |
+
self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
125 |
+
self.conv_layers = conv_layers
|
126 |
+
|
127 |
+
self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
|
128 |
+
if self.conv_layers > 0:
|
129 |
+
self.mel_encoder = ConvStacks(
|
130 |
+
idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
|
131 |
+
self.pitch_predictor = PitchPredictor(
|
132 |
+
self.hidden_size, n_chans=self.predictor_hidden,
|
133 |
+
n_layers=5, dropout_rate=0.1, odim=2,
|
134 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
135 |
+
|
136 |
+
def forward(self, mel_input=None):
|
137 |
+
ret = {}
|
138 |
+
mel_hidden = self.mel_prenet(mel_input)[1]
|
139 |
+
if self.conv_layers > 0:
|
140 |
+
mel_hidden = self.mel_encoder(mel_hidden)
|
141 |
+
|
142 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
|
143 |
+
|
144 |
+
pitch_padding = mel_input.abs().sum(-1) == 0
|
145 |
+
use_uv = hparams['pitch_type'] == 'frame' #and hparams['use_uv']
|
146 |
+
ret['f0_denorm_pred'] = denorm_f0(
|
147 |
+
pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
|
148 |
+
hparams, pitch_padding=pitch_padding)
|
149 |
+
return ret
|
modules/fastspeech/tts_modules.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from modules.commons.espnet_positional_embedding import RelPositionalEncoding
|
9 |
+
from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
|
10 |
+
from utils.hparams import hparams
|
11 |
+
|
12 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
13 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerEncoderLayer(nn.Module):
|
17 |
+
def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
|
18 |
+
super().__init__()
|
19 |
+
self.hidden_size = hidden_size
|
20 |
+
self.dropout = dropout
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.op = EncSALayer(
|
23 |
+
hidden_size, num_heads, dropout=dropout,
|
24 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
25 |
+
kernel_size=kernel_size
|
26 |
+
if kernel_size is not None else hparams['enc_ffn_kernel_size'],
|
27 |
+
padding=hparams['ffn_padding'],
|
28 |
+
norm=norm, act=hparams['ffn_act'])
|
29 |
+
|
30 |
+
def forward(self, x, **kwargs):
|
31 |
+
return self.op(x, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
######################
|
35 |
+
# fastspeech modules
|
36 |
+
######################
|
37 |
+
class LayerNorm(torch.nn.LayerNorm):
|
38 |
+
"""Layer normalization module.
|
39 |
+
:param int nout: output dim size
|
40 |
+
:param int dim: dimension to be normalized
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, nout, dim=-1):
|
44 |
+
"""Construct an LayerNorm object."""
|
45 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
46 |
+
self.dim = dim
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""Apply layer normalization.
|
50 |
+
:param torch.Tensor x: input tensor
|
51 |
+
:return: layer normalized tensor
|
52 |
+
:rtype torch.Tensor
|
53 |
+
"""
|
54 |
+
if self.dim == -1:
|
55 |
+
return super(LayerNorm, self).forward(x)
|
56 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
57 |
+
|
58 |
+
|
59 |
+
class DurationPredictor(torch.nn.Module):
|
60 |
+
"""Duration predictor module.
|
61 |
+
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
62 |
+
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
|
63 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
64 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
65 |
+
Note:
|
66 |
+
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
|
67 |
+
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
71 |
+
"""Initilize duration predictor module.
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
n_layers (int, optional): Number of convolutional layers.
|
75 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
76 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
77 |
+
dropout_rate (float, optional): Dropout rate.
|
78 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
79 |
+
"""
|
80 |
+
super(DurationPredictor, self).__init__()
|
81 |
+
self.offset = offset
|
82 |
+
self.conv = torch.nn.ModuleList()
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.padding = padding
|
85 |
+
for idx in range(n_layers):
|
86 |
+
in_chans = idim if idx == 0 else n_chans
|
87 |
+
self.conv += [torch.nn.Sequential(
|
88 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
89 |
+
if padding == 'SAME'
|
90 |
+
else (kernel_size - 1, 0), 0),
|
91 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
92 |
+
torch.nn.ReLU(),
|
93 |
+
LayerNorm(n_chans, dim=1),
|
94 |
+
torch.nn.Dropout(dropout_rate)
|
95 |
+
)]
|
96 |
+
if hparams['dur_loss'] in ['mse', 'huber']:
|
97 |
+
odims = 1
|
98 |
+
elif hparams['dur_loss'] == 'mog':
|
99 |
+
odims = 15
|
100 |
+
elif hparams['dur_loss'] == 'crf':
|
101 |
+
odims = 32
|
102 |
+
from torchcrf import CRF
|
103 |
+
self.crf = CRF(odims, batch_first=True)
|
104 |
+
self.linear = torch.nn.Linear(n_chans, odims)
|
105 |
+
|
106 |
+
def _forward(self, xs, x_masks=None, is_inference=False):
|
107 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
108 |
+
for f in self.conv:
|
109 |
+
xs = f(xs) # (B, C, Tmax)
|
110 |
+
if x_masks is not None:
|
111 |
+
xs = xs * (1 - x_masks.float())[:, None, :]
|
112 |
+
|
113 |
+
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
|
114 |
+
xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
|
115 |
+
if is_inference:
|
116 |
+
return self.out2dur(xs), xs
|
117 |
+
else:
|
118 |
+
if hparams['dur_loss'] in ['mse']:
|
119 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
120 |
+
return xs
|
121 |
+
|
122 |
+
def out2dur(self, xs):
|
123 |
+
if hparams['dur_loss'] in ['mse']:
|
124 |
+
# NOTE: calculate in log domain
|
125 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
126 |
+
dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
|
127 |
+
elif hparams['dur_loss'] == 'mog':
|
128 |
+
return NotImplementedError
|
129 |
+
elif hparams['dur_loss'] == 'crf':
|
130 |
+
dur = torch.LongTensor(self.crf.decode(xs)).cuda()
|
131 |
+
return dur
|
132 |
+
|
133 |
+
def forward(self, xs, x_masks=None):
|
134 |
+
"""Calculate forward propagation.
|
135 |
+
Args:
|
136 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
137 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
138 |
+
Returns:
|
139 |
+
Tensor: Batch of predicted durations in log domain (B, Tmax).
|
140 |
+
"""
|
141 |
+
return self._forward(xs, x_masks, False)
|
142 |
+
|
143 |
+
def inference(self, xs, x_masks=None):
|
144 |
+
"""Inference duration.
|
145 |
+
Args:
|
146 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
147 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
148 |
+
Returns:
|
149 |
+
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
|
150 |
+
"""
|
151 |
+
return self._forward(xs, x_masks, True)
|
152 |
+
|
153 |
+
|
154 |
+
class LengthRegulator(torch.nn.Module):
|
155 |
+
def __init__(self, pad_value=0.0):
|
156 |
+
super(LengthRegulator, self).__init__()
|
157 |
+
self.pad_value = pad_value
|
158 |
+
|
159 |
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
160 |
+
"""
|
161 |
+
Example (no batch dim version):
|
162 |
+
1. dur = [2,2,3]
|
163 |
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
164 |
+
3. token_mask = [[1,1,0,0,0,0,0],
|
165 |
+
[0,0,1,1,0,0,0],
|
166 |
+
[0,0,0,0,1,1,1]]
|
167 |
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
168 |
+
[0,0,2,2,0,0,0],
|
169 |
+
[0,0,0,0,3,3,3]]
|
170 |
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
171 |
+
|
172 |
+
:param dur: Batch of durations of each frame (B, T_txt)
|
173 |
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
174 |
+
:param alpha: duration rescale coefficient
|
175 |
+
:return:
|
176 |
+
mel2ph (B, T_speech)
|
177 |
+
"""
|
178 |
+
assert alpha > 0
|
179 |
+
dur = torch.round(dur.float() * alpha).long()
|
180 |
+
if dur_padding is not None:
|
181 |
+
dur = dur * (1 - dur_padding.long())
|
182 |
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
183 |
+
dur_cumsum = torch.cumsum(dur, 1)
|
184 |
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
185 |
+
|
186 |
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
187 |
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
188 |
+
mel2ph = (token_idx * token_mask.long()).sum(1)
|
189 |
+
return mel2ph
|
190 |
+
|
191 |
+
|
192 |
+
class PitchPredictor(torch.nn.Module):
|
193 |
+
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
|
194 |
+
dropout_rate=0.1, padding='SAME'):
|
195 |
+
"""Initilize pitch predictor module.
|
196 |
+
Args:
|
197 |
+
idim (int): Input dimension.
|
198 |
+
n_layers (int, optional): Number of convolutional layers.
|
199 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
200 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
201 |
+
dropout_rate (float, optional): Dropout rate.
|
202 |
+
"""
|
203 |
+
super(PitchPredictor, self).__init__()
|
204 |
+
self.conv = torch.nn.ModuleList()
|
205 |
+
self.kernel_size = kernel_size
|
206 |
+
self.padding = padding
|
207 |
+
for idx in range(n_layers):
|
208 |
+
in_chans = idim if idx == 0 else n_chans
|
209 |
+
self.conv += [torch.nn.Sequential(
|
210 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
211 |
+
if padding == 'SAME'
|
212 |
+
else (kernel_size - 1, 0), 0),
|
213 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
214 |
+
torch.nn.ReLU(),
|
215 |
+
LayerNorm(n_chans, dim=1),
|
216 |
+
torch.nn.Dropout(dropout_rate)
|
217 |
+
)]
|
218 |
+
self.linear = torch.nn.Linear(n_chans, odim)
|
219 |
+
self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
|
220 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
|
221 |
+
|
222 |
+
def forward(self, xs):
|
223 |
+
"""
|
224 |
+
|
225 |
+
:param xs: [B, T, H]
|
226 |
+
:return: [B, T, H]
|
227 |
+
"""
|
228 |
+
positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
|
229 |
+
xs = xs + positions
|
230 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
231 |
+
for f in self.conv:
|
232 |
+
xs = f(xs) # (B, C, Tmax)
|
233 |
+
# NOTE: calculate in log domain
|
234 |
+
xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
|
235 |
+
return xs
|
236 |
+
|
237 |
+
|
238 |
+
class EnergyPredictor(PitchPredictor):
|
239 |
+
pass
|
240 |
+
|
241 |
+
|
242 |
+
def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
|
243 |
+
B, _ = mel2ph.shape
|
244 |
+
dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
|
245 |
+
dur = dur[:, 1:]
|
246 |
+
if max_dur is not None:
|
247 |
+
dur = dur.clamp(max=max_dur)
|
248 |
+
return dur
|
249 |
+
|
250 |
+
|
251 |
+
class FFTBlocks(nn.Module):
|
252 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
|
253 |
+
use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
|
254 |
+
super().__init__()
|
255 |
+
self.num_layers = num_layers
|
256 |
+
embed_dim = self.hidden_size = hidden_size
|
257 |
+
self.dropout = dropout if dropout is not None else hparams['dropout']
|
258 |
+
self.use_pos_embed = use_pos_embed
|
259 |
+
self.use_last_norm = use_last_norm
|
260 |
+
if use_pos_embed:
|
261 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
262 |
+
self.padding_idx = 0
|
263 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
264 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
265 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.layers = nn.ModuleList([])
|
269 |
+
self.layers.extend([
|
270 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
271 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads)
|
272 |
+
for _ in range(self.num_layers)
|
273 |
+
])
|
274 |
+
if self.use_last_norm:
|
275 |
+
if norm == 'ln':
|
276 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
277 |
+
elif norm == 'bn':
|
278 |
+
self.layer_norm = BatchNorm1dTBC(embed_dim)
|
279 |
+
else:
|
280 |
+
self.layer_norm = None
|
281 |
+
|
282 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
283 |
+
"""
|
284 |
+
:param x: [B, T, C]
|
285 |
+
:param padding_mask: [B, T]
|
286 |
+
:return: [B, T, C] or [L, B, T, C]
|
287 |
+
"""
|
288 |
+
# padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
289 |
+
padding_mask = x.abs().sum(-1).eq(0).detach() if padding_mask is None else padding_mask
|
290 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
291 |
+
if self.use_pos_embed:
|
292 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
293 |
+
x = x + positions
|
294 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
295 |
+
# B x T x C -> T x B x C
|
296 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
297 |
+
hiddens = []
|
298 |
+
for layer in self.layers:
|
299 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
300 |
+
hiddens.append(x)
|
301 |
+
if self.use_last_norm:
|
302 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
303 |
+
if return_hiddens:
|
304 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
305 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
306 |
+
else:
|
307 |
+
x = x.transpose(0, 1) # [B, T, C]
|
308 |
+
return x
|
309 |
+
|
310 |
+
|
311 |
+
class FastspeechEncoder(FFTBlocks):
|
312 |
+
'''
|
313 |
+
compared to FFTBlocks:
|
314 |
+
- input is [B, T, H], not [B, T, C]
|
315 |
+
- supports "relative" positional encoding
|
316 |
+
'''
|
317 |
+
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
|
318 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
319 |
+
kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
|
320 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
321 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
322 |
+
use_pos_embed=False) # use_pos_embed_alpha for compatibility
|
323 |
+
#self.embed_tokens = embed_tokens
|
324 |
+
self.embed_scale = math.sqrt(hidden_size)
|
325 |
+
self.padding_idx = 0
|
326 |
+
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
|
327 |
+
self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
|
328 |
+
else:
|
329 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
330 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
331 |
+
)
|
332 |
+
|
333 |
+
def forward(self, hubert):
|
334 |
+
"""
|
335 |
+
|
336 |
+
:param hubert: [B, T, H ]
|
337 |
+
:return: {
|
338 |
+
'encoder_out': [T x B x C]
|
339 |
+
}
|
340 |
+
"""
|
341 |
+
# encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
342 |
+
encoder_padding_mask = (hubert==0).all(-1)
|
343 |
+
x = self.forward_embedding(hubert) # [B, T, H]
|
344 |
+
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
|
345 |
+
return x
|
346 |
+
|
347 |
+
def forward_embedding(self, hubert):
|
348 |
+
# embed tokens and positions
|
349 |
+
x = self.embed_scale * hubert
|
350 |
+
if hparams['use_pos_embed']:
|
351 |
+
positions = self.embed_positions(hubert)
|
352 |
+
x = x + positions
|
353 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
354 |
+
return x
|
355 |
+
|
356 |
+
|
357 |
+
class FastspeechDecoder(FFTBlocks):
|
358 |
+
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
|
359 |
+
num_heads = hparams['num_heads'] if num_heads is None else num_heads
|
360 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
361 |
+
kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
|
362 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
363 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
364 |
+
|
modules/hifigan/hifigan.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
|
7 |
+
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
8 |
+
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
LRELU_SLOPE = 0.1
|
12 |
+
|
13 |
+
|
14 |
+
def init_weights(m, mean=0.0, std=0.01):
|
15 |
+
classname = m.__class__.__name__
|
16 |
+
if classname.find("Conv") != -1:
|
17 |
+
m.weight.data.normal_(mean, std)
|
18 |
+
|
19 |
+
|
20 |
+
def apply_weight_norm(m):
|
21 |
+
classname = m.__class__.__name__
|
22 |
+
if classname.find("Conv") != -1:
|
23 |
+
weight_norm(m)
|
24 |
+
|
25 |
+
|
26 |
+
def get_padding(kernel_size, dilation=1):
|
27 |
+
return int((kernel_size * dilation - dilation) / 2)
|
28 |
+
|
29 |
+
|
30 |
+
class ResBlock1(torch.nn.Module):
|
31 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
32 |
+
super(ResBlock1, self).__init__()
|
33 |
+
self.h = h
|
34 |
+
self.convs1 = nn.ModuleList([
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
36 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
38 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
40 |
+
padding=get_padding(kernel_size, dilation[2])))
|
41 |
+
])
|
42 |
+
self.convs1.apply(init_weights)
|
43 |
+
|
44 |
+
self.convs2 = nn.ModuleList([
|
45 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
46 |
+
padding=get_padding(kernel_size, 1))),
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1)))
|
51 |
+
])
|
52 |
+
self.convs2.apply(init_weights)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
56 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
57 |
+
xt = c1(xt)
|
58 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
59 |
+
xt = c2(xt)
|
60 |
+
x = xt + x
|
61 |
+
return x
|
62 |
+
|
63 |
+
def remove_weight_norm(self):
|
64 |
+
for l in self.convs1:
|
65 |
+
remove_weight_norm(l)
|
66 |
+
for l in self.convs2:
|
67 |
+
remove_weight_norm(l)
|
68 |
+
|
69 |
+
|
70 |
+
class ResBlock2(torch.nn.Module):
|
71 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
72 |
+
super(ResBlock2, self).__init__()
|
73 |
+
self.h = h
|
74 |
+
self.convs = nn.ModuleList([
|
75 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
76 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
78 |
+
padding=get_padding(kernel_size, dilation[1])))
|
79 |
+
])
|
80 |
+
self.convs.apply(init_weights)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
for c in self.convs:
|
84 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
85 |
+
xt = c(xt)
|
86 |
+
x = xt + x
|
87 |
+
return x
|
88 |
+
|
89 |
+
def remove_weight_norm(self):
|
90 |
+
for l in self.convs:
|
91 |
+
remove_weight_norm(l)
|
92 |
+
|
93 |
+
|
94 |
+
class Conv1d1x1(Conv1d):
|
95 |
+
"""1x1 Conv1d with customized initialization."""
|
96 |
+
|
97 |
+
def __init__(self, in_channels, out_channels, bias):
|
98 |
+
"""Initialize 1x1 Conv1d module."""
|
99 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
100 |
+
kernel_size=1, padding=0,
|
101 |
+
dilation=1, bias=bias)
|
102 |
+
|
103 |
+
|
104 |
+
class HifiGanGenerator(torch.nn.Module):
|
105 |
+
def __init__(self, h, c_out=1):
|
106 |
+
super(HifiGanGenerator, self).__init__()
|
107 |
+
self.h = h
|
108 |
+
self.num_kernels = len(h['resblock_kernel_sizes'])
|
109 |
+
self.num_upsamples = len(h['upsample_rates'])
|
110 |
+
|
111 |
+
if h['use_pitch_embed']:
|
112 |
+
self.harmonic_num = 8
|
113 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
114 |
+
self.m_source = SourceModuleHnNSF(
|
115 |
+
sampling_rate=h['audio_sample_rate'],
|
116 |
+
harmonic_num=self.harmonic_num)
|
117 |
+
self.noise_convs = nn.ModuleList()
|
118 |
+
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
119 |
+
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
120 |
+
|
121 |
+
self.ups = nn.ModuleList()
|
122 |
+
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
123 |
+
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
124 |
+
self.ups.append(weight_norm(
|
125 |
+
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
126 |
+
if h['use_pitch_embed']:
|
127 |
+
if i + 1 < len(h['upsample_rates']):
|
128 |
+
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
129 |
+
self.noise_convs.append(Conv1d(
|
130 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
131 |
+
else:
|
132 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
133 |
+
|
134 |
+
self.resblocks = nn.ModuleList()
|
135 |
+
for i in range(len(self.ups)):
|
136 |
+
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
137 |
+
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
138 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
139 |
+
|
140 |
+
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
141 |
+
self.ups.apply(init_weights)
|
142 |
+
self.conv_post.apply(init_weights)
|
143 |
+
|
144 |
+
def forward(self, x, f0=None):
|
145 |
+
if f0 is not None:
|
146 |
+
# harmonic-source signal, noise-source signal, uv flag
|
147 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
148 |
+
har_source, noi_source, uv = self.m_source(f0)
|
149 |
+
har_source = har_source.transpose(1, 2)
|
150 |
+
|
151 |
+
x = self.conv_pre(x)
|
152 |
+
for i in range(self.num_upsamples):
|
153 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
154 |
+
x = self.ups[i](x)
|
155 |
+
if f0 is not None:
|
156 |
+
x_source = self.noise_convs[i](har_source)
|
157 |
+
x = x + x_source
|
158 |
+
xs = None
|
159 |
+
for j in range(self.num_kernels):
|
160 |
+
if xs is None:
|
161 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
162 |
+
else:
|
163 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
164 |
+
x = xs / self.num_kernels
|
165 |
+
x = F.leaky_relu(x)
|
166 |
+
x = self.conv_post(x)
|
167 |
+
x = torch.tanh(x)
|
168 |
+
|
169 |
+
return x
|
170 |
+
|
171 |
+
def remove_weight_norm(self):
|
172 |
+
print('Removing weight norm...')
|
173 |
+
for l in self.ups:
|
174 |
+
remove_weight_norm(l)
|
175 |
+
for l in self.resblocks:
|
176 |
+
l.remove_weight_norm()
|
177 |
+
remove_weight_norm(self.conv_pre)
|
178 |
+
remove_weight_norm(self.conv_post)
|
179 |
+
|
180 |
+
|
181 |
+
class DiscriminatorP(torch.nn.Module):
|
182 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
183 |
+
super(DiscriminatorP, self).__init__()
|
184 |
+
self.use_cond = use_cond
|
185 |
+
if use_cond:
|
186 |
+
from utils.hparams import hparams
|
187 |
+
t = hparams['hop_size']
|
188 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
189 |
+
c_in = 2
|
190 |
+
|
191 |
+
self.period = period
|
192 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
193 |
+
self.convs = nn.ModuleList([
|
194 |
+
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
195 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
196 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
197 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
198 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
199 |
+
])
|
200 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
201 |
+
|
202 |
+
def forward(self, x, mel):
|
203 |
+
fmap = []
|
204 |
+
if self.use_cond:
|
205 |
+
x_mel = self.cond_net(mel)
|
206 |
+
x = torch.cat([x_mel, x], 1)
|
207 |
+
# 1d to 2d
|
208 |
+
b, c, t = x.shape
|
209 |
+
if t % self.period != 0: # pad first
|
210 |
+
n_pad = self.period - (t % self.period)
|
211 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
212 |
+
t = t + n_pad
|
213 |
+
x = x.view(b, c, t // self.period, self.period)
|
214 |
+
|
215 |
+
for l in self.convs:
|
216 |
+
x = l(x)
|
217 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
218 |
+
fmap.append(x)
|
219 |
+
x = self.conv_post(x)
|
220 |
+
fmap.append(x)
|
221 |
+
x = torch.flatten(x, 1, -1)
|
222 |
+
|
223 |
+
return x, fmap
|
224 |
+
|
225 |
+
|
226 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
227 |
+
def __init__(self, use_cond=False, c_in=1):
|
228 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
229 |
+
self.discriminators = nn.ModuleList([
|
230 |
+
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
231 |
+
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
232 |
+
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
233 |
+
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
234 |
+
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
235 |
+
])
|
236 |
+
|
237 |
+
def forward(self, y, y_hat, mel=None):
|
238 |
+
y_d_rs = []
|
239 |
+
y_d_gs = []
|
240 |
+
fmap_rs = []
|
241 |
+
fmap_gs = []
|
242 |
+
for i, d in enumerate(self.discriminators):
|
243 |
+
y_d_r, fmap_r = d(y, mel)
|
244 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
245 |
+
y_d_rs.append(y_d_r)
|
246 |
+
fmap_rs.append(fmap_r)
|
247 |
+
y_d_gs.append(y_d_g)
|
248 |
+
fmap_gs.append(fmap_g)
|
249 |
+
|
250 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
251 |
+
|
252 |
+
|
253 |
+
class DiscriminatorS(torch.nn.Module):
|
254 |
+
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
255 |
+
super(DiscriminatorS, self).__init__()
|
256 |
+
self.use_cond = use_cond
|
257 |
+
if use_cond:
|
258 |
+
t = np.prod(upsample_rates)
|
259 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
260 |
+
c_in = 2
|
261 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
262 |
+
self.convs = nn.ModuleList([
|
263 |
+
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
264 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
265 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
266 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
267 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
268 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
269 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
270 |
+
])
|
271 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
272 |
+
|
273 |
+
def forward(self, x, mel):
|
274 |
+
if self.use_cond:
|
275 |
+
x_mel = self.cond_net(mel)
|
276 |
+
x = torch.cat([x_mel, x], 1)
|
277 |
+
fmap = []
|
278 |
+
for l in self.convs:
|
279 |
+
x = l(x)
|
280 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
281 |
+
fmap.append(x)
|
282 |
+
x = self.conv_post(x)
|
283 |
+
fmap.append(x)
|
284 |
+
x = torch.flatten(x, 1, -1)
|
285 |
+
|
286 |
+
return x, fmap
|
287 |
+
|
288 |
+
|
289 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
290 |
+
def __init__(self, use_cond=False, c_in=1):
|
291 |
+
super(MultiScaleDiscriminator, self).__init__()
|
292 |
+
from utils.hparams import hparams
|
293 |
+
self.discriminators = nn.ModuleList([
|
294 |
+
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
295 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
296 |
+
c_in=c_in),
|
297 |
+
DiscriminatorS(use_cond=use_cond,
|
298 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
299 |
+
c_in=c_in),
|
300 |
+
DiscriminatorS(use_cond=use_cond,
|
301 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
302 |
+
c_in=c_in),
|
303 |
+
])
|
304 |
+
self.meanpools = nn.ModuleList([
|
305 |
+
AvgPool1d(4, 2, padding=1),
|
306 |
+
AvgPool1d(4, 2, padding=1)
|
307 |
+
])
|
308 |
+
|
309 |
+
def forward(self, y, y_hat, mel=None):
|
310 |
+
y_d_rs = []
|
311 |
+
y_d_gs = []
|
312 |
+
fmap_rs = []
|
313 |
+
fmap_gs = []
|
314 |
+
for i, d in enumerate(self.discriminators):
|
315 |
+
if i != 0:
|
316 |
+
y = self.meanpools[i - 1](y)
|
317 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
318 |
+
y_d_r, fmap_r = d(y, mel)
|
319 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
320 |
+
y_d_rs.append(y_d_r)
|
321 |
+
fmap_rs.append(fmap_r)
|
322 |
+
y_d_gs.append(y_d_g)
|
323 |
+
fmap_gs.append(fmap_g)
|
324 |
+
|
325 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
326 |
+
|
327 |
+
|
328 |
+
def feature_loss(fmap_r, fmap_g):
|
329 |
+
loss = 0
|
330 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
331 |
+
for rl, gl in zip(dr, dg):
|
332 |
+
loss += torch.mean(torch.abs(rl - gl))
|
333 |
+
|
334 |
+
return loss * 2
|
335 |
+
|
336 |
+
|
337 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
338 |
+
r_losses = 0
|
339 |
+
g_losses = 0
|
340 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
341 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
342 |
+
g_loss = torch.mean(dg ** 2)
|
343 |
+
r_losses += r_loss
|
344 |
+
g_losses += g_loss
|
345 |
+
r_losses = r_losses / len(disc_real_outputs)
|
346 |
+
g_losses = g_losses / len(disc_real_outputs)
|
347 |
+
return r_losses, g_losses
|
348 |
+
|
349 |
+
|
350 |
+
def cond_discriminator_loss(outputs):
|
351 |
+
loss = 0
|
352 |
+
for dg in outputs:
|
353 |
+
g_loss = torch.mean(dg ** 2)
|
354 |
+
loss += g_loss
|
355 |
+
loss = loss / len(outputs)
|
356 |
+
return loss
|
357 |
+
|
358 |
+
|
359 |
+
def generator_loss(disc_outputs):
|
360 |
+
loss = 0
|
361 |
+
for dg in disc_outputs:
|
362 |
+
l = torch.mean((1 - dg) ** 2)
|
363 |
+
loss += l
|
364 |
+
loss = loss / len(disc_outputs)
|
365 |
+
return loss
|
modules/hifigan/mel_utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, hparams, center=False, complex=False):
|
46 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
47 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
48 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
49 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
50 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
51 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
52 |
+
n_fft = hparams['fft_size']
|
53 |
+
num_mels = hparams['audio_num_mel_bins']
|
54 |
+
sampling_rate = hparams['audio_sample_rate']
|
55 |
+
hop_size = hparams['hop_size']
|
56 |
+
win_size = hparams['win_size']
|
57 |
+
fmin = hparams['fmin']
|
58 |
+
fmax = hparams['fmax']
|
59 |
+
y = y.clamp(min=-1., max=1.)
|
60 |
+
global mel_basis, hann_window
|
61 |
+
if fmax not in mel_basis:
|
62 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
63 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
64 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
65 |
+
|
66 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
67 |
+
mode='reflect')
|
68 |
+
y = y.squeeze(1)
|
69 |
+
|
70 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
71 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
72 |
+
|
73 |
+
if not complex:
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
75 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
76 |
+
spec = spectral_normalize_torch(spec)
|
77 |
+
else:
|
78 |
+
B, C, T, _ = spec.shape
|
79 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
80 |
+
return spec
|
modules/nsf_hifigan/env.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
|
4 |
+
|
5 |
+
class AttrDict(dict):
|
6 |
+
def __init__(self, *args, **kwargs):
|
7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8 |
+
self.__dict__ = self
|
9 |
+
|
10 |
+
|
11 |
+
def build_env(config, config_name, path):
|
12 |
+
t_path = os.path.join(path, config_name)
|
13 |
+
if config != t_path:
|
14 |
+
os.makedirs(path, exist_ok=True)
|
15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
modules/nsf_hifigan/models.py
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from .env import AttrDict
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
9 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
10 |
+
from .utils import init_weights, get_padding
|
11 |
+
|
12 |
+
LRELU_SLOPE = 0.1
|
13 |
+
|
14 |
+
def load_model(model_path, device='cuda'):
|
15 |
+
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
|
16 |
+
with open(config_file) as f:
|
17 |
+
data = f.read()
|
18 |
+
|
19 |
+
global h
|
20 |
+
json_config = json.loads(data)
|
21 |
+
h = AttrDict(json_config)
|
22 |
+
|
23 |
+
generator = Generator(h).to(device)
|
24 |
+
|
25 |
+
cp_dict = torch.load(model_path)
|
26 |
+
generator.load_state_dict(cp_dict['generator'])
|
27 |
+
generator.eval()
|
28 |
+
generator.remove_weight_norm()
|
29 |
+
del cp_dict
|
30 |
+
return generator, h
|
31 |
+
|
32 |
+
|
33 |
+
class ResBlock1(torch.nn.Module):
|
34 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
35 |
+
super(ResBlock1, self).__init__()
|
36 |
+
self.h = h
|
37 |
+
self.convs1 = nn.ModuleList([
|
38 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
39 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
40 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
41 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
42 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
43 |
+
padding=get_padding(kernel_size, dilation[2])))
|
44 |
+
])
|
45 |
+
self.convs1.apply(init_weights)
|
46 |
+
|
47 |
+
self.convs2 = nn.ModuleList([
|
48 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
49 |
+
padding=get_padding(kernel_size, 1))),
|
50 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
51 |
+
padding=get_padding(kernel_size, 1))),
|
52 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
53 |
+
padding=get_padding(kernel_size, 1)))
|
54 |
+
])
|
55 |
+
self.convs2.apply(init_weights)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
59 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
60 |
+
xt = c1(xt)
|
61 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
62 |
+
xt = c2(xt)
|
63 |
+
x = xt + x
|
64 |
+
return x
|
65 |
+
|
66 |
+
def remove_weight_norm(self):
|
67 |
+
for l in self.convs1:
|
68 |
+
remove_weight_norm(l)
|
69 |
+
for l in self.convs2:
|
70 |
+
remove_weight_norm(l)
|
71 |
+
|
72 |
+
|
73 |
+
class ResBlock2(torch.nn.Module):
|
74 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
75 |
+
super(ResBlock2, self).__init__()
|
76 |
+
self.h = h
|
77 |
+
self.convs = nn.ModuleList([
|
78 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
79 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
80 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
81 |
+
padding=get_padding(kernel_size, dilation[1])))
|
82 |
+
])
|
83 |
+
self.convs.apply(init_weights)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
for c in self.convs:
|
87 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
88 |
+
xt = c(xt)
|
89 |
+
x = xt + x
|
90 |
+
return x
|
91 |
+
|
92 |
+
def remove_weight_norm(self):
|
93 |
+
for l in self.convs:
|
94 |
+
remove_weight_norm(l)
|
95 |
+
|
96 |
+
|
97 |
+
class Generator(torch.nn.Module):
|
98 |
+
def __init__(self, h):
|
99 |
+
super(Generator, self).__init__()
|
100 |
+
self.h = h
|
101 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
102 |
+
self.num_upsamples = len(h.upsample_rates)
|
103 |
+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
104 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
105 |
+
|
106 |
+
self.ups = nn.ModuleList()
|
107 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
108 |
+
self.ups.append(weight_norm(
|
109 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
110 |
+
k, u, padding=(k-u)//2)))
|
111 |
+
|
112 |
+
self.resblocks = nn.ModuleList()
|
113 |
+
for i in range(len(self.ups)):
|
114 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
115 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
116 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
117 |
+
|
118 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
119 |
+
self.ups.apply(init_weights)
|
120 |
+
self.conv_post.apply(init_weights)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
x = self.conv_pre(x)
|
124 |
+
for i in range(self.num_upsamples):
|
125 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
126 |
+
x = self.ups[i](x)
|
127 |
+
xs = None
|
128 |
+
for j in range(self.num_kernels):
|
129 |
+
if xs is None:
|
130 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
131 |
+
else:
|
132 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
133 |
+
x = xs / self.num_kernels
|
134 |
+
x = F.leaky_relu(x)
|
135 |
+
x = self.conv_post(x)
|
136 |
+
x = torch.tanh(x)
|
137 |
+
|
138 |
+
return x
|
139 |
+
|
140 |
+
def remove_weight_norm(self):
|
141 |
+
print('Removing weight norm...')
|
142 |
+
for l in self.ups:
|
143 |
+
remove_weight_norm(l)
|
144 |
+
for l in self.resblocks:
|
145 |
+
l.remove_weight_norm()
|
146 |
+
remove_weight_norm(self.conv_pre)
|
147 |
+
remove_weight_norm(self.conv_post)
|
148 |
+
class SineGen(torch.nn.Module):
|
149 |
+
""" Definition of sine generator
|
150 |
+
SineGen(samp_rate, harmonic_num = 0,
|
151 |
+
sine_amp = 0.1, noise_std = 0.003,
|
152 |
+
voiced_threshold = 0,
|
153 |
+
flag_for_pulse=False)
|
154 |
+
samp_rate: sampling rate in Hz
|
155 |
+
harmonic_num: number of harmonic overtones (default 0)
|
156 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
157 |
+
noise_std: std of Gaussian noise (default 0.003)
|
158 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
159 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
160 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
161 |
+
segment is always sin(np.pi) or cos(0)
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
165 |
+
sine_amp=0.1, noise_std=0.003,
|
166 |
+
voiced_threshold=0,
|
167 |
+
flag_for_pulse=False):
|
168 |
+
super(SineGen, self).__init__()
|
169 |
+
self.sine_amp = sine_amp
|
170 |
+
self.noise_std = noise_std
|
171 |
+
self.harmonic_num = harmonic_num
|
172 |
+
self.dim = self.harmonic_num + 1
|
173 |
+
self.sampling_rate = samp_rate
|
174 |
+
self.voiced_threshold = voiced_threshold
|
175 |
+
self.flag_for_pulse = flag_for_pulse
|
176 |
+
|
177 |
+
def _f02uv(self, f0):
|
178 |
+
# generate uv signal
|
179 |
+
uv = torch.ones_like(f0)
|
180 |
+
uv = uv * (f0 > self.voiced_threshold)
|
181 |
+
return uv
|
182 |
+
|
183 |
+
def _f02sine(self, f0_values):
|
184 |
+
""" f0_values: (batchsize, length, dim)
|
185 |
+
where dim indicates fundamental tone and overtones
|
186 |
+
"""
|
187 |
+
# convert to F0 in rad. The interger part n can be ignored
|
188 |
+
# because 2 * np.pi * n doesn't affect phase
|
189 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
190 |
+
|
191 |
+
# initial phase noise (no noise for fundamental component)
|
192 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
193 |
+
device=f0_values.device)
|
194 |
+
rand_ini[:, 0] = 0
|
195 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
196 |
+
|
197 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
198 |
+
if not self.flag_for_pulse:
|
199 |
+
# for normal case
|
200 |
+
|
201 |
+
# To prevent torch.cumsum numerical overflow,
|
202 |
+
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
203 |
+
# Buffer tmp_over_one_idx indicates the time step to add -1.
|
204 |
+
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
205 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
206 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
|
207 |
+
tmp_over_one[:, :-1, :]) < 0
|
208 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
209 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
210 |
+
|
211 |
+
sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
|
212 |
+
* 2 * np.pi)
|
213 |
+
else:
|
214 |
+
# If necessary, make sure that the first time step of every
|
215 |
+
# voiced segments is sin(pi) or cos(0)
|
216 |
+
# This is used for pulse-train generation
|
217 |
+
|
218 |
+
# identify the last time step in unvoiced segments
|
219 |
+
uv = self._f02uv(f0_values)
|
220 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
221 |
+
uv_1[:, -1, :] = 1
|
222 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
223 |
+
|
224 |
+
# get the instantanouse phase
|
225 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
226 |
+
# different batch needs to be processed differently
|
227 |
+
for idx in range(f0_values.shape[0]):
|
228 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
229 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
230 |
+
# stores the accumulation of i.phase within
|
231 |
+
# each voiced segments
|
232 |
+
tmp_cumsum[idx, :, :] = 0
|
233 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
234 |
+
|
235 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
236 |
+
# within the previous voiced segment.
|
237 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
238 |
+
|
239 |
+
# get the sines
|
240 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
241 |
+
return sines
|
242 |
+
|
243 |
+
def forward(self, f0):
|
244 |
+
""" sine_tensor, uv = forward(f0)
|
245 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
246 |
+
f0 for unvoiced steps should be 0
|
247 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
248 |
+
output uv: tensor(batchsize=1, length, 1)
|
249 |
+
"""
|
250 |
+
with torch.no_grad():
|
251 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
252 |
+
device=f0.device)
|
253 |
+
# fundamental component
|
254 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
255 |
+
for idx in np.arange(self.harmonic_num):
|
256 |
+
# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
257 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
258 |
+
|
259 |
+
# generate sine waveforms
|
260 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
261 |
+
|
262 |
+
# generate uv signal
|
263 |
+
# uv = torch.ones(f0.shape)
|
264 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
265 |
+
uv = self._f02uv(f0)
|
266 |
+
|
267 |
+
# noise: for unvoiced should be similar to sine_amp
|
268 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
269 |
+
# . for voiced regions is self.noise_std
|
270 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
271 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
272 |
+
|
273 |
+
# first: set the unvoiced part to 0 by uv
|
274 |
+
# then: additive noise
|
275 |
+
sine_waves = sine_waves * uv + noise
|
276 |
+
return sine_waves, uv, noise
|
277 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
278 |
+
""" SourceModule for hn-nsf
|
279 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
280 |
+
add_noise_std=0.003, voiced_threshod=0)
|
281 |
+
sampling_rate: sampling_rate in Hz
|
282 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
283 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
284 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
285 |
+
note that amplitude of noise in unvoiced is decided
|
286 |
+
by sine_amp
|
287 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
288 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
289 |
+
F0_sampled (batchsize, length, 1)
|
290 |
+
Sine_source (batchsize, length, 1)
|
291 |
+
noise_source (batchsize, length 1)
|
292 |
+
uv (batchsize, length, 1)
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
296 |
+
add_noise_std=0.003, voiced_threshod=0):
|
297 |
+
super(SourceModuleHnNSF, self).__init__()
|
298 |
+
|
299 |
+
self.sine_amp = sine_amp
|
300 |
+
self.noise_std = add_noise_std
|
301 |
+
|
302 |
+
# to produce sine waveforms
|
303 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
304 |
+
sine_amp, add_noise_std, voiced_threshod)
|
305 |
+
|
306 |
+
# to merge source harmonics into a single excitation
|
307 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
308 |
+
self.l_tanh = torch.nn.Tanh()
|
309 |
+
|
310 |
+
def forward(self, x):
|
311 |
+
"""
|
312 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
313 |
+
F0_sampled (batchsize, length, 1)
|
314 |
+
Sine_source (batchsize, length, 1)
|
315 |
+
noise_source (batchsize, length 1)
|
316 |
+
"""
|
317 |
+
# source for harmonic branch
|
318 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
319 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
320 |
+
|
321 |
+
# source for noise branch, in the same shape as uv
|
322 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
323 |
+
return sine_merge, noise, uv
|
324 |
+
|
325 |
+
class Generator(torch.nn.Module):
|
326 |
+
def __init__(self, h):
|
327 |
+
super(Generator, self).__init__()
|
328 |
+
self.h = h
|
329 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
330 |
+
self.num_upsamples = len(h.upsample_rates)
|
331 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates))
|
332 |
+
self.m_source = SourceModuleHnNSF(
|
333 |
+
sampling_rate=h.sampling_rate,
|
334 |
+
harmonic_num=8)
|
335 |
+
self.noise_convs = nn.ModuleList()
|
336 |
+
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
337 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
338 |
+
|
339 |
+
self.ups = nn.ModuleList()
|
340 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
341 |
+
c_cur = h.upsample_initial_channel // (2 ** (i + 1))
|
342 |
+
self.ups.append(weight_norm(
|
343 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
344 |
+
k, u, padding=(k-u)//2)))
|
345 |
+
if i + 1 < len(h.upsample_rates):#
|
346 |
+
stride_f0 = np.prod(h.upsample_rates[i + 1:])
|
347 |
+
self.noise_convs.append(Conv1d(
|
348 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
349 |
+
else:
|
350 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
351 |
+
self.resblocks = nn.ModuleList()
|
352 |
+
for i in range(len(self.ups)):
|
353 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
354 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
355 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
356 |
+
|
357 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
358 |
+
self.ups.apply(init_weights)
|
359 |
+
self.conv_post.apply(init_weights)
|
360 |
+
|
361 |
+
def forward(self, x,f0):
|
362 |
+
# print(1,x.shape,f0.shape,f0[:, None].shape)
|
363 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)#bs,n,t
|
364 |
+
# print(2,f0.shape)
|
365 |
+
har_source, noi_source, uv = self.m_source(f0)
|
366 |
+
har_source = har_source.transpose(1, 2)
|
367 |
+
x = self.conv_pre(x)
|
368 |
+
# print(124,x.shape,har_source.shape)
|
369 |
+
for i in range(self.num_upsamples):
|
370 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
371 |
+
# print(3,x.shape)
|
372 |
+
x = self.ups[i](x)
|
373 |
+
x_source = self.noise_convs[i](har_source)
|
374 |
+
# print(4,x_source.shape,har_source.shape,x.shape)
|
375 |
+
x = x + x_source
|
376 |
+
xs = None
|
377 |
+
for j in range(self.num_kernels):
|
378 |
+
if xs is None:
|
379 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
380 |
+
else:
|
381 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
382 |
+
x = xs / self.num_kernels
|
383 |
+
x = F.leaky_relu(x)
|
384 |
+
x = self.conv_post(x)
|
385 |
+
x = torch.tanh(x)
|
386 |
+
|
387 |
+
return x
|
388 |
+
|
389 |
+
def remove_weight_norm(self):
|
390 |
+
print('Removing weight norm...')
|
391 |
+
for l in self.ups:
|
392 |
+
remove_weight_norm(l)
|
393 |
+
for l in self.resblocks:
|
394 |
+
l.remove_weight_norm()
|
395 |
+
remove_weight_norm(self.conv_pre)
|
396 |
+
remove_weight_norm(self.conv_post)
|
397 |
+
|
398 |
+
class DiscriminatorP(torch.nn.Module):
|
399 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
400 |
+
super(DiscriminatorP, self).__init__()
|
401 |
+
self.period = period
|
402 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
403 |
+
self.convs = nn.ModuleList([
|
404 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
405 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
406 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
407 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
408 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
409 |
+
])
|
410 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
411 |
+
|
412 |
+
def forward(self, x):
|
413 |
+
fmap = []
|
414 |
+
|
415 |
+
# 1d to 2d
|
416 |
+
b, c, t = x.shape
|
417 |
+
if t % self.period != 0: # pad first
|
418 |
+
n_pad = self.period - (t % self.period)
|
419 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
420 |
+
t = t + n_pad
|
421 |
+
x = x.view(b, c, t // self.period, self.period)
|
422 |
+
|
423 |
+
for l in self.convs:
|
424 |
+
x = l(x)
|
425 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
426 |
+
fmap.append(x)
|
427 |
+
x = self.conv_post(x)
|
428 |
+
fmap.append(x)
|
429 |
+
x = torch.flatten(x, 1, -1)
|
430 |
+
|
431 |
+
return x, fmap
|
432 |
+
|
433 |
+
|
434 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
435 |
+
def __init__(self, periods=None):
|
436 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
437 |
+
self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
|
438 |
+
self.discriminators = nn.ModuleList()
|
439 |
+
for period in self.periods:
|
440 |
+
self.discriminators.append(DiscriminatorP(period))
|
441 |
+
|
442 |
+
def forward(self, y, y_hat):
|
443 |
+
y_d_rs = []
|
444 |
+
y_d_gs = []
|
445 |
+
fmap_rs = []
|
446 |
+
fmap_gs = []
|
447 |
+
for i, d in enumerate(self.discriminators):
|
448 |
+
y_d_r, fmap_r = d(y)
|
449 |
+
y_d_g, fmap_g = d(y_hat)
|
450 |
+
y_d_rs.append(y_d_r)
|
451 |
+
fmap_rs.append(fmap_r)
|
452 |
+
y_d_gs.append(y_d_g)
|
453 |
+
fmap_gs.append(fmap_g)
|
454 |
+
|
455 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
456 |
+
|
457 |
+
|
458 |
+
class DiscriminatorS(torch.nn.Module):
|
459 |
+
def __init__(self, use_spectral_norm=False):
|
460 |
+
super(DiscriminatorS, self).__init__()
|
461 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
462 |
+
self.convs = nn.ModuleList([
|
463 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
464 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
465 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
466 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
467 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
468 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
469 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
470 |
+
])
|
471 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
472 |
+
|
473 |
+
def forward(self, x):
|
474 |
+
fmap = []
|
475 |
+
for l in self.convs:
|
476 |
+
x = l(x)
|
477 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
478 |
+
fmap.append(x)
|
479 |
+
x = self.conv_post(x)
|
480 |
+
fmap.append(x)
|
481 |
+
x = torch.flatten(x, 1, -1)
|
482 |
+
|
483 |
+
return x, fmap
|
484 |
+
|
485 |
+
|
486 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
487 |
+
def __init__(self):
|
488 |
+
super(MultiScaleDiscriminator, self).__init__()
|
489 |
+
self.discriminators = nn.ModuleList([
|
490 |
+
DiscriminatorS(use_spectral_norm=True),
|
491 |
+
DiscriminatorS(),
|
492 |
+
DiscriminatorS(),
|
493 |
+
])
|
494 |
+
self.meanpools = nn.ModuleList([
|
495 |
+
AvgPool1d(4, 2, padding=2),
|
496 |
+
AvgPool1d(4, 2, padding=2)
|
497 |
+
])
|
498 |
+
|
499 |
+
def forward(self, y, y_hat):
|
500 |
+
y_d_rs = []
|
501 |
+
y_d_gs = []
|
502 |
+
fmap_rs = []
|
503 |
+
fmap_gs = []
|
504 |
+
for i, d in enumerate(self.discriminators):
|
505 |
+
if i != 0:
|
506 |
+
y = self.meanpools[i-1](y)
|
507 |
+
y_hat = self.meanpools[i-1](y_hat)
|
508 |
+
y_d_r, fmap_r = d(y)
|
509 |
+
y_d_g, fmap_g = d(y_hat)
|
510 |
+
y_d_rs.append(y_d_r)
|
511 |
+
fmap_rs.append(fmap_r)
|
512 |
+
y_d_gs.append(y_d_g)
|
513 |
+
fmap_gs.append(fmap_g)
|
514 |
+
|
515 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
516 |
+
|
517 |
+
|
518 |
+
def feature_loss(fmap_r, fmap_g):
|
519 |
+
loss = 0
|
520 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
521 |
+
for rl, gl in zip(dr, dg):
|
522 |
+
loss += torch.mean(torch.abs(rl - gl))
|
523 |
+
|
524 |
+
return loss*2
|
525 |
+
|
526 |
+
|
527 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
528 |
+
loss = 0
|
529 |
+
r_losses = []
|
530 |
+
g_losses = []
|
531 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
532 |
+
r_loss = torch.mean((1-dr)**2)
|
533 |
+
g_loss = torch.mean(dg**2)
|
534 |
+
loss += (r_loss + g_loss)
|
535 |
+
r_losses.append(r_loss.item())
|
536 |
+
g_losses.append(g_loss.item())
|
537 |
+
|
538 |
+
return loss, r_losses, g_losses
|
539 |
+
|
540 |
+
|
541 |
+
def generator_loss(disc_outputs):
|
542 |
+
loss = 0
|
543 |
+
gen_losses = []
|
544 |
+
for dg in disc_outputs:
|
545 |
+
l = torch.mean((1-dg)**2)
|
546 |
+
gen_losses.append(l)
|
547 |
+
loss += l
|
548 |
+
|
549 |
+
return loss, gen_losses
|
modules/nsf_hifigan/nvSTFT.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.utils.data
|
7 |
+
import numpy as np
|
8 |
+
import librosa
|
9 |
+
from librosa.util import normalize
|
10 |
+
from librosa.filters import mel as librosa_mel_fn
|
11 |
+
from scipy.io.wavfile import read
|
12 |
+
import soundfile as sf
|
13 |
+
|
14 |
+
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
15 |
+
sampling_rate = None
|
16 |
+
try:
|
17 |
+
data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
|
18 |
+
except Exception as ex:
|
19 |
+
print(f"'{full_path}' failed to load.\nException:")
|
20 |
+
print(ex)
|
21 |
+
if return_empty_on_exception:
|
22 |
+
return [], sampling_rate or target_sr or 48000
|
23 |
+
else:
|
24 |
+
raise Exception(ex)
|
25 |
+
|
26 |
+
if len(data.shape) > 1:
|
27 |
+
data = data[:, 0]
|
28 |
+
assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
|
29 |
+
|
30 |
+
if np.issubdtype(data.dtype, np.integer): # if audio data is type int
|
31 |
+
max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
|
32 |
+
else: # if audio data is type fp32
|
33 |
+
max_mag = max(np.amax(data), -np.amin(data))
|
34 |
+
max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
|
35 |
+
|
36 |
+
data = torch.FloatTensor(data.astype(np.float32))/max_mag
|
37 |
+
|
38 |
+
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
|
39 |
+
return [], sampling_rate or target_sr or 48000
|
40 |
+
if target_sr is not None and sampling_rate != target_sr:
|
41 |
+
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
|
42 |
+
sampling_rate = target_sr
|
43 |
+
|
44 |
+
return data, sampling_rate
|
45 |
+
|
46 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
47 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
48 |
+
|
49 |
+
def dynamic_range_decompression(x, C=1):
|
50 |
+
return np.exp(x) / C
|
51 |
+
|
52 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
53 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
54 |
+
|
55 |
+
def dynamic_range_decompression_torch(x, C=1):
|
56 |
+
return torch.exp(x) / C
|
57 |
+
|
58 |
+
class STFT():
|
59 |
+
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
60 |
+
self.target_sr = sr
|
61 |
+
|
62 |
+
self.n_mels = n_mels
|
63 |
+
self.n_fft = n_fft
|
64 |
+
self.win_size = win_size
|
65 |
+
self.hop_length = hop_length
|
66 |
+
self.fmin = fmin
|
67 |
+
self.fmax = fmax
|
68 |
+
self.clip_val = clip_val
|
69 |
+
self.mel_basis = {}
|
70 |
+
self.hann_window = {}
|
71 |
+
|
72 |
+
def get_mel(self, y, center=False):
|
73 |
+
sampling_rate = self.target_sr
|
74 |
+
n_mels = self.n_mels
|
75 |
+
n_fft = self.n_fft
|
76 |
+
win_size = self.win_size
|
77 |
+
hop_length = self.hop_length
|
78 |
+
fmin = self.fmin
|
79 |
+
fmax = self.fmax
|
80 |
+
clip_val = self.clip_val
|
81 |
+
|
82 |
+
if torch.min(y) < -1.:
|
83 |
+
print('min value is ', torch.min(y))
|
84 |
+
if torch.max(y) > 1.:
|
85 |
+
print('max value is ', torch.max(y))
|
86 |
+
|
87 |
+
if fmax not in self.mel_basis:
|
88 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
89 |
+
self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
90 |
+
self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
|
91 |
+
|
92 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
|
93 |
+
y = y.squeeze(1)
|
94 |
+
|
95 |
+
spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
|
96 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
97 |
+
# print(111,spec)
|
98 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
99 |
+
# print(222,spec)
|
100 |
+
spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
101 |
+
# print(333,spec)
|
102 |
+
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
103 |
+
# print(444,spec)
|
104 |
+
return spec
|
105 |
+
|
106 |
+
def __call__(self, audiopath):
|
107 |
+
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
108 |
+
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
109 |
+
return spect
|
110 |
+
|
111 |
+
stft = STFT()
|
modules/nsf_hifigan/utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import matplotlib
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils import weight_norm
|
6 |
+
matplotlib.use("Agg")
|
7 |
+
import matplotlib.pylab as plt
|
8 |
+
|
9 |
+
|
10 |
+
def plot_spectrogram(spectrogram):
|
11 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
12 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
13 |
+
interpolation='none')
|
14 |
+
plt.colorbar(im, ax=ax)
|
15 |
+
|
16 |
+
fig.canvas.draw()
|
17 |
+
plt.close()
|
18 |
+
|
19 |
+
return fig
|
20 |
+
|
21 |
+
|
22 |
+
def init_weights(m, mean=0.0, std=0.01):
|
23 |
+
classname = m.__class__.__name__
|
24 |
+
if classname.find("Conv") != -1:
|
25 |
+
m.weight.data.normal_(mean, std)
|
26 |
+
|
27 |
+
|
28 |
+
def apply_weight_norm(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find("Conv") != -1:
|
31 |
+
weight_norm(m)
|
32 |
+
|
33 |
+
|
34 |
+
def get_padding(kernel_size, dilation=1):
|
35 |
+
return int((kernel_size*dilation - dilation)/2)
|
36 |
+
|
37 |
+
|
38 |
+
def load_checkpoint(filepath, device):
|
39 |
+
assert os.path.isfile(filepath)
|
40 |
+
print("Loading '{}'".format(filepath))
|
41 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
42 |
+
print("Complete.")
|
43 |
+
return checkpoint_dict
|
44 |
+
|
45 |
+
|
46 |
+
def save_checkpoint(filepath, obj):
|
47 |
+
print("Saving checkpoint to {}".format(filepath))
|
48 |
+
torch.save(obj, filepath)
|
49 |
+
print("Complete.")
|
50 |
+
|
51 |
+
|
52 |
+
def del_old_checkpoints(cp_dir, prefix, n_models=2):
|
53 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
54 |
+
cp_list = glob.glob(pattern) # get checkpoint paths
|
55 |
+
cp_list = sorted(cp_list)# sort by iter
|
56 |
+
if len(cp_list) > n_models: # if more than n_models models are found
|
57 |
+
for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
|
58 |
+
open(cp, 'w').close()# empty file contents
|
59 |
+
os.unlink(cp)# delete file (move to trash when using Colab)
|
60 |
+
|
61 |
+
|
62 |
+
def scan_checkpoint(cp_dir, prefix):
|
63 |
+
pattern = os.path.join(cp_dir, prefix + '????????')
|
64 |
+
cp_list = glob.glob(pattern)
|
65 |
+
if len(cp_list) == 0:
|
66 |
+
return None
|
67 |
+
return sorted(cp_list)[-1]
|
modules/parallel_wavegan/__init__.py
ADDED
File without changes
|
modules/parallel_wavegan/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .causal_conv import * # NOQA
|
2 |
+
from .pqmf import * # NOQA
|
3 |
+
from .residual_block import * # NOQA
|
4 |
+
from modules.parallel_wavegan.layers.residual_stack import * # NOQA
|
5 |
+
from .upsample import * # NOQA
|
modules/parallel_wavegan/layers/causal_conv.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Causal convolusion layer modules."""
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class CausalConv1d(torch.nn.Module):
|
13 |
+
"""CausalConv1d module with customized initialization."""
|
14 |
+
|
15 |
+
def __init__(self, in_channels, out_channels, kernel_size,
|
16 |
+
dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
|
17 |
+
"""Initialize CausalConv1d module."""
|
18 |
+
super(CausalConv1d, self).__init__()
|
19 |
+
self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
|
20 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
|
21 |
+
dilation=dilation, bias=bias)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
"""Calculate forward propagation.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Tensor: Output tensor (B, out_channels, T).
|
31 |
+
|
32 |
+
"""
|
33 |
+
return self.conv(self.pad(x))[:, :, :x.size(2)]
|
34 |
+
|
35 |
+
|
36 |
+
class CausalConvTranspose1d(torch.nn.Module):
|
37 |
+
"""CausalConvTranspose1d module with customized initialization."""
|
38 |
+
|
39 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
|
40 |
+
"""Initialize CausalConvTranspose1d module."""
|
41 |
+
super(CausalConvTranspose1d, self).__init__()
|
42 |
+
self.deconv = torch.nn.ConvTranspose1d(
|
43 |
+
in_channels, out_channels, kernel_size, stride, bias=bias)
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""Calculate forward propagation.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x (Tensor): Input tensor (B, in_channels, T_in).
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Tensor: Output tensor (B, out_channels, T_out).
|
54 |
+
|
55 |
+
"""
|
56 |
+
return self.deconv(x)[:, :, :-self.stride]
|
modules/parallel_wavegan/layers/pqmf.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Pseudo QMF modules."""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from scipy.signal import kaiser
|
13 |
+
|
14 |
+
|
15 |
+
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
16 |
+
"""Design prototype filter for PQMF.
|
17 |
+
|
18 |
+
This method is based on `A Kaiser window approach for the design of prototype
|
19 |
+
filters of cosine modulated filterbanks`_.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
taps (int): The number of filter taps.
|
23 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
24 |
+
beta (float): Beta coefficient for kaiser window.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
ndarray: Impluse response of prototype filter (taps + 1,).
|
28 |
+
|
29 |
+
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
30 |
+
https://ieeexplore.ieee.org/abstract/document/681427
|
31 |
+
|
32 |
+
"""
|
33 |
+
# check the arguments are valid
|
34 |
+
assert taps % 2 == 0, "The number of taps mush be even number."
|
35 |
+
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
36 |
+
|
37 |
+
# make initial filter
|
38 |
+
omega_c = np.pi * cutoff_ratio
|
39 |
+
with np.errstate(invalid='ignore'):
|
40 |
+
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
|
41 |
+
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
42 |
+
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
43 |
+
|
44 |
+
# apply kaiser window
|
45 |
+
w = kaiser(taps + 1, beta)
|
46 |
+
h = h_i * w
|
47 |
+
|
48 |
+
return h
|
49 |
+
|
50 |
+
|
51 |
+
class PQMF(torch.nn.Module):
|
52 |
+
"""PQMF module.
|
53 |
+
|
54 |
+
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
55 |
+
|
56 |
+
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
57 |
+
https://ieeexplore.ieee.org/document/258122
|
58 |
+
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
|
62 |
+
"""Initilize PQMF module.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
subbands (int): The number of subbands.
|
66 |
+
taps (int): The number of filter taps.
|
67 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
68 |
+
beta (float): Beta coefficient for kaiser window.
|
69 |
+
|
70 |
+
"""
|
71 |
+
super(PQMF, self).__init__()
|
72 |
+
|
73 |
+
# define filter coefficient
|
74 |
+
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
75 |
+
h_analysis = np.zeros((subbands, len(h_proto)))
|
76 |
+
h_synthesis = np.zeros((subbands, len(h_proto)))
|
77 |
+
for k in range(subbands):
|
78 |
+
h_analysis[k] = 2 * h_proto * np.cos(
|
79 |
+
(2 * k + 1) * (np.pi / (2 * subbands)) *
|
80 |
+
(np.arange(taps + 1) - ((taps - 1) / 2)) +
|
81 |
+
(-1) ** k * np.pi / 4)
|
82 |
+
h_synthesis[k] = 2 * h_proto * np.cos(
|
83 |
+
(2 * k + 1) * (np.pi / (2 * subbands)) *
|
84 |
+
(np.arange(taps + 1) - ((taps - 1) / 2)) -
|
85 |
+
(-1) ** k * np.pi / 4)
|
86 |
+
|
87 |
+
# convert to tensor
|
88 |
+
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
89 |
+
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
90 |
+
|
91 |
+
# register coefficients as beffer
|
92 |
+
self.register_buffer("analysis_filter", analysis_filter)
|
93 |
+
self.register_buffer("synthesis_filter", synthesis_filter)
|
94 |
+
|
95 |
+
# filter for downsampling & upsampling
|
96 |
+
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
97 |
+
for k in range(subbands):
|
98 |
+
updown_filter[k, k, 0] = 1.0
|
99 |
+
self.register_buffer("updown_filter", updown_filter)
|
100 |
+
self.subbands = subbands
|
101 |
+
|
102 |
+
# keep padding info
|
103 |
+
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
104 |
+
|
105 |
+
def analysis(self, x):
|
106 |
+
"""Analysis with PQMF.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
x (Tensor): Input tensor (B, 1, T).
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Tensor: Output tensor (B, subbands, T // subbands).
|
113 |
+
|
114 |
+
"""
|
115 |
+
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
116 |
+
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
117 |
+
|
118 |
+
def synthesis(self, x):
|
119 |
+
"""Synthesis with PQMF.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
x (Tensor): Input tensor (B, subbands, T // subbands).
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
Tensor: Output tensor (B, 1, T).
|
126 |
+
|
127 |
+
"""
|
128 |
+
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
|
129 |
+
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
modules/parallel_wavegan/layers/residual_block.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Residual block module in WaveNet.
|
4 |
+
|
5 |
+
This code is modified from https://github.com/r9y9/wavenet_vocoder.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class Conv1d(torch.nn.Conv1d):
|
16 |
+
"""Conv1d module with customized initialization."""
|
17 |
+
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
"""Initialize Conv1d module."""
|
20 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
21 |
+
|
22 |
+
def reset_parameters(self):
|
23 |
+
"""Reset parameters."""
|
24 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
25 |
+
if self.bias is not None:
|
26 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
27 |
+
|
28 |
+
|
29 |
+
class Conv1d1x1(Conv1d):
|
30 |
+
"""1x1 Conv1d with customized initialization."""
|
31 |
+
|
32 |
+
def __init__(self, in_channels, out_channels, bias):
|
33 |
+
"""Initialize 1x1 Conv1d module."""
|
34 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
35 |
+
kernel_size=1, padding=0,
|
36 |
+
dilation=1, bias=bias)
|
37 |
+
|
38 |
+
|
39 |
+
class ResidualBlock(torch.nn.Module):
|
40 |
+
"""Residual block module in WaveNet."""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
kernel_size=3,
|
44 |
+
residual_channels=64,
|
45 |
+
gate_channels=128,
|
46 |
+
skip_channels=64,
|
47 |
+
aux_channels=80,
|
48 |
+
dropout=0.0,
|
49 |
+
dilation=1,
|
50 |
+
bias=True,
|
51 |
+
use_causal_conv=False
|
52 |
+
):
|
53 |
+
"""Initialize ResidualBlock module.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
57 |
+
residual_channels (int): Number of channels for residual connection.
|
58 |
+
skip_channels (int): Number of channels for skip connection.
|
59 |
+
aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
|
60 |
+
dropout (float): Dropout probability.
|
61 |
+
dilation (int): Dilation factor.
|
62 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
63 |
+
use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
|
64 |
+
|
65 |
+
"""
|
66 |
+
super(ResidualBlock, self).__init__()
|
67 |
+
self.dropout = dropout
|
68 |
+
# no future time stamps available
|
69 |
+
if use_causal_conv:
|
70 |
+
padding = (kernel_size - 1) * dilation
|
71 |
+
else:
|
72 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
73 |
+
padding = (kernel_size - 1) // 2 * dilation
|
74 |
+
self.use_causal_conv = use_causal_conv
|
75 |
+
|
76 |
+
# dilation conv
|
77 |
+
self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
|
78 |
+
padding=padding, dilation=dilation, bias=bias)
|
79 |
+
|
80 |
+
# local conditioning
|
81 |
+
if aux_channels > 0:
|
82 |
+
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
83 |
+
else:
|
84 |
+
self.conv1x1_aux = None
|
85 |
+
|
86 |
+
# conv output is split into two groups
|
87 |
+
gate_out_channels = gate_channels // 2
|
88 |
+
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
|
89 |
+
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
|
90 |
+
|
91 |
+
def forward(self, x, c):
|
92 |
+
"""Calculate forward propagation.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
x (Tensor): Input tensor (B, residual_channels, T).
|
96 |
+
c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
100 |
+
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
101 |
+
|
102 |
+
"""
|
103 |
+
residual = x
|
104 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
105 |
+
x = self.conv(x)
|
106 |
+
|
107 |
+
# remove future time steps if use_causal_conv conv
|
108 |
+
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
|
109 |
+
|
110 |
+
# split into two part for gated activation
|
111 |
+
splitdim = 1
|
112 |
+
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
113 |
+
|
114 |
+
# local conditioning
|
115 |
+
if c is not None:
|
116 |
+
assert self.conv1x1_aux is not None
|
117 |
+
c = self.conv1x1_aux(c)
|
118 |
+
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
119 |
+
xa, xb = xa + ca, xb + cb
|
120 |
+
|
121 |
+
x = torch.tanh(xa) * torch.sigmoid(xb)
|
122 |
+
|
123 |
+
# for skip connection
|
124 |
+
s = self.conv1x1_skip(x)
|
125 |
+
|
126 |
+
# for residual connection
|
127 |
+
x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
|
128 |
+
|
129 |
+
return x, s
|
modules/parallel_wavegan/layers/residual_stack.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Residual stack module in MelGAN."""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from . import CausalConv1d
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualStack(torch.nn.Module):
|
14 |
+
"""Residual stack module introduced in MelGAN."""
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
kernel_size=3,
|
18 |
+
channels=32,
|
19 |
+
dilation=1,
|
20 |
+
bias=True,
|
21 |
+
nonlinear_activation="LeakyReLU",
|
22 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
23 |
+
pad="ReflectionPad1d",
|
24 |
+
pad_params={},
|
25 |
+
use_causal_conv=False,
|
26 |
+
):
|
27 |
+
"""Initialize ResidualStack module.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
31 |
+
channels (int): Number of channels of convolution layers.
|
32 |
+
dilation (int): Dilation factor.
|
33 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
34 |
+
nonlinear_activation (str): Activation function module name.
|
35 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
36 |
+
pad (str): Padding function module name before dilated convolution layer.
|
37 |
+
pad_params (dict): Hyperparameters for padding function.
|
38 |
+
use_causal_conv (bool): Whether to use causal convolution.
|
39 |
+
|
40 |
+
"""
|
41 |
+
super(ResidualStack, self).__init__()
|
42 |
+
|
43 |
+
# defile residual stack part
|
44 |
+
if not use_causal_conv:
|
45 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
46 |
+
self.stack = torch.nn.Sequential(
|
47 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
48 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
|
49 |
+
torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
|
50 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
51 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
self.stack = torch.nn.Sequential(
|
55 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
56 |
+
CausalConv1d(channels, channels, kernel_size, dilation=dilation,
|
57 |
+
bias=bias, pad=pad, pad_params=pad_params),
|
58 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
59 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
60 |
+
)
|
61 |
+
|
62 |
+
# defile extra layer for skip connection
|
63 |
+
self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
64 |
+
|
65 |
+
def forward(self, c):
|
66 |
+
"""Calculate forward propagation.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
c (Tensor): Input tensor (B, channels, T).
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Tensor: Output tensor (B, chennels, T).
|
73 |
+
|
74 |
+
"""
|
75 |
+
return self.stack(c) + self.skip_layer(c)
|
modules/parallel_wavegan/layers/tf_layers.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 MINH ANH (@dathudeptrai)
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Tensorflow Layer modules complatible with pytorch."""
|
7 |
+
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
|
11 |
+
class TFReflectionPad1d(tf.keras.layers.Layer):
|
12 |
+
"""Tensorflow ReflectionPad1d module."""
|
13 |
+
|
14 |
+
def __init__(self, padding_size):
|
15 |
+
"""Initialize TFReflectionPad1d module.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
padding_size (int): Padding size.
|
19 |
+
|
20 |
+
"""
|
21 |
+
super(TFReflectionPad1d, self).__init__()
|
22 |
+
self.padding_size = padding_size
|
23 |
+
|
24 |
+
@tf.function
|
25 |
+
def call(self, x):
|
26 |
+
"""Calculate forward propagation.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
x (Tensor): Input tensor (B, T, 1, C).
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
|
33 |
+
|
34 |
+
"""
|
35 |
+
return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
|
36 |
+
|
37 |
+
|
38 |
+
class TFConvTranspose1d(tf.keras.layers.Layer):
|
39 |
+
"""Tensorflow ConvTranspose1d module."""
|
40 |
+
|
41 |
+
def __init__(self, channels, kernel_size, stride, padding):
|
42 |
+
"""Initialize TFConvTranspose1d( module.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
channels (int): Number of channels.
|
46 |
+
kernel_size (int): kernel size.
|
47 |
+
strides (int): Stride width.
|
48 |
+
padding (str): Padding type ("same" or "valid").
|
49 |
+
|
50 |
+
"""
|
51 |
+
super(TFConvTranspose1d, self).__init__()
|
52 |
+
self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
|
53 |
+
filters=channels,
|
54 |
+
kernel_size=(kernel_size, 1),
|
55 |
+
strides=(stride, 1),
|
56 |
+
padding=padding,
|
57 |
+
)
|
58 |
+
|
59 |
+
@tf.function
|
60 |
+
def call(self, x):
|
61 |
+
"""Calculate forward propagation.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
x (Tensor): Input tensor (B, T, 1, C).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensors: Output tensor (B, T', 1, C').
|
68 |
+
|
69 |
+
"""
|
70 |
+
x = self.conv1d_transpose(x)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class TFResidualStack(tf.keras.layers.Layer):
|
75 |
+
"""Tensorflow ResidualStack module."""
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
kernel_size,
|
79 |
+
channels,
|
80 |
+
dilation,
|
81 |
+
bias,
|
82 |
+
nonlinear_activation,
|
83 |
+
nonlinear_activation_params,
|
84 |
+
padding,
|
85 |
+
):
|
86 |
+
"""Initialize TFResidualStack module.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
kernel_size (int): Kernel size.
|
90 |
+
channles (int): Number of channels.
|
91 |
+
dilation (int): Dilation ine.
|
92 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
93 |
+
nonlinear_activation (str): Activation function module name.
|
94 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
95 |
+
padding (str): Padding type ("same" or "valid").
|
96 |
+
|
97 |
+
"""
|
98 |
+
super(TFResidualStack, self).__init__()
|
99 |
+
self.block = [
|
100 |
+
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
|
101 |
+
TFReflectionPad1d(dilation),
|
102 |
+
tf.keras.layers.Conv2D(
|
103 |
+
filters=channels,
|
104 |
+
kernel_size=(kernel_size, 1),
|
105 |
+
dilation_rate=(dilation, 1),
|
106 |
+
use_bias=bias,
|
107 |
+
padding="valid",
|
108 |
+
),
|
109 |
+
getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
|
110 |
+
tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
|
111 |
+
]
|
112 |
+
self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
|
113 |
+
|
114 |
+
@tf.function
|
115 |
+
def call(self, x):
|
116 |
+
"""Calculate forward propagation.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
x (Tensor): Input tensor (B, T, 1, C).
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Tensor: Output tensor (B, T, 1, C).
|
123 |
+
|
124 |
+
"""
|
125 |
+
_x = tf.identity(x)
|
126 |
+
for i, layer in enumerate(self.block):
|
127 |
+
_x = layer(_x)
|
128 |
+
shortcut = self.shortcut(x)
|
129 |
+
return shortcut + _x
|
modules/parallel_wavegan/layers/upsample.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""Upsampling module.
|
4 |
+
|
5 |
+
This code is modified from https://github.com/r9y9/wavenet_vocoder.
|
6 |
+
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from . import Conv1d
|
14 |
+
|
15 |
+
|
16 |
+
class Stretch2d(torch.nn.Module):
|
17 |
+
"""Stretch2d module."""
|
18 |
+
|
19 |
+
def __init__(self, x_scale, y_scale, mode="nearest"):
|
20 |
+
"""Initialize Stretch2d module.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
x_scale (int): X scaling factor (Time axis in spectrogram).
|
24 |
+
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
|
25 |
+
mode (str): Interpolation mode.
|
26 |
+
|
27 |
+
"""
|
28 |
+
super(Stretch2d, self).__init__()
|
29 |
+
self.x_scale = x_scale
|
30 |
+
self.y_scale = y_scale
|
31 |
+
self.mode = mode
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
"""Calculate forward propagation.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (Tensor): Input tensor (B, C, F, T).
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
41 |
+
|
42 |
+
"""
|
43 |
+
return F.interpolate(
|
44 |
+
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
45 |
+
|
46 |
+
|
47 |
+
class Conv2d(torch.nn.Conv2d):
|
48 |
+
"""Conv2d module with customized initialization."""
|
49 |
+
|
50 |
+
def __init__(self, *args, **kwargs):
|
51 |
+
"""Initialize Conv2d module."""
|
52 |
+
super(Conv2d, self).__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
def reset_parameters(self):
|
55 |
+
"""Reset parameters."""
|
56 |
+
self.weight.data.fill_(1. / np.prod(self.kernel_size))
|
57 |
+
if self.bias is not None:
|
58 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
59 |
+
|
60 |
+
|
61 |
+
class UpsampleNetwork(torch.nn.Module):
|
62 |
+
"""Upsampling network module."""
|
63 |
+
|
64 |
+
def __init__(self,
|
65 |
+
upsample_scales,
|
66 |
+
nonlinear_activation=None,
|
67 |
+
nonlinear_activation_params={},
|
68 |
+
interpolate_mode="nearest",
|
69 |
+
freq_axis_kernel_size=1,
|
70 |
+
use_causal_conv=False,
|
71 |
+
):
|
72 |
+
"""Initialize upsampling network module.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
upsample_scales (list): List of upsampling scales.
|
76 |
+
nonlinear_activation (str): Activation function name.
|
77 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
78 |
+
interpolate_mode (str): Interpolation mode.
|
79 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
80 |
+
|
81 |
+
"""
|
82 |
+
super(UpsampleNetwork, self).__init__()
|
83 |
+
self.use_causal_conv = use_causal_conv
|
84 |
+
self.up_layers = torch.nn.ModuleList()
|
85 |
+
for scale in upsample_scales:
|
86 |
+
# interpolation layer
|
87 |
+
stretch = Stretch2d(scale, 1, interpolate_mode)
|
88 |
+
self.up_layers += [stretch]
|
89 |
+
|
90 |
+
# conv layer
|
91 |
+
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
92 |
+
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
93 |
+
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
94 |
+
if use_causal_conv:
|
95 |
+
padding = (freq_axis_padding, scale * 2)
|
96 |
+
else:
|
97 |
+
padding = (freq_axis_padding, scale)
|
98 |
+
conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
99 |
+
self.up_layers += [conv]
|
100 |
+
|
101 |
+
# nonlinear
|
102 |
+
if nonlinear_activation is not None:
|
103 |
+
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
104 |
+
self.up_layers += [nonlinear]
|
105 |
+
|
106 |
+
def forward(self, c):
|
107 |
+
"""Calculate forward propagation.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
c : Input tensor (B, C, T).
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
|
114 |
+
|
115 |
+
"""
|
116 |
+
c = c.unsqueeze(1) # (B, 1, C, T)
|
117 |
+
for f in self.up_layers:
|
118 |
+
if self.use_causal_conv and isinstance(f, Conv2d):
|
119 |
+
c = f(c)[..., :c.size(-1)]
|
120 |
+
else:
|
121 |
+
c = f(c)
|
122 |
+
return c.squeeze(1) # (B, C, T')
|
123 |
+
|
124 |
+
|
125 |
+
class ConvInUpsampleNetwork(torch.nn.Module):
|
126 |
+
"""Convolution + upsampling network module."""
|
127 |
+
|
128 |
+
def __init__(self,
|
129 |
+
upsample_scales,
|
130 |
+
nonlinear_activation=None,
|
131 |
+
nonlinear_activation_params={},
|
132 |
+
interpolate_mode="nearest",
|
133 |
+
freq_axis_kernel_size=1,
|
134 |
+
aux_channels=80,
|
135 |
+
aux_context_window=0,
|
136 |
+
use_causal_conv=False
|
137 |
+
):
|
138 |
+
"""Initialize convolution + upsampling network module.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
upsample_scales (list): List of upsampling scales.
|
142 |
+
nonlinear_activation (str): Activation function name.
|
143 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
144 |
+
mode (str): Interpolation mode.
|
145 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
146 |
+
aux_channels (int): Number of channels of pre-convolutional layer.
|
147 |
+
aux_context_window (int): Context window size of the pre-convolutional layer.
|
148 |
+
use_causal_conv (bool): Whether to use causal structure.
|
149 |
+
|
150 |
+
"""
|
151 |
+
super(ConvInUpsampleNetwork, self).__init__()
|
152 |
+
self.aux_context_window = aux_context_window
|
153 |
+
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
154 |
+
# To capture wide-context information in conditional features
|
155 |
+
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
156 |
+
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
157 |
+
self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
158 |
+
self.upsample = UpsampleNetwork(
|
159 |
+
upsample_scales=upsample_scales,
|
160 |
+
nonlinear_activation=nonlinear_activation,
|
161 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
162 |
+
interpolate_mode=interpolate_mode,
|
163 |
+
freq_axis_kernel_size=freq_axis_kernel_size,
|
164 |
+
use_causal_conv=use_causal_conv,
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, c):
|
168 |
+
"""Calculate forward propagation.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
c : Input tensor (B, C, T').
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
Tensor: Upsampled tensor (B, C, T),
|
175 |
+
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
|
176 |
+
|
177 |
+
Note:
|
178 |
+
The length of inputs considers the context window size.
|
179 |
+
|
180 |
+
"""
|
181 |
+
c_ = self.conv_in(c)
|
182 |
+
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
183 |
+
return self.upsample(c)
|
modules/parallel_wavegan/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .stft_loss import * # NOQA
|
modules/parallel_wavegan/losses/stft_loss.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""STFT-based Loss modules."""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def stft(x, fft_size, hop_size, win_length, window):
|
13 |
+
"""Perform STFT and convert to magnitude spectrogram.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
x (Tensor): Input signal tensor (B, T).
|
17 |
+
fft_size (int): FFT size.
|
18 |
+
hop_size (int): Hop size.
|
19 |
+
win_length (int): Window length.
|
20 |
+
window (str): Window function type.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
24 |
+
|
25 |
+
"""
|
26 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
27 |
+
real = x_stft[..., 0]
|
28 |
+
imag = x_stft[..., 1]
|
29 |
+
|
30 |
+
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
31 |
+
return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
|
32 |
+
|
33 |
+
|
34 |
+
class SpectralConvergengeLoss(torch.nn.Module):
|
35 |
+
"""Spectral convergence loss module."""
|
36 |
+
|
37 |
+
def __init__(self):
|
38 |
+
"""Initilize spectral convergence loss module."""
|
39 |
+
super(SpectralConvergengeLoss, self).__init__()
|
40 |
+
|
41 |
+
def forward(self, x_mag, y_mag):
|
42 |
+
"""Calculate forward propagation.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
46 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
Tensor: Spectral convergence loss value.
|
50 |
+
|
51 |
+
"""
|
52 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
53 |
+
|
54 |
+
|
55 |
+
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
56 |
+
"""Log STFT magnitude loss module."""
|
57 |
+
|
58 |
+
def __init__(self):
|
59 |
+
"""Initilize los STFT magnitude loss module."""
|
60 |
+
super(LogSTFTMagnitudeLoss, self).__init__()
|
61 |
+
|
62 |
+
def forward(self, x_mag, y_mag):
|
63 |
+
"""Calculate forward propagation.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
67 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Tensor: Log STFT magnitude loss value.
|
71 |
+
|
72 |
+
"""
|
73 |
+
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
74 |
+
|
75 |
+
|
76 |
+
class STFTLoss(torch.nn.Module):
|
77 |
+
"""STFT loss module."""
|
78 |
+
|
79 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
|
80 |
+
"""Initialize STFT loss module."""
|
81 |
+
super(STFTLoss, self).__init__()
|
82 |
+
self.fft_size = fft_size
|
83 |
+
self.shift_size = shift_size
|
84 |
+
self.win_length = win_length
|
85 |
+
self.window = getattr(torch, window)(win_length)
|
86 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
87 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
88 |
+
|
89 |
+
def forward(self, x, y):
|
90 |
+
"""Calculate forward propagation.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
x (Tensor): Predicted signal (B, T).
|
94 |
+
y (Tensor): Groundtruth signal (B, T).
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Tensor: Spectral convergence loss value.
|
98 |
+
Tensor: Log STFT magnitude loss value.
|
99 |
+
|
100 |
+
"""
|
101 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
102 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
103 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
104 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
105 |
+
|
106 |
+
return sc_loss, mag_loss
|
107 |
+
|
108 |
+
|
109 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
110 |
+
"""Multi resolution STFT loss module."""
|
111 |
+
|
112 |
+
def __init__(self,
|
113 |
+
fft_sizes=[1024, 2048, 512],
|
114 |
+
hop_sizes=[120, 240, 50],
|
115 |
+
win_lengths=[600, 1200, 240],
|
116 |
+
window="hann_window"):
|
117 |
+
"""Initialize Multi resolution STFT loss module.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
fft_sizes (list): List of FFT sizes.
|
121 |
+
hop_sizes (list): List of hop sizes.
|
122 |
+
win_lengths (list): List of window lengths.
|
123 |
+
window (str): Window function type.
|
124 |
+
|
125 |
+
"""
|
126 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
127 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
128 |
+
self.stft_losses = torch.nn.ModuleList()
|
129 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
130 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
131 |
+
|
132 |
+
def forward(self, x, y):
|
133 |
+
"""Calculate forward propagation.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
x (Tensor): Predicted signal (B, T).
|
137 |
+
y (Tensor): Groundtruth signal (B, T).
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Tensor: Multi resolution spectral convergence loss value.
|
141 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
142 |
+
|
143 |
+
"""
|
144 |
+
sc_loss = 0.0
|
145 |
+
mag_loss = 0.0
|
146 |
+
for f in self.stft_losses:
|
147 |
+
sc_l, mag_l = f(x, y)
|
148 |
+
sc_loss += sc_l
|
149 |
+
mag_loss += mag_l
|
150 |
+
sc_loss /= len(self.stft_losses)
|
151 |
+
mag_loss /= len(self.stft_losses)
|
152 |
+
|
153 |
+
return sc_loss, mag_loss
|
modules/parallel_wavegan/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .melgan import * # NOQA
|
2 |
+
from .parallel_wavegan import * # NOQA
|
modules/parallel_wavegan/models/melgan.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2020 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""MelGAN Modules."""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from modules.parallel_wavegan.layers import CausalConv1d
|
14 |
+
from modules.parallel_wavegan.layers import CausalConvTranspose1d
|
15 |
+
from modules.parallel_wavegan.layers import ResidualStack
|
16 |
+
|
17 |
+
|
18 |
+
class MelGANGenerator(torch.nn.Module):
|
19 |
+
"""MelGAN generator module."""
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
in_channels=80,
|
23 |
+
out_channels=1,
|
24 |
+
kernel_size=7,
|
25 |
+
channels=512,
|
26 |
+
bias=True,
|
27 |
+
upsample_scales=[8, 8, 2, 2],
|
28 |
+
stack_kernel_size=3,
|
29 |
+
stacks=3,
|
30 |
+
nonlinear_activation="LeakyReLU",
|
31 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
32 |
+
pad="ReflectionPad1d",
|
33 |
+
pad_params={},
|
34 |
+
use_final_nonlinear_activation=True,
|
35 |
+
use_weight_norm=True,
|
36 |
+
use_causal_conv=False,
|
37 |
+
):
|
38 |
+
"""Initialize MelGANGenerator module.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
in_channels (int): Number of input channels.
|
42 |
+
out_channels (int): Number of output channels.
|
43 |
+
kernel_size (int): Kernel size of initial and final conv layer.
|
44 |
+
channels (int): Initial number of channels for conv layer.
|
45 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
46 |
+
upsample_scales (list): List of upsampling scales.
|
47 |
+
stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
|
48 |
+
stacks (int): Number of stacks in a single residual stack.
|
49 |
+
nonlinear_activation (str): Activation function module name.
|
50 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
51 |
+
pad (str): Padding function module name before dilated convolution layer.
|
52 |
+
pad_params (dict): Hyperparameters for padding function.
|
53 |
+
use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
|
54 |
+
use_weight_norm (bool): Whether to use weight norm.
|
55 |
+
If set to true, it will be applied to all of the conv layers.
|
56 |
+
use_causal_conv (bool): Whether to use causal convolution.
|
57 |
+
|
58 |
+
"""
|
59 |
+
super(MelGANGenerator, self).__init__()
|
60 |
+
|
61 |
+
# check hyper parameters is valid
|
62 |
+
assert channels >= np.prod(upsample_scales)
|
63 |
+
assert channels % (2 ** len(upsample_scales)) == 0
|
64 |
+
if not use_causal_conv:
|
65 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
66 |
+
|
67 |
+
# add initial layer
|
68 |
+
layers = []
|
69 |
+
if not use_causal_conv:
|
70 |
+
layers += [
|
71 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
|
72 |
+
torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
|
73 |
+
]
|
74 |
+
else:
|
75 |
+
layers += [
|
76 |
+
CausalConv1d(in_channels, channels, kernel_size,
|
77 |
+
bias=bias, pad=pad, pad_params=pad_params),
|
78 |
+
]
|
79 |
+
|
80 |
+
for i, upsample_scale in enumerate(upsample_scales):
|
81 |
+
# add upsampling layer
|
82 |
+
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
|
83 |
+
if not use_causal_conv:
|
84 |
+
layers += [
|
85 |
+
torch.nn.ConvTranspose1d(
|
86 |
+
channels // (2 ** i),
|
87 |
+
channels // (2 ** (i + 1)),
|
88 |
+
upsample_scale * 2,
|
89 |
+
stride=upsample_scale,
|
90 |
+
padding=upsample_scale // 2 + upsample_scale % 2,
|
91 |
+
output_padding=upsample_scale % 2,
|
92 |
+
bias=bias,
|
93 |
+
)
|
94 |
+
]
|
95 |
+
else:
|
96 |
+
layers += [
|
97 |
+
CausalConvTranspose1d(
|
98 |
+
channels // (2 ** i),
|
99 |
+
channels // (2 ** (i + 1)),
|
100 |
+
upsample_scale * 2,
|
101 |
+
stride=upsample_scale,
|
102 |
+
bias=bias,
|
103 |
+
)
|
104 |
+
]
|
105 |
+
|
106 |
+
# add residual stack
|
107 |
+
for j in range(stacks):
|
108 |
+
layers += [
|
109 |
+
ResidualStack(
|
110 |
+
kernel_size=stack_kernel_size,
|
111 |
+
channels=channels // (2 ** (i + 1)),
|
112 |
+
dilation=stack_kernel_size ** j,
|
113 |
+
bias=bias,
|
114 |
+
nonlinear_activation=nonlinear_activation,
|
115 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
116 |
+
pad=pad,
|
117 |
+
pad_params=pad_params,
|
118 |
+
use_causal_conv=use_causal_conv,
|
119 |
+
)
|
120 |
+
]
|
121 |
+
|
122 |
+
# add final layer
|
123 |
+
layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
|
124 |
+
if not use_causal_conv:
|
125 |
+
layers += [
|
126 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
|
127 |
+
torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
|
128 |
+
]
|
129 |
+
else:
|
130 |
+
layers += [
|
131 |
+
CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
|
132 |
+
bias=bias, pad=pad, pad_params=pad_params),
|
133 |
+
]
|
134 |
+
if use_final_nonlinear_activation:
|
135 |
+
layers += [torch.nn.Tanh()]
|
136 |
+
|
137 |
+
# define the model as a single function
|
138 |
+
self.melgan = torch.nn.Sequential(*layers)
|
139 |
+
|
140 |
+
# apply weight norm
|
141 |
+
if use_weight_norm:
|
142 |
+
self.apply_weight_norm()
|
143 |
+
|
144 |
+
# reset parameters
|
145 |
+
self.reset_parameters()
|
146 |
+
|
147 |
+
def forward(self, c):
|
148 |
+
"""Calculate forward propagation.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
c (Tensor): Input tensor (B, channels, T).
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
|
155 |
+
|
156 |
+
"""
|
157 |
+
return self.melgan(c)
|
158 |
+
|
159 |
+
def remove_weight_norm(self):
|
160 |
+
"""Remove weight normalization module from all of the layers."""
|
161 |
+
def _remove_weight_norm(m):
|
162 |
+
try:
|
163 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
164 |
+
torch.nn.utils.remove_weight_norm(m)
|
165 |
+
except ValueError: # this module didn't have weight norm
|
166 |
+
return
|
167 |
+
|
168 |
+
self.apply(_remove_weight_norm)
|
169 |
+
|
170 |
+
def apply_weight_norm(self):
|
171 |
+
"""Apply weight normalization module from all of the layers."""
|
172 |
+
def _apply_weight_norm(m):
|
173 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
174 |
+
torch.nn.utils.weight_norm(m)
|
175 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
176 |
+
|
177 |
+
self.apply(_apply_weight_norm)
|
178 |
+
|
179 |
+
def reset_parameters(self):
|
180 |
+
"""Reset parameters.
|
181 |
+
|
182 |
+
This initialization follows official implementation manner.
|
183 |
+
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
|
184 |
+
|
185 |
+
"""
|
186 |
+
def _reset_parameters(m):
|
187 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
188 |
+
m.weight.data.normal_(0.0, 0.02)
|
189 |
+
logging.debug(f"Reset parameters in {m}.")
|
190 |
+
|
191 |
+
self.apply(_reset_parameters)
|
192 |
+
|
193 |
+
|
194 |
+
class MelGANDiscriminator(torch.nn.Module):
|
195 |
+
"""MelGAN discriminator module."""
|
196 |
+
|
197 |
+
def __init__(self,
|
198 |
+
in_channels=1,
|
199 |
+
out_channels=1,
|
200 |
+
kernel_sizes=[5, 3],
|
201 |
+
channels=16,
|
202 |
+
max_downsample_channels=1024,
|
203 |
+
bias=True,
|
204 |
+
downsample_scales=[4, 4, 4, 4],
|
205 |
+
nonlinear_activation="LeakyReLU",
|
206 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
207 |
+
pad="ReflectionPad1d",
|
208 |
+
pad_params={},
|
209 |
+
):
|
210 |
+
"""Initilize MelGAN discriminator module.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
in_channels (int): Number of input channels.
|
214 |
+
out_channels (int): Number of output channels.
|
215 |
+
kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
|
216 |
+
and the first and the second kernel sizes will be used for the last two layers.
|
217 |
+
For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
|
218 |
+
the last two layers' kernel size will be 5 and 3, respectively.
|
219 |
+
channels (int): Initial number of channels for conv layer.
|
220 |
+
max_downsample_channels (int): Maximum number of channels for downsampling layers.
|
221 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
222 |
+
downsample_scales (list): List of downsampling scales.
|
223 |
+
nonlinear_activation (str): Activation function module name.
|
224 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
225 |
+
pad (str): Padding function module name before dilated convolution layer.
|
226 |
+
pad_params (dict): Hyperparameters for padding function.
|
227 |
+
|
228 |
+
"""
|
229 |
+
super(MelGANDiscriminator, self).__init__()
|
230 |
+
self.layers = torch.nn.ModuleList()
|
231 |
+
|
232 |
+
# check kernel size is valid
|
233 |
+
assert len(kernel_sizes) == 2
|
234 |
+
assert kernel_sizes[0] % 2 == 1
|
235 |
+
assert kernel_sizes[1] % 2 == 1
|
236 |
+
|
237 |
+
# add first layer
|
238 |
+
self.layers += [
|
239 |
+
torch.nn.Sequential(
|
240 |
+
getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
|
241 |
+
torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
|
242 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
243 |
+
)
|
244 |
+
]
|
245 |
+
|
246 |
+
# add downsample layers
|
247 |
+
in_chs = channels
|
248 |
+
for downsample_scale in downsample_scales:
|
249 |
+
out_chs = min(in_chs * downsample_scale, max_downsample_channels)
|
250 |
+
self.layers += [
|
251 |
+
torch.nn.Sequential(
|
252 |
+
torch.nn.Conv1d(
|
253 |
+
in_chs, out_chs,
|
254 |
+
kernel_size=downsample_scale * 10 + 1,
|
255 |
+
stride=downsample_scale,
|
256 |
+
padding=downsample_scale * 5,
|
257 |
+
groups=in_chs // 4,
|
258 |
+
bias=bias,
|
259 |
+
),
|
260 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
261 |
+
)
|
262 |
+
]
|
263 |
+
in_chs = out_chs
|
264 |
+
|
265 |
+
# add final layers
|
266 |
+
out_chs = min(in_chs * 2, max_downsample_channels)
|
267 |
+
self.layers += [
|
268 |
+
torch.nn.Sequential(
|
269 |
+
torch.nn.Conv1d(
|
270 |
+
in_chs, out_chs, kernel_sizes[0],
|
271 |
+
padding=(kernel_sizes[0] - 1) // 2,
|
272 |
+
bias=bias,
|
273 |
+
),
|
274 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
275 |
+
)
|
276 |
+
]
|
277 |
+
self.layers += [
|
278 |
+
torch.nn.Conv1d(
|
279 |
+
out_chs, out_channels, kernel_sizes[1],
|
280 |
+
padding=(kernel_sizes[1] - 1) // 2,
|
281 |
+
bias=bias,
|
282 |
+
),
|
283 |
+
]
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
"""Calculate forward propagation.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
x (Tensor): Input noise signal (B, 1, T).
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
List: List of output tensors of each layer.
|
293 |
+
|
294 |
+
"""
|
295 |
+
outs = []
|
296 |
+
for f in self.layers:
|
297 |
+
x = f(x)
|
298 |
+
outs += [x]
|
299 |
+
|
300 |
+
return outs
|
301 |
+
|
302 |
+
|
303 |
+
class MelGANMultiScaleDiscriminator(torch.nn.Module):
|
304 |
+
"""MelGAN multi-scale discriminator module."""
|
305 |
+
|
306 |
+
def __init__(self,
|
307 |
+
in_channels=1,
|
308 |
+
out_channels=1,
|
309 |
+
scales=3,
|
310 |
+
downsample_pooling="AvgPool1d",
|
311 |
+
# follow the official implementation setting
|
312 |
+
downsample_pooling_params={
|
313 |
+
"kernel_size": 4,
|
314 |
+
"stride": 2,
|
315 |
+
"padding": 1,
|
316 |
+
"count_include_pad": False,
|
317 |
+
},
|
318 |
+
kernel_sizes=[5, 3],
|
319 |
+
channels=16,
|
320 |
+
max_downsample_channels=1024,
|
321 |
+
bias=True,
|
322 |
+
downsample_scales=[4, 4, 4, 4],
|
323 |
+
nonlinear_activation="LeakyReLU",
|
324 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
325 |
+
pad="ReflectionPad1d",
|
326 |
+
pad_params={},
|
327 |
+
use_weight_norm=True,
|
328 |
+
):
|
329 |
+
"""Initilize MelGAN multi-scale discriminator module.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
in_channels (int): Number of input channels.
|
333 |
+
out_channels (int): Number of output channels.
|
334 |
+
downsample_pooling (str): Pooling module name for downsampling of the inputs.
|
335 |
+
downsample_pooling_params (dict): Parameters for the above pooling module.
|
336 |
+
kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
|
337 |
+
and the first and the second kernel sizes will be used for the last two layers.
|
338 |
+
channels (int): Initial number of channels for conv layer.
|
339 |
+
max_downsample_channels (int): Maximum number of channels for downsampling layers.
|
340 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
341 |
+
downsample_scales (list): List of downsampling scales.
|
342 |
+
nonlinear_activation (str): Activation function module name.
|
343 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
344 |
+
pad (str): Padding function module name before dilated convolution layer.
|
345 |
+
pad_params (dict): Hyperparameters for padding function.
|
346 |
+
use_causal_conv (bool): Whether to use causal convolution.
|
347 |
+
|
348 |
+
"""
|
349 |
+
super(MelGANMultiScaleDiscriminator, self).__init__()
|
350 |
+
self.discriminators = torch.nn.ModuleList()
|
351 |
+
|
352 |
+
# add discriminators
|
353 |
+
for _ in range(scales):
|
354 |
+
self.discriminators += [
|
355 |
+
MelGANDiscriminator(
|
356 |
+
in_channels=in_channels,
|
357 |
+
out_channels=out_channels,
|
358 |
+
kernel_sizes=kernel_sizes,
|
359 |
+
channels=channels,
|
360 |
+
max_downsample_channels=max_downsample_channels,
|
361 |
+
bias=bias,
|
362 |
+
downsample_scales=downsample_scales,
|
363 |
+
nonlinear_activation=nonlinear_activation,
|
364 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
365 |
+
pad=pad,
|
366 |
+
pad_params=pad_params,
|
367 |
+
)
|
368 |
+
]
|
369 |
+
self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
|
370 |
+
|
371 |
+
# apply weight norm
|
372 |
+
if use_weight_norm:
|
373 |
+
self.apply_weight_norm()
|
374 |
+
|
375 |
+
# reset parameters
|
376 |
+
self.reset_parameters()
|
377 |
+
|
378 |
+
def forward(self, x):
|
379 |
+
"""Calculate forward propagation.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
x (Tensor): Input noise signal (B, 1, T).
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
List: List of list of each discriminator outputs, which consists of each layer output tensors.
|
386 |
+
|
387 |
+
"""
|
388 |
+
outs = []
|
389 |
+
for f in self.discriminators:
|
390 |
+
outs += [f(x)]
|
391 |
+
x = self.pooling(x)
|
392 |
+
|
393 |
+
return outs
|
394 |
+
|
395 |
+
def remove_weight_norm(self):
|
396 |
+
"""Remove weight normalization module from all of the layers."""
|
397 |
+
def _remove_weight_norm(m):
|
398 |
+
try:
|
399 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
400 |
+
torch.nn.utils.remove_weight_norm(m)
|
401 |
+
except ValueError: # this module didn't have weight norm
|
402 |
+
return
|
403 |
+
|
404 |
+
self.apply(_remove_weight_norm)
|
405 |
+
|
406 |
+
def apply_weight_norm(self):
|
407 |
+
"""Apply weight normalization module from all of the layers."""
|
408 |
+
def _apply_weight_norm(m):
|
409 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
410 |
+
torch.nn.utils.weight_norm(m)
|
411 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
412 |
+
|
413 |
+
self.apply(_apply_weight_norm)
|
414 |
+
|
415 |
+
def reset_parameters(self):
|
416 |
+
"""Reset parameters.
|
417 |
+
|
418 |
+
This initialization follows official implementation manner.
|
419 |
+
https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
|
420 |
+
|
421 |
+
"""
|
422 |
+
def _reset_parameters(m):
|
423 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
|
424 |
+
m.weight.data.normal_(0.0, 0.02)
|
425 |
+
logging.debug(f"Reset parameters in {m}.")
|
426 |
+
|
427 |
+
self.apply(_reset_parameters)
|
modules/parallel_wavegan/models/parallel_wavegan.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Parallel WaveGAN Modules."""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from modules.parallel_wavegan.layers import Conv1d
|
15 |
+
from modules.parallel_wavegan.layers import Conv1d1x1
|
16 |
+
from modules.parallel_wavegan.layers import ResidualBlock
|
17 |
+
from modules.parallel_wavegan.layers import upsample
|
18 |
+
from modules.parallel_wavegan import models
|
19 |
+
|
20 |
+
|
21 |
+
class ParallelWaveGANGenerator(torch.nn.Module):
|
22 |
+
"""Parallel WaveGAN Generator module."""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
in_channels=1,
|
26 |
+
out_channels=1,
|
27 |
+
kernel_size=3,
|
28 |
+
layers=30,
|
29 |
+
stacks=3,
|
30 |
+
residual_channels=64,
|
31 |
+
gate_channels=128,
|
32 |
+
skip_channels=64,
|
33 |
+
aux_channels=80,
|
34 |
+
aux_context_window=2,
|
35 |
+
dropout=0.0,
|
36 |
+
bias=True,
|
37 |
+
use_weight_norm=True,
|
38 |
+
use_causal_conv=False,
|
39 |
+
upsample_conditional_features=True,
|
40 |
+
upsample_net="ConvInUpsampleNetwork",
|
41 |
+
upsample_params={"upsample_scales": [4, 4, 4, 4]},
|
42 |
+
use_pitch_embed=False,
|
43 |
+
):
|
44 |
+
"""Initialize Parallel WaveGAN Generator module.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
in_channels (int): Number of input channels.
|
48 |
+
out_channels (int): Number of output channels.
|
49 |
+
kernel_size (int): Kernel size of dilated convolution.
|
50 |
+
layers (int): Number of residual block layers.
|
51 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
52 |
+
residual_channels (int): Number of channels in residual conv.
|
53 |
+
gate_channels (int): Number of channels in gated conv.
|
54 |
+
skip_channels (int): Number of channels in skip conv.
|
55 |
+
aux_channels (int): Number of channels for auxiliary feature conv.
|
56 |
+
aux_context_window (int): Context window size for auxiliary feature.
|
57 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
58 |
+
bias (bool): Whether to use bias parameter in conv layer.
|
59 |
+
use_weight_norm (bool): Whether to use weight norm.
|
60 |
+
If set to true, it will be applied to all of the conv layers.
|
61 |
+
use_causal_conv (bool): Whether to use causal structure.
|
62 |
+
upsample_conditional_features (bool): Whether to use upsampling network.
|
63 |
+
upsample_net (str): Upsampling network architecture.
|
64 |
+
upsample_params (dict): Upsampling network parameters.
|
65 |
+
|
66 |
+
"""
|
67 |
+
super(ParallelWaveGANGenerator, self).__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
self.out_channels = out_channels
|
70 |
+
self.aux_channels = aux_channels
|
71 |
+
self.layers = layers
|
72 |
+
self.stacks = stacks
|
73 |
+
self.kernel_size = kernel_size
|
74 |
+
|
75 |
+
# check the number of layers and stacks
|
76 |
+
assert layers % stacks == 0
|
77 |
+
layers_per_stack = layers // stacks
|
78 |
+
|
79 |
+
# define first convolution
|
80 |
+
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
81 |
+
|
82 |
+
# define conv + upsampling network
|
83 |
+
if upsample_conditional_features:
|
84 |
+
upsample_params.update({
|
85 |
+
"use_causal_conv": use_causal_conv,
|
86 |
+
})
|
87 |
+
if upsample_net == "MelGANGenerator":
|
88 |
+
assert aux_context_window == 0
|
89 |
+
upsample_params.update({
|
90 |
+
"use_weight_norm": False, # not to apply twice
|
91 |
+
"use_final_nonlinear_activation": False,
|
92 |
+
})
|
93 |
+
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
|
94 |
+
else:
|
95 |
+
if upsample_net == "ConvInUpsampleNetwork":
|
96 |
+
upsample_params.update({
|
97 |
+
"aux_channels": aux_channels,
|
98 |
+
"aux_context_window": aux_context_window,
|
99 |
+
})
|
100 |
+
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
101 |
+
else:
|
102 |
+
self.upsample_net = None
|
103 |
+
|
104 |
+
# define residual blocks
|
105 |
+
self.conv_layers = torch.nn.ModuleList()
|
106 |
+
for layer in range(layers):
|
107 |
+
dilation = 2 ** (layer % layers_per_stack)
|
108 |
+
conv = ResidualBlock(
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
residual_channels=residual_channels,
|
111 |
+
gate_channels=gate_channels,
|
112 |
+
skip_channels=skip_channels,
|
113 |
+
aux_channels=aux_channels,
|
114 |
+
dilation=dilation,
|
115 |
+
dropout=dropout,
|
116 |
+
bias=bias,
|
117 |
+
use_causal_conv=use_causal_conv,
|
118 |
+
)
|
119 |
+
self.conv_layers += [conv]
|
120 |
+
|
121 |
+
# define output layers
|
122 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
123 |
+
torch.nn.ReLU(inplace=True),
|
124 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
125 |
+
torch.nn.ReLU(inplace=True),
|
126 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
127 |
+
])
|
128 |
+
|
129 |
+
self.use_pitch_embed = use_pitch_embed
|
130 |
+
if use_pitch_embed:
|
131 |
+
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
|
132 |
+
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
|
133 |
+
|
134 |
+
# apply weight norm
|
135 |
+
if use_weight_norm:
|
136 |
+
self.apply_weight_norm()
|
137 |
+
|
138 |
+
def forward(self, x, c=None, pitch=None, **kwargs):
|
139 |
+
"""Calculate forward propagation.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (Tensor): Input noise signal (B, C_in, T).
|
143 |
+
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
144 |
+
pitch (Tensor): Local conditioning pitch (B, T').
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Tensor: Output tensor (B, C_out, T)
|
148 |
+
|
149 |
+
"""
|
150 |
+
# perform upsampling
|
151 |
+
if c is not None and self.upsample_net is not None:
|
152 |
+
if self.use_pitch_embed:
|
153 |
+
p = self.pitch_embed(pitch)
|
154 |
+
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
|
155 |
+
c = self.upsample_net(c)
|
156 |
+
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
|
157 |
+
|
158 |
+
# encode to hidden representation
|
159 |
+
x = self.first_conv(x)
|
160 |
+
skips = 0
|
161 |
+
for f in self.conv_layers:
|
162 |
+
x, h = f(x, c)
|
163 |
+
skips += h
|
164 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
165 |
+
|
166 |
+
# apply final layers
|
167 |
+
x = skips
|
168 |
+
for f in self.last_conv_layers:
|
169 |
+
x = f(x)
|
170 |
+
|
171 |
+
return x
|
172 |
+
|
173 |
+
def remove_weight_norm(self):
|
174 |
+
"""Remove weight normalization module from all of the layers."""
|
175 |
+
def _remove_weight_norm(m):
|
176 |
+
try:
|
177 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
178 |
+
torch.nn.utils.remove_weight_norm(m)
|
179 |
+
except ValueError: # this module didn't have weight norm
|
180 |
+
return
|
181 |
+
|
182 |
+
self.apply(_remove_weight_norm)
|
183 |
+
|
184 |
+
def apply_weight_norm(self):
|
185 |
+
"""Apply weight normalization module from all of the layers."""
|
186 |
+
def _apply_weight_norm(m):
|
187 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
188 |
+
torch.nn.utils.weight_norm(m)
|
189 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
190 |
+
|
191 |
+
self.apply(_apply_weight_norm)
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def _get_receptive_field_size(layers, stacks, kernel_size,
|
195 |
+
dilation=lambda x: 2 ** x):
|
196 |
+
assert layers % stacks == 0
|
197 |
+
layers_per_cycle = layers // stacks
|
198 |
+
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
199 |
+
return (kernel_size - 1) * sum(dilations) + 1
|
200 |
+
|
201 |
+
@property
|
202 |
+
def receptive_field_size(self):
|
203 |
+
"""Return receptive field size."""
|
204 |
+
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
205 |
+
|
206 |
+
|
207 |
+
class ParallelWaveGANDiscriminator(torch.nn.Module):
|
208 |
+
"""Parallel WaveGAN Discriminator module."""
|
209 |
+
|
210 |
+
def __init__(self,
|
211 |
+
in_channels=1,
|
212 |
+
out_channels=1,
|
213 |
+
kernel_size=3,
|
214 |
+
layers=10,
|
215 |
+
conv_channels=64,
|
216 |
+
dilation_factor=1,
|
217 |
+
nonlinear_activation="LeakyReLU",
|
218 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
219 |
+
bias=True,
|
220 |
+
use_weight_norm=True,
|
221 |
+
):
|
222 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
in_channels (int): Number of input channels.
|
226 |
+
out_channels (int): Number of output channels.
|
227 |
+
kernel_size (int): Number of output channels.
|
228 |
+
layers (int): Number of conv layers.
|
229 |
+
conv_channels (int): Number of chnn layers.
|
230 |
+
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
|
231 |
+
the dilation will be 2, 4, 8, ..., and so on.
|
232 |
+
nonlinear_activation (str): Nonlinear function after each conv.
|
233 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
234 |
+
bias (bool): Whether to use bias parameter in conv.
|
235 |
+
use_weight_norm (bool) Whether to use weight norm.
|
236 |
+
If set to true, it will be applied to all of the conv layers.
|
237 |
+
|
238 |
+
"""
|
239 |
+
super(ParallelWaveGANDiscriminator, self).__init__()
|
240 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
241 |
+
assert dilation_factor > 0, "Dilation factor must be > 0."
|
242 |
+
self.conv_layers = torch.nn.ModuleList()
|
243 |
+
conv_in_channels = in_channels
|
244 |
+
for i in range(layers - 1):
|
245 |
+
if i == 0:
|
246 |
+
dilation = 1
|
247 |
+
else:
|
248 |
+
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
249 |
+
conv_in_channels = conv_channels
|
250 |
+
padding = (kernel_size - 1) // 2 * dilation
|
251 |
+
conv_layer = [
|
252 |
+
Conv1d(conv_in_channels, conv_channels,
|
253 |
+
kernel_size=kernel_size, padding=padding,
|
254 |
+
dilation=dilation, bias=bias),
|
255 |
+
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
256 |
+
]
|
257 |
+
self.conv_layers += conv_layer
|
258 |
+
padding = (kernel_size - 1) // 2
|
259 |
+
last_conv_layer = Conv1d(
|
260 |
+
conv_in_channels, out_channels,
|
261 |
+
kernel_size=kernel_size, padding=padding, bias=bias)
|
262 |
+
self.conv_layers += [last_conv_layer]
|
263 |
+
|
264 |
+
# apply weight norm
|
265 |
+
if use_weight_norm:
|
266 |
+
self.apply_weight_norm()
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
"""Calculate forward propagation.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
x (Tensor): Input noise signal (B, 1, T).
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Tensor: Output tensor (B, 1, T)
|
276 |
+
|
277 |
+
"""
|
278 |
+
for f in self.conv_layers:
|
279 |
+
x = f(x)
|
280 |
+
return x
|
281 |
+
|
282 |
+
def apply_weight_norm(self):
|
283 |
+
"""Apply weight normalization module from all of the layers."""
|
284 |
+
def _apply_weight_norm(m):
|
285 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
286 |
+
torch.nn.utils.weight_norm(m)
|
287 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
288 |
+
|
289 |
+
self.apply(_apply_weight_norm)
|
290 |
+
|
291 |
+
def remove_weight_norm(self):
|
292 |
+
"""Remove weight normalization module from all of the layers."""
|
293 |
+
def _remove_weight_norm(m):
|
294 |
+
try:
|
295 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
296 |
+
torch.nn.utils.remove_weight_norm(m)
|
297 |
+
except ValueError: # this module didn't have weight norm
|
298 |
+
return
|
299 |
+
|
300 |
+
self.apply(_remove_weight_norm)
|
301 |
+
|
302 |
+
|
303 |
+
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
|
304 |
+
"""Parallel WaveGAN Discriminator module."""
|
305 |
+
|
306 |
+
def __init__(self,
|
307 |
+
in_channels=1,
|
308 |
+
out_channels=1,
|
309 |
+
kernel_size=3,
|
310 |
+
layers=30,
|
311 |
+
stacks=3,
|
312 |
+
residual_channels=64,
|
313 |
+
gate_channels=128,
|
314 |
+
skip_channels=64,
|
315 |
+
dropout=0.0,
|
316 |
+
bias=True,
|
317 |
+
use_weight_norm=True,
|
318 |
+
use_causal_conv=False,
|
319 |
+
nonlinear_activation="LeakyReLU",
|
320 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
321 |
+
):
|
322 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
in_channels (int): Number of input channels.
|
326 |
+
out_channels (int): Number of output channels.
|
327 |
+
kernel_size (int): Kernel size of dilated convolution.
|
328 |
+
layers (int): Number of residual block layers.
|
329 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
330 |
+
residual_channels (int): Number of channels in residual conv.
|
331 |
+
gate_channels (int): Number of channels in gated conv.
|
332 |
+
skip_channels (int): Number of channels in skip conv.
|
333 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
334 |
+
bias (bool): Whether to use bias parameter in conv.
|
335 |
+
use_weight_norm (bool): Whether to use weight norm.
|
336 |
+
If set to true, it will be applied to all of the conv layers.
|
337 |
+
use_causal_conv (bool): Whether to use causal structure.
|
338 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
339 |
+
|
340 |
+
"""
|
341 |
+
super(ResidualParallelWaveGANDiscriminator, self).__init__()
|
342 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
343 |
+
|
344 |
+
self.in_channels = in_channels
|
345 |
+
self.out_channels = out_channels
|
346 |
+
self.layers = layers
|
347 |
+
self.stacks = stacks
|
348 |
+
self.kernel_size = kernel_size
|
349 |
+
|
350 |
+
# check the number of layers and stacks
|
351 |
+
assert layers % stacks == 0
|
352 |
+
layers_per_stack = layers // stacks
|
353 |
+
|
354 |
+
# define first convolution
|
355 |
+
self.first_conv = torch.nn.Sequential(
|
356 |
+
Conv1d1x1(in_channels, residual_channels, bias=True),
|
357 |
+
getattr(torch.nn, nonlinear_activation)(
|
358 |
+
inplace=True, **nonlinear_activation_params),
|
359 |
+
)
|
360 |
+
|
361 |
+
# define residual blocks
|
362 |
+
self.conv_layers = torch.nn.ModuleList()
|
363 |
+
for layer in range(layers):
|
364 |
+
dilation = 2 ** (layer % layers_per_stack)
|
365 |
+
conv = ResidualBlock(
|
366 |
+
kernel_size=kernel_size,
|
367 |
+
residual_channels=residual_channels,
|
368 |
+
gate_channels=gate_channels,
|
369 |
+
skip_channels=skip_channels,
|
370 |
+
aux_channels=-1,
|
371 |
+
dilation=dilation,
|
372 |
+
dropout=dropout,
|
373 |
+
bias=bias,
|
374 |
+
use_causal_conv=use_causal_conv,
|
375 |
+
)
|
376 |
+
self.conv_layers += [conv]
|
377 |
+
|
378 |
+
# define output layers
|
379 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
380 |
+
getattr(torch.nn, nonlinear_activation)(
|
381 |
+
inplace=True, **nonlinear_activation_params),
|
382 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
383 |
+
getattr(torch.nn, nonlinear_activation)(
|
384 |
+
inplace=True, **nonlinear_activation_params),
|
385 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
386 |
+
])
|
387 |
+
|
388 |
+
# apply weight norm
|
389 |
+
if use_weight_norm:
|
390 |
+
self.apply_weight_norm()
|
391 |
+
|
392 |
+
def forward(self, x):
|
393 |
+
"""Calculate forward propagation.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
x (Tensor): Input noise signal (B, 1, T).
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
Tensor: Output tensor (B, 1, T)
|
400 |
+
|
401 |
+
"""
|
402 |
+
x = self.first_conv(x)
|
403 |
+
|
404 |
+
skips = 0
|
405 |
+
for f in self.conv_layers:
|
406 |
+
x, h = f(x, None)
|
407 |
+
skips += h
|
408 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
409 |
+
|
410 |
+
# apply final layers
|
411 |
+
x = skips
|
412 |
+
for f in self.last_conv_layers:
|
413 |
+
x = f(x)
|
414 |
+
return x
|
415 |
+
|
416 |
+
def apply_weight_norm(self):
|
417 |
+
"""Apply weight normalization module from all of the layers."""
|
418 |
+
def _apply_weight_norm(m):
|
419 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
420 |
+
torch.nn.utils.weight_norm(m)
|
421 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
422 |
+
|
423 |
+
self.apply(_apply_weight_norm)
|
424 |
+
|
425 |
+
def remove_weight_norm(self):
|
426 |
+
"""Remove weight normalization module from all of the layers."""
|
427 |
+
def _remove_weight_norm(m):
|
428 |
+
try:
|
429 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
430 |
+
torch.nn.utils.remove_weight_norm(m)
|
431 |
+
except ValueError: # this module didn't have weight norm
|
432 |
+
return
|
433 |
+
|
434 |
+
self.apply(_remove_weight_norm)
|
modules/parallel_wavegan/models/source.py
ADDED
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import sys
|
4 |
+
import torch.nn.functional as torch_nn_func
|
5 |
+
|
6 |
+
|
7 |
+
class SineGen(torch.nn.Module):
|
8 |
+
""" Definition of sine generator
|
9 |
+
SineGen(samp_rate, harmonic_num = 0,
|
10 |
+
sine_amp = 0.1, noise_std = 0.003,
|
11 |
+
voiced_threshold = 0,
|
12 |
+
flag_for_pulse=False)
|
13 |
+
|
14 |
+
samp_rate: sampling rate in Hz
|
15 |
+
harmonic_num: number of harmonic overtones (default 0)
|
16 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
17 |
+
noise_std: std of Gaussian noise (default 0.003)
|
18 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
19 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
20 |
+
|
21 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
22 |
+
segment is always sin(np.pi) or cos(0)
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
26 |
+
sine_amp=0.1, noise_std=0.003,
|
27 |
+
voiced_threshold=0,
|
28 |
+
flag_for_pulse=False):
|
29 |
+
super(SineGen, self).__init__()
|
30 |
+
self.sine_amp = sine_amp
|
31 |
+
self.noise_std = noise_std
|
32 |
+
self.harmonic_num = harmonic_num
|
33 |
+
self.dim = self.harmonic_num + 1
|
34 |
+
self.sampling_rate = samp_rate
|
35 |
+
self.voiced_threshold = voiced_threshold
|
36 |
+
self.flag_for_pulse = flag_for_pulse
|
37 |
+
|
38 |
+
def _f02uv(self, f0):
|
39 |
+
# generate uv signal
|
40 |
+
uv = torch.ones_like(f0)
|
41 |
+
uv = uv * (f0 > self.voiced_threshold)
|
42 |
+
return uv
|
43 |
+
|
44 |
+
def _f02sine(self, f0_values):
|
45 |
+
""" f0_values: (batchsize, length, dim)
|
46 |
+
where dim indicates fundamental tone and overtones
|
47 |
+
"""
|
48 |
+
# convert to F0 in rad. The interger part n can be ignored
|
49 |
+
# because 2 * np.pi * n doesn't affect phase
|
50 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
51 |
+
|
52 |
+
# initial phase noise (no noise for fundamental component)
|
53 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
54 |
+
device=f0_values.device)
|
55 |
+
rand_ini[:, 0] = 0
|
56 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
57 |
+
|
58 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
59 |
+
if not self.flag_for_pulse:
|
60 |
+
# for normal case
|
61 |
+
|
62 |
+
# To prevent torch.cumsum numerical overflow,
|
63 |
+
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
64 |
+
# Buffer tmp_over_one_idx indicates the time step to add -1.
|
65 |
+
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
66 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
67 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
|
68 |
+
tmp_over_one[:, :-1, :]) < 0
|
69 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
70 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
71 |
+
|
72 |
+
sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
|
73 |
+
* 2 * np.pi)
|
74 |
+
else:
|
75 |
+
# If necessary, make sure that the first time step of every
|
76 |
+
# voiced segments is sin(pi) or cos(0)
|
77 |
+
# This is used for pulse-train generation
|
78 |
+
|
79 |
+
# identify the last time step in unvoiced segments
|
80 |
+
uv = self._f02uv(f0_values)
|
81 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
82 |
+
uv_1[:, -1, :] = 1
|
83 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
84 |
+
|
85 |
+
# get the instantanouse phase
|
86 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
87 |
+
# different batch needs to be processed differently
|
88 |
+
for idx in range(f0_values.shape[0]):
|
89 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
90 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
91 |
+
# stores the accumulation of i.phase within
|
92 |
+
# each voiced segments
|
93 |
+
tmp_cumsum[idx, :, :] = 0
|
94 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
95 |
+
|
96 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
97 |
+
# within the previous voiced segment.
|
98 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
99 |
+
|
100 |
+
# get the sines
|
101 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
102 |
+
return sines
|
103 |
+
|
104 |
+
def forward(self, f0):
|
105 |
+
""" sine_tensor, uv = forward(f0)
|
106 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
107 |
+
f0 for unvoiced steps should be 0
|
108 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
109 |
+
output uv: tensor(batchsize=1, length, 1)
|
110 |
+
"""
|
111 |
+
with torch.no_grad():
|
112 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
113 |
+
device=f0.device)
|
114 |
+
# fundamental component
|
115 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
116 |
+
for idx in np.arange(self.harmonic_num):
|
117 |
+
# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
118 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
119 |
+
|
120 |
+
# generate sine waveforms
|
121 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
122 |
+
|
123 |
+
# generate uv signal
|
124 |
+
# uv = torch.ones(f0.shape)
|
125 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
126 |
+
uv = self._f02uv(f0)
|
127 |
+
|
128 |
+
# noise: for unvoiced should be similar to sine_amp
|
129 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
130 |
+
# . for voiced regions is self.noise_std
|
131 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
132 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
133 |
+
|
134 |
+
# first: set the unvoiced part to 0 by uv
|
135 |
+
# then: additive noise
|
136 |
+
sine_waves = sine_waves * uv + noise
|
137 |
+
return sine_waves, uv, noise
|
138 |
+
|
139 |
+
|
140 |
+
class PulseGen(torch.nn.Module):
|
141 |
+
""" Definition of Pulse train generator
|
142 |
+
|
143 |
+
There are many ways to implement pulse generator.
|
144 |
+
Here, PulseGen is based on SinGen. For a perfect
|
145 |
+
"""
|
146 |
+
def __init__(self, samp_rate, pulse_amp = 0.1,
|
147 |
+
noise_std = 0.003, voiced_threshold = 0):
|
148 |
+
super(PulseGen, self).__init__()
|
149 |
+
self.pulse_amp = pulse_amp
|
150 |
+
self.sampling_rate = samp_rate
|
151 |
+
self.voiced_threshold = voiced_threshold
|
152 |
+
self.noise_std = noise_std
|
153 |
+
self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \
|
154 |
+
sine_amp=self.pulse_amp, noise_std=0, \
|
155 |
+
voiced_threshold=self.voiced_threshold, \
|
156 |
+
flag_for_pulse=True)
|
157 |
+
|
158 |
+
def forward(self, f0):
|
159 |
+
""" Pulse train generator
|
160 |
+
pulse_train, uv = forward(f0)
|
161 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
162 |
+
f0 for unvoiced steps should be 0
|
163 |
+
output pulse_train: tensor(batchsize=1, length, dim)
|
164 |
+
output uv: tensor(batchsize=1, length, 1)
|
165 |
+
|
166 |
+
Note: self.l_sine doesn't make sure that the initial phase of
|
167 |
+
a voiced segment is np.pi, the first pulse in a voiced segment
|
168 |
+
may not be at the first time step within a voiced segment
|
169 |
+
"""
|
170 |
+
with torch.no_grad():
|
171 |
+
sine_wav, uv, noise = self.l_sinegen(f0)
|
172 |
+
|
173 |
+
# sine without additive noise
|
174 |
+
pure_sine = sine_wav - noise
|
175 |
+
|
176 |
+
# step t corresponds to a pulse if
|
177 |
+
# sine[t] > sine[t+1] & sine[t] > sine[t-1]
|
178 |
+
# & sine[t-1], sine[t+1], and sine[t] are voiced
|
179 |
+
# or
|
180 |
+
# sine[t] is voiced, sine[t-1] is unvoiced
|
181 |
+
# we use torch.roll to simulate sine[t+1] and sine[t-1]
|
182 |
+
sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
|
183 |
+
uv_1 = torch.roll(uv, shifts=1, dims=1)
|
184 |
+
uv_1[:, 0, :] = 0
|
185 |
+
sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
|
186 |
+
uv_2 = torch.roll(uv, shifts=-1, dims=1)
|
187 |
+
uv_2[:, -1, :] = 0
|
188 |
+
|
189 |
+
loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
|
190 |
+
* (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
|
191 |
+
+ (uv_1 < 1) * (uv > 0)
|
192 |
+
|
193 |
+
# pulse train without noise
|
194 |
+
pulse_train = pure_sine * loc
|
195 |
+
|
196 |
+
# additive noise to pulse train
|
197 |
+
# note that noise from sinegen is zero in voiced regions
|
198 |
+
pulse_noise = torch.randn_like(pure_sine) * self.noise_std
|
199 |
+
|
200 |
+
# with additive noise on pulse, and unvoiced regions
|
201 |
+
pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
|
202 |
+
return pulse_train, sine_wav, uv, pulse_noise
|
203 |
+
|
204 |
+
|
205 |
+
class SignalsConv1d(torch.nn.Module):
|
206 |
+
""" Filtering input signal with time invariant filter
|
207 |
+
Note: FIRFilter conducted filtering given fixed FIR weight
|
208 |
+
SignalsConv1d convolves two signals
|
209 |
+
Note: this is based on torch.nn.functional.conv1d
|
210 |
+
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self):
|
214 |
+
super(SignalsConv1d, self).__init__()
|
215 |
+
|
216 |
+
def forward(self, signal, system_ir):
|
217 |
+
""" output = forward(signal, system_ir)
|
218 |
+
|
219 |
+
signal: (batchsize, length1, dim)
|
220 |
+
system_ir: (length2, dim)
|
221 |
+
|
222 |
+
output: (batchsize, length1, dim)
|
223 |
+
"""
|
224 |
+
if signal.shape[-1] != system_ir.shape[-1]:
|
225 |
+
print("Error: SignalsConv1d expects shape:")
|
226 |
+
print("signal (batchsize, length1, dim)")
|
227 |
+
print("system_id (batchsize, length2, dim)")
|
228 |
+
print("But received signal: {:s}".format(str(signal.shape)))
|
229 |
+
print(" system_ir: {:s}".format(str(system_ir.shape)))
|
230 |
+
sys.exit(1)
|
231 |
+
padding_length = system_ir.shape[0] - 1
|
232 |
+
groups = signal.shape[-1]
|
233 |
+
|
234 |
+
# pad signal on the left
|
235 |
+
signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \
|
236 |
+
(padding_length, 0))
|
237 |
+
# prepare system impulse response as (dim, 1, length2)
|
238 |
+
# also flip the impulse response
|
239 |
+
ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \
|
240 |
+
dims=[2])
|
241 |
+
# convolute
|
242 |
+
output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
|
243 |
+
return output.permute(0, 2, 1)
|
244 |
+
|
245 |
+
|
246 |
+
class CyclicNoiseGen_v1(torch.nn.Module):
|
247 |
+
""" CyclicnoiseGen_v1
|
248 |
+
Cyclic noise with a single parameter of beta.
|
249 |
+
Pytorch v1 implementation assumes f_t is also fixed
|
250 |
+
"""
|
251 |
+
|
252 |
+
def __init__(self, samp_rate,
|
253 |
+
noise_std=0.003, voiced_threshold=0):
|
254 |
+
super(CyclicNoiseGen_v1, self).__init__()
|
255 |
+
self.samp_rate = samp_rate
|
256 |
+
self.noise_std = noise_std
|
257 |
+
self.voiced_threshold = voiced_threshold
|
258 |
+
|
259 |
+
self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0,
|
260 |
+
noise_std=noise_std,
|
261 |
+
voiced_threshold=voiced_threshold)
|
262 |
+
self.l_conv = SignalsConv1d()
|
263 |
+
|
264 |
+
def noise_decay(self, beta, f0mean):
|
265 |
+
""" decayed_noise = noise_decay(beta, f0mean)
|
266 |
+
decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate)
|
267 |
+
|
268 |
+
beta: (dim=1) or (batchsize=1, 1, dim=1)
|
269 |
+
f0mean (batchsize=1, 1, dim=1)
|
270 |
+
|
271 |
+
decayed_noise (batchsize=1, length, dim=1)
|
272 |
+
"""
|
273 |
+
with torch.no_grad():
|
274 |
+
# exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
|
275 |
+
# truncate the noise when decayed by -40 dB
|
276 |
+
length = 4.6 * self.samp_rate / f0mean
|
277 |
+
length = length.int()
|
278 |
+
time_idx = torch.arange(0, length, device=beta.device)
|
279 |
+
time_idx = time_idx.unsqueeze(0).unsqueeze(2)
|
280 |
+
time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
|
281 |
+
|
282 |
+
noise = torch.randn(time_idx.shape, device=beta.device)
|
283 |
+
|
284 |
+
# due to Pytorch implementation, use f0_mean as the f0 factor
|
285 |
+
decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
|
286 |
+
return noise * self.noise_std * decay
|
287 |
+
|
288 |
+
def forward(self, f0s, beta):
|
289 |
+
""" Producde cyclic-noise
|
290 |
+
"""
|
291 |
+
# pulse train
|
292 |
+
pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
|
293 |
+
pure_pulse = pulse_train - noise
|
294 |
+
|
295 |
+
# decayed_noise (length, dim=1)
|
296 |
+
if (uv < 1).all():
|
297 |
+
# all unvoiced
|
298 |
+
cyc_noise = torch.zeros_like(sine_wav)
|
299 |
+
else:
|
300 |
+
f0mean = f0s[uv > 0].mean()
|
301 |
+
|
302 |
+
decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
|
303 |
+
# convolute
|
304 |
+
cyc_noise = self.l_conv(pure_pulse, decayed_noise)
|
305 |
+
|
306 |
+
# add noise in invoiced segments
|
307 |
+
cyc_noise = cyc_noise + noise * (1.0 - uv)
|
308 |
+
return cyc_noise, pulse_train, sine_wav, uv, noise
|
309 |
+
|
310 |
+
|
311 |
+
class SineGen(torch.nn.Module):
|
312 |
+
""" Definition of sine generator
|
313 |
+
SineGen(samp_rate, harmonic_num = 0,
|
314 |
+
sine_amp = 0.1, noise_std = 0.003,
|
315 |
+
voiced_threshold = 0,
|
316 |
+
flag_for_pulse=False)
|
317 |
+
|
318 |
+
samp_rate: sampling rate in Hz
|
319 |
+
harmonic_num: number of harmonic overtones (default 0)
|
320 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
321 |
+
noise_std: std of Gaussian noise (default 0.003)
|
322 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
323 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
324 |
+
|
325 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
326 |
+
segment is always sin(np.pi) or cos(0)
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
330 |
+
sine_amp=0.1, noise_std=0.003,
|
331 |
+
voiced_threshold=0,
|
332 |
+
flag_for_pulse=False):
|
333 |
+
super(SineGen, self).__init__()
|
334 |
+
self.sine_amp = sine_amp
|
335 |
+
self.noise_std = noise_std
|
336 |
+
self.harmonic_num = harmonic_num
|
337 |
+
self.dim = self.harmonic_num + 1
|
338 |
+
self.sampling_rate = samp_rate
|
339 |
+
self.voiced_threshold = voiced_threshold
|
340 |
+
self.flag_for_pulse = flag_for_pulse
|
341 |
+
|
342 |
+
def _f02uv(self, f0):
|
343 |
+
# generate uv signal
|
344 |
+
uv = torch.ones_like(f0)
|
345 |
+
uv = uv * (f0 > self.voiced_threshold)
|
346 |
+
return uv
|
347 |
+
|
348 |
+
def _f02sine(self, f0_values):
|
349 |
+
""" f0_values: (batchsize, length, dim)
|
350 |
+
where dim indicates fundamental tone and overtones
|
351 |
+
"""
|
352 |
+
# convert to F0 in rad. The interger part n can be ignored
|
353 |
+
# because 2 * np.pi * n doesn't affect phase
|
354 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
355 |
+
|
356 |
+
# initial phase noise (no noise for fundamental component)
|
357 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
358 |
+
device=f0_values.device)
|
359 |
+
rand_ini[:, 0] = 0
|
360 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
361 |
+
|
362 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
363 |
+
if not self.flag_for_pulse:
|
364 |
+
# for normal case
|
365 |
+
|
366 |
+
# To prevent torch.cumsum numerical overflow,
|
367 |
+
# it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
368 |
+
# Buffer tmp_over_one_idx indicates the time step to add -1.
|
369 |
+
# This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
370 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
371 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
|
372 |
+
tmp_over_one[:, :-1, :]) < 0
|
373 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
374 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
375 |
+
|
376 |
+
sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
|
377 |
+
* 2 * np.pi)
|
378 |
+
else:
|
379 |
+
# If necessary, make sure that the first time step of every
|
380 |
+
# voiced segments is sin(pi) or cos(0)
|
381 |
+
# This is used for pulse-train generation
|
382 |
+
|
383 |
+
# identify the last time step in unvoiced segments
|
384 |
+
uv = self._f02uv(f0_values)
|
385 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
386 |
+
uv_1[:, -1, :] = 1
|
387 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
388 |
+
|
389 |
+
# get the instantanouse phase
|
390 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
391 |
+
# different batch needs to be processed differently
|
392 |
+
for idx in range(f0_values.shape[0]):
|
393 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
394 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
395 |
+
# stores the accumulation of i.phase within
|
396 |
+
# each voiced segments
|
397 |
+
tmp_cumsum[idx, :, :] = 0
|
398 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
399 |
+
|
400 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
401 |
+
# within the previous voiced segment.
|
402 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
403 |
+
|
404 |
+
# get the sines
|
405 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
406 |
+
return sines
|
407 |
+
|
408 |
+
def forward(self, f0):
|
409 |
+
""" sine_tensor, uv = forward(f0)
|
410 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
411 |
+
f0 for unvoiced steps should be 0
|
412 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
413 |
+
output uv: tensor(batchsize=1, length, 1)
|
414 |
+
"""
|
415 |
+
with torch.no_grad():
|
416 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \
|
417 |
+
device=f0.device)
|
418 |
+
# fundamental component
|
419 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
420 |
+
for idx in np.arange(self.harmonic_num):
|
421 |
+
# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
422 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
423 |
+
|
424 |
+
# generate sine waveforms
|
425 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
426 |
+
|
427 |
+
# generate uv signal
|
428 |
+
# uv = torch.ones(f0.shape)
|
429 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
430 |
+
uv = self._f02uv(f0)
|
431 |
+
|
432 |
+
# noise: for unvoiced should be similar to sine_amp
|
433 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
434 |
+
# . for voiced regions is self.noise_std
|
435 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
436 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
437 |
+
|
438 |
+
# first: set the unvoiced part to 0 by uv
|
439 |
+
# then: additive noise
|
440 |
+
sine_waves = sine_waves * uv + noise
|
441 |
+
return sine_waves, uv, noise
|
442 |
+
|
443 |
+
|
444 |
+
class SourceModuleCycNoise_v1(torch.nn.Module):
|
445 |
+
""" SourceModuleCycNoise_v1
|
446 |
+
SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
|
447 |
+
sampling_rate: sampling_rate in Hz
|
448 |
+
|
449 |
+
noise_std: std of Gaussian noise (default: 0.003)
|
450 |
+
voiced_threshold: threshold to set U/V given F0 (default: 0)
|
451 |
+
|
452 |
+
cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
|
453 |
+
F0_upsampled (batchsize, length, 1)
|
454 |
+
beta (1)
|
455 |
+
cyc (batchsize, length, 1)
|
456 |
+
noise (batchsize, length, 1)
|
457 |
+
uv (batchsize, length, 1)
|
458 |
+
"""
|
459 |
+
|
460 |
+
def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
|
461 |
+
super(SourceModuleCycNoise_v1, self).__init__()
|
462 |
+
self.sampling_rate = sampling_rate
|
463 |
+
self.noise_std = noise_std
|
464 |
+
self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std,
|
465 |
+
voiced_threshod)
|
466 |
+
|
467 |
+
def forward(self, f0_upsamped, beta):
|
468 |
+
"""
|
469 |
+
cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
|
470 |
+
F0_upsampled (batchsize, length, 1)
|
471 |
+
beta (1)
|
472 |
+
cyc (batchsize, length, 1)
|
473 |
+
noise (batchsize, length, 1)
|
474 |
+
uv (batchsize, length, 1)
|
475 |
+
"""
|
476 |
+
# source for harmonic branch
|
477 |
+
cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
|
478 |
+
|
479 |
+
# source for noise branch, in the same shape as uv
|
480 |
+
noise = torch.randn_like(uv) * self.noise_std / 3
|
481 |
+
return cyc, noise, uv
|
482 |
+
|
483 |
+
|
484 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
485 |
+
""" SourceModule for hn-nsf
|
486 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
487 |
+
add_noise_std=0.003, voiced_threshod=0)
|
488 |
+
sampling_rate: sampling_rate in Hz
|
489 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
490 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
491 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
492 |
+
note that amplitude of noise in unvoiced is decided
|
493 |
+
by sine_amp
|
494 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
495 |
+
|
496 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
497 |
+
F0_sampled (batchsize, length, 1)
|
498 |
+
Sine_source (batchsize, length, 1)
|
499 |
+
noise_source (batchsize, length 1)
|
500 |
+
uv (batchsize, length, 1)
|
501 |
+
"""
|
502 |
+
|
503 |
+
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
|
504 |
+
add_noise_std=0.003, voiced_threshod=0):
|
505 |
+
super(SourceModuleHnNSF, self).__init__()
|
506 |
+
|
507 |
+
self.sine_amp = sine_amp
|
508 |
+
self.noise_std = add_noise_std
|
509 |
+
|
510 |
+
# to produce sine waveforms
|
511 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
512 |
+
sine_amp, add_noise_std, voiced_threshod)
|
513 |
+
|
514 |
+
# to merge source harmonics into a single excitation
|
515 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
516 |
+
self.l_tanh = torch.nn.Tanh()
|
517 |
+
|
518 |
+
def forward(self, x):
|
519 |
+
"""
|
520 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
521 |
+
F0_sampled (batchsize, length, 1)
|
522 |
+
Sine_source (batchsize, length, 1)
|
523 |
+
noise_source (batchsize, length 1)
|
524 |
+
"""
|
525 |
+
# source for harmonic branch
|
526 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
527 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
528 |
+
|
529 |
+
# source for noise branch, in the same shape as uv
|
530 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
531 |
+
return sine_merge, noise, uv
|
532 |
+
|
533 |
+
|
534 |
+
if __name__ == '__main__':
|
535 |
+
source = SourceModuleCycNoise_v1(24000)
|
536 |
+
x = torch.randn(16, 25600, 1)
|
537 |
+
|
538 |
+
|
modules/parallel_wavegan/optimizers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from torch.optim import * # NOQA
|
2 |
+
from .radam import * # NOQA
|
modules/parallel_wavegan/optimizers/radam.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""RAdam optimizer.
|
4 |
+
|
5 |
+
This code is drived from https://github.com/LiyuanLucasLiu/RAdam.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from torch.optim.optimizer import Optimizer
|
12 |
+
|
13 |
+
|
14 |
+
class RAdam(Optimizer):
|
15 |
+
"""Rectified Adam optimizer."""
|
16 |
+
|
17 |
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
18 |
+
"""Initilize RAdam optimizer."""
|
19 |
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
20 |
+
self.buffer = [[None, None, None] for ind in range(10)]
|
21 |
+
super(RAdam, self).__init__(params, defaults)
|
22 |
+
|
23 |
+
def __setstate__(self, state):
|
24 |
+
"""Set state."""
|
25 |
+
super(RAdam, self).__setstate__(state)
|
26 |
+
|
27 |
+
def step(self, closure=None):
|
28 |
+
"""Run one step."""
|
29 |
+
loss = None
|
30 |
+
if closure is not None:
|
31 |
+
loss = closure()
|
32 |
+
|
33 |
+
for group in self.param_groups:
|
34 |
+
|
35 |
+
for p in group['params']:
|
36 |
+
if p.grad is None:
|
37 |
+
continue
|
38 |
+
grad = p.grad.data.float()
|
39 |
+
if grad.is_sparse:
|
40 |
+
raise RuntimeError('RAdam does not support sparse gradients')
|
41 |
+
|
42 |
+
p_data_fp32 = p.data.float()
|
43 |
+
|
44 |
+
state = self.state[p]
|
45 |
+
|
46 |
+
if len(state) == 0:
|
47 |
+
state['step'] = 0
|
48 |
+
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
49 |
+
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
50 |
+
else:
|
51 |
+
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
52 |
+
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
53 |
+
|
54 |
+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
55 |
+
beta1, beta2 = group['betas']
|
56 |
+
|
57 |
+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
58 |
+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
59 |
+
|
60 |
+
state['step'] += 1
|
61 |
+
buffered = self.buffer[int(state['step'] % 10)]
|
62 |
+
if state['step'] == buffered[0]:
|
63 |
+
N_sma, step_size = buffered[1], buffered[2]
|
64 |
+
else:
|
65 |
+
buffered[0] = state['step']
|
66 |
+
beta2_t = beta2 ** state['step']
|
67 |
+
N_sma_max = 2 / (1 - beta2) - 1
|
68 |
+
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
69 |
+
buffered[1] = N_sma
|
70 |
+
|
71 |
+
# more conservative since it's an approximated value
|
72 |
+
if N_sma >= 5:
|
73 |
+
step_size = math.sqrt(
|
74 |
+
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA
|
75 |
+
else:
|
76 |
+
step_size = 1.0 / (1 - beta1 ** state['step'])
|
77 |
+
buffered[2] = step_size
|
78 |
+
|
79 |
+
if group['weight_decay'] != 0:
|
80 |
+
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
81 |
+
|
82 |
+
# more conservative since it's an approximated value
|
83 |
+
if N_sma >= 5:
|
84 |
+
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
85 |
+
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
|
86 |
+
else:
|
87 |
+
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
|
88 |
+
|
89 |
+
p.data.copy_(p_data_fp32)
|
90 |
+
|
91 |
+
return loss
|
modules/parallel_wavegan/stft_loss.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""STFT-based Loss modules."""
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft
|
11 |
+
|
12 |
+
|
13 |
+
class STFTLoss(torch.nn.Module):
|
14 |
+
"""STFT loss module."""
|
15 |
+
|
16 |
+
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
|
17 |
+
use_mel_loss=False):
|
18 |
+
"""Initialize STFT loss module."""
|
19 |
+
super(STFTLoss, self).__init__()
|
20 |
+
self.fft_size = fft_size
|
21 |
+
self.shift_size = shift_size
|
22 |
+
self.win_length = win_length
|
23 |
+
self.window = getattr(torch, window)(win_length)
|
24 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
25 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
26 |
+
self.use_mel_loss = use_mel_loss
|
27 |
+
self.mel_basis = None
|
28 |
+
|
29 |
+
def forward(self, x, y):
|
30 |
+
"""Calculate forward propagation.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x (Tensor): Predicted signal (B, T).
|
34 |
+
y (Tensor): Groundtruth signal (B, T).
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tensor: Spectral convergence loss value.
|
38 |
+
Tensor: Log STFT magnitude loss value.
|
39 |
+
|
40 |
+
"""
|
41 |
+
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
|
42 |
+
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
|
43 |
+
if self.use_mel_loss:
|
44 |
+
if self.mel_basis is None:
|
45 |
+
self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T
|
46 |
+
x_mag = x_mag @ self.mel_basis
|
47 |
+
y_mag = y_mag @ self.mel_basis
|
48 |
+
|
49 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
50 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
51 |
+
|
52 |
+
return sc_loss, mag_loss
|
53 |
+
|
54 |
+
|
55 |
+
class MultiResolutionSTFTLoss(torch.nn.Module):
|
56 |
+
"""Multi resolution STFT loss module."""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
fft_sizes=[1024, 2048, 512],
|
60 |
+
hop_sizes=[120, 240, 50],
|
61 |
+
win_lengths=[600, 1200, 240],
|
62 |
+
window="hann_window",
|
63 |
+
use_mel_loss=False):
|
64 |
+
"""Initialize Multi resolution STFT loss module.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
fft_sizes (list): List of FFT sizes.
|
68 |
+
hop_sizes (list): List of hop sizes.
|
69 |
+
win_lengths (list): List of window lengths.
|
70 |
+
window (str): Window function type.
|
71 |
+
|
72 |
+
"""
|
73 |
+
super(MultiResolutionSTFTLoss, self).__init__()
|
74 |
+
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
75 |
+
self.stft_losses = torch.nn.ModuleList()
|
76 |
+
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
77 |
+
self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)]
|
78 |
+
|
79 |
+
def forward(self, x, y):
|
80 |
+
"""Calculate forward propagation.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (Tensor): Predicted signal (B, T).
|
84 |
+
y (Tensor): Groundtruth signal (B, T).
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Tensor: Multi resolution spectral convergence loss value.
|
88 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
89 |
+
|
90 |
+
"""
|
91 |
+
sc_loss = 0.0
|
92 |
+
mag_loss = 0.0
|
93 |
+
for f in self.stft_losses:
|
94 |
+
sc_l, mag_l = f(x, y)
|
95 |
+
sc_loss += sc_l
|
96 |
+
mag_loss += mag_l
|
97 |
+
sc_loss /= len(self.stft_losses)
|
98 |
+
mag_loss /= len(self.stft_losses)
|
99 |
+
|
100 |
+
return sc_loss, mag_loss
|
modules/parallel_wavegan/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .utils import * # NOQA
|
modules/parallel_wavegan/utils/utils.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Utility functions."""
|
7 |
+
|
8 |
+
import fnmatch
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
|
13 |
+
import h5py
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
|
17 |
+
def find_files(root_dir, query="*.wav", include_root_dir=True):
|
18 |
+
"""Find files recursively.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
root_dir (str): Root root_dir to find.
|
22 |
+
query (str): Query to find.
|
23 |
+
include_root_dir (bool): If False, root_dir name is not included.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
list: List of found filenames.
|
27 |
+
|
28 |
+
"""
|
29 |
+
files = []
|
30 |
+
for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
|
31 |
+
for filename in fnmatch.filter(filenames, query):
|
32 |
+
files.append(os.path.join(root, filename))
|
33 |
+
if not include_root_dir:
|
34 |
+
files = [file_.replace(root_dir + "/", "") for file_ in files]
|
35 |
+
|
36 |
+
return files
|
37 |
+
|
38 |
+
|
39 |
+
def read_hdf5(hdf5_name, hdf5_path):
|
40 |
+
"""Read hdf5 dataset.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
hdf5_name (str): Filename of hdf5 file.
|
44 |
+
hdf5_path (str): Dataset name in hdf5 file.
|
45 |
+
|
46 |
+
Return:
|
47 |
+
any: Dataset values.
|
48 |
+
|
49 |
+
"""
|
50 |
+
if not os.path.exists(hdf5_name):
|
51 |
+
logging.error(f"There is no such a hdf5 file ({hdf5_name}).")
|
52 |
+
sys.exit(1)
|
53 |
+
|
54 |
+
hdf5_file = h5py.File(hdf5_name, "r")
|
55 |
+
|
56 |
+
if hdf5_path not in hdf5_file:
|
57 |
+
logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
|
58 |
+
sys.exit(1)
|
59 |
+
|
60 |
+
hdf5_data = hdf5_file[hdf5_path][()]
|
61 |
+
hdf5_file.close()
|
62 |
+
|
63 |
+
return hdf5_data
|
64 |
+
|
65 |
+
|
66 |
+
def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
|
67 |
+
"""Write dataset to hdf5.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
hdf5_name (str): Hdf5 dataset filename.
|
71 |
+
hdf5_path (str): Dataset path in hdf5.
|
72 |
+
write_data (ndarray): Data to write.
|
73 |
+
is_overwrite (bool): Whether to overwrite dataset.
|
74 |
+
|
75 |
+
"""
|
76 |
+
# convert to numpy array
|
77 |
+
write_data = np.array(write_data)
|
78 |
+
|
79 |
+
# check folder existence
|
80 |
+
folder_name, _ = os.path.split(hdf5_name)
|
81 |
+
if not os.path.exists(folder_name) and len(folder_name) != 0:
|
82 |
+
os.makedirs(folder_name)
|
83 |
+
|
84 |
+
# check hdf5 existence
|
85 |
+
if os.path.exists(hdf5_name):
|
86 |
+
# if already exists, open with r+ mode
|
87 |
+
hdf5_file = h5py.File(hdf5_name, "r+")
|
88 |
+
# check dataset existence
|
89 |
+
if hdf5_path in hdf5_file:
|
90 |
+
if is_overwrite:
|
91 |
+
logging.warning("Dataset in hdf5 file already exists. "
|
92 |
+
"recreate dataset in hdf5.")
|
93 |
+
hdf5_file.__delitem__(hdf5_path)
|
94 |
+
else:
|
95 |
+
logging.error("Dataset in hdf5 file already exists. "
|
96 |
+
"if you want to overwrite, please set is_overwrite = True.")
|
97 |
+
hdf5_file.close()
|
98 |
+
sys.exit(1)
|
99 |
+
else:
|
100 |
+
# if not exists, open with w mode
|
101 |
+
hdf5_file = h5py.File(hdf5_name, "w")
|
102 |
+
|
103 |
+
# write data to hdf5
|
104 |
+
hdf5_file.create_dataset(hdf5_path, data=write_data)
|
105 |
+
hdf5_file.flush()
|
106 |
+
hdf5_file.close()
|
107 |
+
|
108 |
+
|
109 |
+
class HDF5ScpLoader(object):
|
110 |
+
"""Loader class for a fests.scp file of hdf5 file.
|
111 |
+
|
112 |
+
Examples:
|
113 |
+
key1 /some/path/a.h5:feats
|
114 |
+
key2 /some/path/b.h5:feats
|
115 |
+
key3 /some/path/c.h5:feats
|
116 |
+
key4 /some/path/d.h5:feats
|
117 |
+
...
|
118 |
+
>>> loader = HDF5ScpLoader("hdf5.scp")
|
119 |
+
>>> array = loader["key1"]
|
120 |
+
|
121 |
+
key1 /some/path/a.h5
|
122 |
+
key2 /some/path/b.h5
|
123 |
+
key3 /some/path/c.h5
|
124 |
+
key4 /some/path/d.h5
|
125 |
+
...
|
126 |
+
>>> loader = HDF5ScpLoader("hdf5.scp", "feats")
|
127 |
+
>>> array = loader["key1"]
|
128 |
+
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, feats_scp, default_hdf5_path="feats"):
|
132 |
+
"""Initialize HDF5 scp loader.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
feats_scp (str): Kaldi-style feats.scp file with hdf5 format.
|
136 |
+
default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used.
|
137 |
+
|
138 |
+
"""
|
139 |
+
self.default_hdf5_path = default_hdf5_path
|
140 |
+
with open(feats_scp, encoding='utf-8') as f:
|
141 |
+
lines = [line.replace("\n", "") for line in f.readlines()]
|
142 |
+
self.data = {}
|
143 |
+
for line in lines:
|
144 |
+
key, value = line.split()
|
145 |
+
self.data[key] = value
|
146 |
+
|
147 |
+
def get_path(self, key):
|
148 |
+
"""Get hdf5 file path for a given key."""
|
149 |
+
return self.data[key]
|
150 |
+
|
151 |
+
def __getitem__(self, key):
|
152 |
+
"""Get ndarray for a given key."""
|
153 |
+
p = self.data[key]
|
154 |
+
if ":" in p:
|
155 |
+
return read_hdf5(*p.split(":"))
|
156 |
+
else:
|
157 |
+
return read_hdf5(p, self.default_hdf5_path)
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
"""Return the length of the scp file."""
|
161 |
+
return len(self.data)
|
162 |
+
|
163 |
+
def __iter__(self):
|
164 |
+
"""Return the iterator of the scp file."""
|
165 |
+
return iter(self.data)
|
166 |
+
|
167 |
+
def keys(self):
|
168 |
+
"""Return the keys of the scp file."""
|
169 |
+
return self.data.keys()
|
network/diff/candidate_decoder.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder
|
2 |
+
# from modules.fastspeech.fast_tacotron import DecoderRNN
|
3 |
+
# from modules.fastspeech.speedy_speech.speedy_speech import ConvBlocks
|
4 |
+
# from modules.fastspeech.conformer.conformer import ConformerDecoder
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
import math
|
9 |
+
from utils.hparams import hparams
|
10 |
+
from modules.commons.common_layers import Mish
|
11 |
+
Linear = nn.Linear
|
12 |
+
|
13 |
+
class SinusoidalPosEmb(nn.Module):
|
14 |
+
def __init__(self, dim):
|
15 |
+
super().__init__()
|
16 |
+
self.dim = dim
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
device = x.device
|
20 |
+
half_dim = self.dim // 2
|
21 |
+
emb = math.log(10000) / (half_dim - 1)
|
22 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
23 |
+
emb = x[:, None] * emb[None, :]
|
24 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
25 |
+
return emb
|
26 |
+
|
27 |
+
|
28 |
+
def Conv1d(*args, **kwargs):
|
29 |
+
layer = nn.Conv1d(*args, **kwargs)
|
30 |
+
nn.init.kaiming_normal_(layer.weight)
|
31 |
+
return layer
|
32 |
+
|
33 |
+
|
34 |
+
class FFT(FastspeechDecoder): # unused, because DiffSinger only uses FastspeechEncoder
|
35 |
+
# NOTE: this part of script is *isolated* from other scripts, which means
|
36 |
+
# it may not be compatible with the current version.
|
37 |
+
|
38 |
+
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
|
39 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
40 |
+
dim = hparams['residual_channels']
|
41 |
+
self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1)
|
42 |
+
self.diffusion_embedding = SinusoidalPosEmb(dim)
|
43 |
+
self.mlp = nn.Sequential(
|
44 |
+
nn.Linear(dim, dim * 4),
|
45 |
+
Mish(),
|
46 |
+
nn.Linear(dim * 4, dim)
|
47 |
+
)
|
48 |
+
self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True)
|
49 |
+
self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim,
|
50 |
+
hparams['hidden_size']) # hs + dim + 80 -> hs
|
51 |
+
|
52 |
+
def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False):
|
53 |
+
"""
|
54 |
+
:param spec: [B, 1, 80, T]
|
55 |
+
:param diffusion_step: [B, 1]
|
56 |
+
:param cond: [B, M, T]
|
57 |
+
:return:
|
58 |
+
"""
|
59 |
+
x = spec[:, 0]
|
60 |
+
x = self.input_projection(x).permute([0, 2, 1]) # [B, T, residual_channel]
|
61 |
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
62 |
+
diffusion_step = self.mlp(diffusion_step) # [B, dim]
|
63 |
+
cond = cond.permute([0, 2, 1]) # [B, T, M]
|
64 |
+
|
65 |
+
seq_len = cond.shape[1] # [T_mel]
|
66 |
+
time_embed = diffusion_step[:, None, :] # [B, 1, dim]
|
67 |
+
time_embed = time_embed.repeat([1, seq_len, 1]) # # [B, T, dim]
|
68 |
+
|
69 |
+
decoder_inp = torch.cat([x, cond, time_embed], dim=-1) # [B, T, dim + H + dim]
|
70 |
+
decoder_inp = self.get_decode_inp(decoder_inp) # [B, T, H]
|
71 |
+
x = decoder_inp
|
72 |
+
|
73 |
+
'''
|
74 |
+
Required x: [B, T, C]
|
75 |
+
:return: [B, T, C] or [L, B, T, C]
|
76 |
+
'''
|
77 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
78 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
79 |
+
if self.use_pos_embed:
|
80 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
81 |
+
x = x + positions
|
82 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
83 |
+
# B x T x C -> T x B x C
|
84 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
85 |
+
hiddens = []
|
86 |
+
for layer in self.layers:
|
87 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
88 |
+
hiddens.append(x)
|
89 |
+
if self.use_last_norm:
|
90 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
91 |
+
if return_hiddens:
|
92 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
93 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
94 |
+
else:
|
95 |
+
x = x.transpose(0, 1) # [B, T, C]
|
96 |
+
|
97 |
+
x = self.get_mel_out(x).permute([0, 2, 1]) # [B, 80, T]
|
98 |
+
return x[:, None, :, :]
|
network/diff/diffusion.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from functools import partial
|
3 |
+
from inspect import isfunction
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from modules.fastspeech.fs2 import FastSpeech2
|
11 |
+
# from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
12 |
+
from utils.hparams import hparams
|
13 |
+
from training.train_pipeline import Batch2Loss
|
14 |
+
|
15 |
+
|
16 |
+
def exists(x):
|
17 |
+
return x is not None
|
18 |
+
|
19 |
+
|
20 |
+
def default(val, d):
|
21 |
+
if exists(val):
|
22 |
+
return val
|
23 |
+
return d() if isfunction(d) else d
|
24 |
+
|
25 |
+
|
26 |
+
# gaussian diffusion trainer class
|
27 |
+
|
28 |
+
def extract(a, t, x_shape):
|
29 |
+
b, *_ = t.shape
|
30 |
+
out = a.gather(-1, t)
|
31 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
32 |
+
|
33 |
+
|
34 |
+
def noise_like(shape, device, repeat=False):
|
35 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
36 |
+
noise = lambda: torch.randn(shape, device=device)
|
37 |
+
return repeat_noise() if repeat else noise()
|
38 |
+
|
39 |
+
|
40 |
+
def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
|
41 |
+
"""
|
42 |
+
linear schedule
|
43 |
+
"""
|
44 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
45 |
+
return betas
|
46 |
+
|
47 |
+
|
48 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
49 |
+
"""
|
50 |
+
cosine schedule
|
51 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
52 |
+
"""
|
53 |
+
steps = timesteps + 1
|
54 |
+
x = np.linspace(0, steps, steps)
|
55 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
56 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
57 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
58 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
59 |
+
|
60 |
+
|
61 |
+
beta_schedule = {
|
62 |
+
"cosine": cosine_beta_schedule,
|
63 |
+
"linear": linear_beta_schedule,
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
class GaussianDiffusion(nn.Module):
|
68 |
+
def __init__(self, phone_encoder, out_dims, denoise_fn,
|
69 |
+
timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None,
|
70 |
+
spec_max=None):
|
71 |
+
super().__init__()
|
72 |
+
self.denoise_fn = denoise_fn
|
73 |
+
# if hparams.get('use_midi') is not None and hparams['use_midi']:
|
74 |
+
# self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
|
75 |
+
# else:
|
76 |
+
self.fs2 = FastSpeech2(phone_encoder, out_dims)
|
77 |
+
self.mel_bins = out_dims
|
78 |
+
|
79 |
+
if exists(betas):
|
80 |
+
betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
|
81 |
+
else:
|
82 |
+
if 'schedule_type' in hparams.keys():
|
83 |
+
betas = beta_schedule[hparams['schedule_type']](timesteps)
|
84 |
+
else:
|
85 |
+
betas = cosine_beta_schedule(timesteps)
|
86 |
+
|
87 |
+
alphas = 1. - betas
|
88 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
89 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
90 |
+
|
91 |
+
timesteps, = betas.shape
|
92 |
+
self.num_timesteps = int(timesteps)
|
93 |
+
self.K_step = K_step
|
94 |
+
self.loss_type = loss_type
|
95 |
+
|
96 |
+
self.noise_list = deque(maxlen=4)
|
97 |
+
|
98 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
99 |
+
|
100 |
+
self.register_buffer('betas', to_torch(betas))
|
101 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
102 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
103 |
+
|
104 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
105 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
106 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
107 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
108 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
109 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
110 |
+
|
111 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
112 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
113 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
114 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
115 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
116 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
117 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
118 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
119 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
120 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
121 |
+
|
122 |
+
self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
|
123 |
+
self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
|
124 |
+
|
125 |
+
def q_mean_variance(self, x_start, t):
|
126 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
127 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
128 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
129 |
+
return mean, variance, log_variance
|
130 |
+
|
131 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
132 |
+
return (
|
133 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
134 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
135 |
+
)
|
136 |
+
|
137 |
+
def q_posterior(self, x_start, x_t, t):
|
138 |
+
posterior_mean = (
|
139 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
140 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
141 |
+
)
|
142 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
143 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
144 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
145 |
+
|
146 |
+
def p_mean_variance(self, x, t, cond, clip_denoised: bool):
|
147 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
148 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
149 |
+
|
150 |
+
if clip_denoised:
|
151 |
+
x_recon.clamp_(-1., 1.)
|
152 |
+
|
153 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
154 |
+
return model_mean, posterior_variance, posterior_log_variance
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
158 |
+
b, *_, device = *x.shape, x.device
|
159 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
160 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
161 |
+
# no noise when t == 0
|
162 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
163 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
164 |
+
|
165 |
+
@torch.no_grad()
|
166 |
+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
167 |
+
"""
|
168 |
+
Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
169 |
+
"""
|
170 |
+
|
171 |
+
def get_x_pred(x, noise_t, t):
|
172 |
+
a_t = extract(self.alphas_cumprod, t, x.shape)
|
173 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
174 |
+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
175 |
+
|
176 |
+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
177 |
+
x_pred = x + x_delta
|
178 |
+
|
179 |
+
return x_pred
|
180 |
+
|
181 |
+
noise_list = self.noise_list
|
182 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
183 |
+
|
184 |
+
if len(noise_list) == 0:
|
185 |
+
x_pred = get_x_pred(x, noise_pred, t)
|
186 |
+
noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
|
187 |
+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
188 |
+
elif len(noise_list) == 1:
|
189 |
+
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
190 |
+
elif len(noise_list) == 2:
|
191 |
+
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
192 |
+
elif len(noise_list) >= 3:
|
193 |
+
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
194 |
+
|
195 |
+
x_prev = get_x_pred(x, noise_pred_prime, t)
|
196 |
+
noise_list.append(noise_pred)
|
197 |
+
|
198 |
+
return x_prev
|
199 |
+
|
200 |
+
def q_sample(self, x_start, t, noise=None):
|
201 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
202 |
+
return (
|
203 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
204 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
205 |
+
)
|
206 |
+
|
207 |
+
def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
|
208 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
209 |
+
|
210 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
211 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
212 |
+
|
213 |
+
if self.loss_type == 'l1':
|
214 |
+
if nonpadding is not None:
|
215 |
+
loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
|
216 |
+
else:
|
217 |
+
# print('are you sure w/o nonpadding?')
|
218 |
+
loss = (noise - x_recon).abs().mean()
|
219 |
+
|
220 |
+
elif self.loss_type == 'l2':
|
221 |
+
loss = F.mse_loss(noise, x_recon)
|
222 |
+
else:
|
223 |
+
raise NotImplementedError()
|
224 |
+
|
225 |
+
return loss
|
226 |
+
|
227 |
+
def forward(self, hubert, mel2ph=None, spk_embed=None,
|
228 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
229 |
+
'''
|
230 |
+
conditioning diffusion, use fastspeech2 encoder output as the condition
|
231 |
+
'''
|
232 |
+
ret = self.fs2(hubert, mel2ph, spk_embed, None, f0, uv, energy,
|
233 |
+
skip_decoder=True, infer=infer, **kwargs)
|
234 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
235 |
+
b, *_, device = *hubert.shape, hubert.device
|
236 |
+
|
237 |
+
if not infer:
|
238 |
+
Batch2Loss.module4(
|
239 |
+
self.p_losses,
|
240 |
+
self.norm_spec(ref_mels), cond, ret, self.K_step, b, device
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
'''
|
244 |
+
ret['fs2_mel'] = ret['mel_out']
|
245 |
+
fs2_mels = ret['mel_out']
|
246 |
+
t = self.K_step
|
247 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
248 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
249 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
250 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
251 |
+
print('===> gaussion start.')
|
252 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
253 |
+
x = torch.randn(shape, device=device)
|
254 |
+
'''
|
255 |
+
if 'use_gt_mel' in kwargs.keys() and kwargs['use_gt_mel']:
|
256 |
+
t =kwargs['add_noise_step']
|
257 |
+
print('===>using ground truth mel as start, please make sure parameter "key==0" !')
|
258 |
+
fs2_mels = ref_mels
|
259 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
260 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
261 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
262 |
+
# for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
263 |
+
# x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
264 |
+
else:
|
265 |
+
t = self.K_step
|
266 |
+
#print('===> gaussion start.')
|
267 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
268 |
+
x = torch.randn(shape, device=device)
|
269 |
+
if hparams.get('pndm_speedup') and hparams['pndm_speedup'] > 1:
|
270 |
+
self.noise_list = deque(maxlen=4)
|
271 |
+
iteration_interval =hparams['pndm_speedup']
|
272 |
+
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|
273 |
+
total=t // iteration_interval):
|
274 |
+
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
|
275 |
+
cond)
|
276 |
+
else:
|
277 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
278 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
279 |
+
x = x[:, 0].transpose(1, 2)
|
280 |
+
if mel2ph is not None: # for singing
|
281 |
+
ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
|
282 |
+
else:
|
283 |
+
ret['mel_out'] = self.denorm_spec(x)
|
284 |
+
return ret
|
285 |
+
|
286 |
+
def norm_spec(self, x):
|
287 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
288 |
+
|
289 |
+
def denorm_spec(self, x):
|
290 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
291 |
+
|
292 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
293 |
+
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
294 |
+
|
295 |
+
def out2mel(self, x):
|
296 |
+
return x
|
297 |
+
|
298 |
+
|
299 |
+
class OfflineGaussianDiffusion(GaussianDiffusion):
|
300 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
301 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
302 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
303 |
+
|
304 |
+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
305 |
+
skip_decoder=True, infer=True, **kwargs)
|
306 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
307 |
+
fs2_mels = ref_mels[1]
|
308 |
+
ref_mels = ref_mels[0]
|
309 |
+
|
310 |
+
if not infer:
|
311 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
312 |
+
x = ref_mels
|
313 |
+
x = self.norm_spec(x)
|
314 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
315 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
316 |
+
else:
|
317 |
+
t = self.K_step
|
318 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
319 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
320 |
+
|
321 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
322 |
+
|
323 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
324 |
+
print('===> gaussion start.')
|
325 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
326 |
+
x = torch.randn(shape, device=device)
|
327 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
328 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
329 |
+
x = x[:, 0].transpose(1, 2)
|
330 |
+
ret['mel_out'] = self.denorm_spec(x)
|
331 |
+
|
332 |
+
return ret
|
network/diff/net.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from math import sqrt
|
8 |
+
|
9 |
+
from utils.hparams import hparams
|
10 |
+
from modules.commons.common_layers import Mish
|
11 |
+
|
12 |
+
Linear = nn.Linear
|
13 |
+
ConvTranspose2d = nn.ConvTranspose2d
|
14 |
+
|
15 |
+
|
16 |
+
class AttrDict(dict):
|
17 |
+
def __init__(self, *args, **kwargs):
|
18 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
19 |
+
self.__dict__ = self
|
20 |
+
|
21 |
+
def override(self, attrs):
|
22 |
+
if isinstance(attrs, dict):
|
23 |
+
self.__dict__.update(**attrs)
|
24 |
+
elif isinstance(attrs, (list, tuple, set)):
|
25 |
+
for attr in attrs:
|
26 |
+
self.override(attr)
|
27 |
+
elif attrs is not None:
|
28 |
+
raise NotImplementedError
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
class SinusoidalPosEmb(nn.Module):
|
33 |
+
def __init__(self, dim):
|
34 |
+
super().__init__()
|
35 |
+
self.dim = dim
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
device = x.device
|
39 |
+
half_dim = self.dim // 2
|
40 |
+
emb = math.log(10000) / (half_dim - 1)
|
41 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
42 |
+
emb = x[:, None] * emb[None, :]
|
43 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
44 |
+
return emb
|
45 |
+
|
46 |
+
|
47 |
+
def Conv1d(*args, **kwargs):
|
48 |
+
layer = nn.Conv1d(*args, **kwargs)
|
49 |
+
nn.init.kaiming_normal_(layer.weight)
|
50 |
+
return layer
|
51 |
+
|
52 |
+
|
53 |
+
@torch.jit.script
|
54 |
+
def silu(x):
|
55 |
+
return x * torch.sigmoid(x)
|
56 |
+
|
57 |
+
|
58 |
+
class ResidualBlock(nn.Module):
|
59 |
+
def __init__(self, encoder_hidden, residual_channels, dilation):
|
60 |
+
super().__init__()
|
61 |
+
self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
|
62 |
+
self.diffusion_projection = Linear(residual_channels, residual_channels)
|
63 |
+
self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
|
64 |
+
self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
|
65 |
+
|
66 |
+
def forward(self, x, conditioner, diffusion_step):
|
67 |
+
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
68 |
+
conditioner = self.conditioner_projection(conditioner)
|
69 |
+
y = x + diffusion_step
|
70 |
+
|
71 |
+
y = self.dilated_conv(y) + conditioner
|
72 |
+
|
73 |
+
gate, filter = torch.chunk(y, 2, dim=1)
|
74 |
+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
75 |
+
# gate, filter = torch.split(y, torch.div(y.shape[1], 2), dim=1)
|
76 |
+
|
77 |
+
y = torch.sigmoid(gate) * torch.tanh(filter)
|
78 |
+
|
79 |
+
y = self.output_projection(y)
|
80 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
81 |
+
# Using torch.split instead of torch.chunk to avoid using onnx::Slice
|
82 |
+
# residual, skip = torch.split(y, torch.div(y.shape[1], 2), dim=1)
|
83 |
+
|
84 |
+
return (x + residual) / sqrt(2.0), skip
|
85 |
+
|
86 |
+
class DiffNet(nn.Module):
|
87 |
+
def __init__(self, in_dims=80):
|
88 |
+
super().__init__()
|
89 |
+
self.params = params = AttrDict(
|
90 |
+
# Model params
|
91 |
+
encoder_hidden=hparams['hidden_size'],
|
92 |
+
residual_layers=hparams['residual_layers'],
|
93 |
+
residual_channels=hparams['residual_channels'],
|
94 |
+
dilation_cycle_length=hparams['dilation_cycle_length'],
|
95 |
+
)
|
96 |
+
self.input_projection = Conv1d(in_dims, params.residual_channels, 1)
|
97 |
+
self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels)
|
98 |
+
dim = params.residual_channels
|
99 |
+
self.mlp = nn.Sequential(
|
100 |
+
nn.Linear(dim, dim * 4),
|
101 |
+
Mish(),
|
102 |
+
nn.Linear(dim * 4, dim)
|
103 |
+
)
|
104 |
+
self.residual_layers = nn.ModuleList([
|
105 |
+
ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length))
|
106 |
+
for i in range(params.residual_layers)
|
107 |
+
])
|
108 |
+
self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
|
109 |
+
self.output_projection = Conv1d(params.residual_channels, in_dims, 1)
|
110 |
+
nn.init.zeros_(self.output_projection.weight)
|
111 |
+
|
112 |
+
def forward(self, spec, diffusion_step, cond):
|
113 |
+
"""
|
114 |
+
|
115 |
+
:param spec: [B, 1, M, T]
|
116 |
+
:param diffusion_step: [B, 1]
|
117 |
+
:param cond: [B, M, T]
|
118 |
+
:return:
|
119 |
+
"""
|
120 |
+
x = spec[:, 0]
|
121 |
+
x = self.input_projection(x) # x [B, residual_channel, T]
|
122 |
+
|
123 |
+
x = F.relu(x)
|
124 |
+
diffusion_step = self.diffusion_embedding(diffusion_step)
|
125 |
+
diffusion_step = self.mlp(diffusion_step)
|
126 |
+
skip = []
|
127 |
+
for layer_id, layer in enumerate(self.residual_layers):
|
128 |
+
x, skip_connection = layer(x, cond, diffusion_step)
|
129 |
+
skip.append(skip_connection)
|
130 |
+
|
131 |
+
x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
|
132 |
+
x = self.skip_projection(x)
|
133 |
+
x = F.relu(x)
|
134 |
+
x = self.output_projection(x) # [B, 80, T]
|
135 |
+
return x[:, None, :, :]
|
network/hubert/hubert_model.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as t_func
|
11 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
12 |
+
|
13 |
+
from utils import hparams
|
14 |
+
|
15 |
+
|
16 |
+
class Hubert(nn.Module):
|
17 |
+
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
|
18 |
+
super().__init__()
|
19 |
+
self._mask = mask
|
20 |
+
self.feature_extractor = FeatureExtractor()
|
21 |
+
self.feature_projection = FeatureProjection()
|
22 |
+
self.positional_embedding = PositionalConvEmbedding()
|
23 |
+
self.norm = nn.LayerNorm(768)
|
24 |
+
self.dropout = nn.Dropout(0.1)
|
25 |
+
self.encoder = TransformerEncoder(
|
26 |
+
nn.TransformerEncoderLayer(
|
27 |
+
768, 12, 3072, activation="gelu", batch_first=True
|
28 |
+
),
|
29 |
+
12,
|
30 |
+
)
|
31 |
+
self.proj = nn.Linear(768, 256)
|
32 |
+
|
33 |
+
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
|
34 |
+
self.label_embedding = nn.Embedding(num_label_embeddings, 256)
|
35 |
+
|
36 |
+
def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
37 |
+
mask = None
|
38 |
+
if self.training and self._mask:
|
39 |
+
mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
|
40 |
+
x[mask] = self.masked_spec_embed.to(x.dtype)
|
41 |
+
return x, mask
|
42 |
+
|
43 |
+
def encode(
|
44 |
+
self, x: torch.Tensor, layer: Optional[int] = None
|
45 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
46 |
+
x = self.feature_extractor(x)
|
47 |
+
x = self.feature_projection(x.transpose(1, 2))
|
48 |
+
x, mask = self.mask(x)
|
49 |
+
x = x + self.positional_embedding(x)
|
50 |
+
x = self.dropout(self.norm(x))
|
51 |
+
x = self.encoder(x, output_layer=layer)
|
52 |
+
return x, mask
|
53 |
+
|
54 |
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
55 |
+
logits = torch.cosine_similarity(
|
56 |
+
x.unsqueeze(2),
|
57 |
+
self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
|
58 |
+
dim=-1,
|
59 |
+
)
|
60 |
+
return logits / 0.1
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
63 |
+
x, mask = self.encode(x)
|
64 |
+
x = self.proj(x)
|
65 |
+
logits = self.logits(x)
|
66 |
+
return logits, mask
|
67 |
+
|
68 |
+
|
69 |
+
class HubertSoft(Hubert):
|
70 |
+
def __init__(self):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
# @torch.inference_mode()
|
74 |
+
def units(self, wav: torch.Tensor) -> torch.Tensor:
|
75 |
+
wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
|
76 |
+
x, _ = self.encode(wav)
|
77 |
+
return self.proj(x)
|
78 |
+
|
79 |
+
def forward(self, wav: torch.Tensor):
|
80 |
+
return self.units(wav)
|
81 |
+
|
82 |
+
|
83 |
+
class FeatureExtractor(nn.Module):
|
84 |
+
def __init__(self):
|
85 |
+
super().__init__()
|
86 |
+
self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
|
87 |
+
self.norm0 = nn.GroupNorm(512, 512)
|
88 |
+
self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
89 |
+
self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
90 |
+
self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
91 |
+
self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
|
92 |
+
self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
93 |
+
self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
|
94 |
+
|
95 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
96 |
+
x = t_func.gelu(self.norm0(self.conv0(x)))
|
97 |
+
x = t_func.gelu(self.conv1(x))
|
98 |
+
x = t_func.gelu(self.conv2(x))
|
99 |
+
x = t_func.gelu(self.conv3(x))
|
100 |
+
x = t_func.gelu(self.conv4(x))
|
101 |
+
x = t_func.gelu(self.conv5(x))
|
102 |
+
x = t_func.gelu(self.conv6(x))
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class FeatureProjection(nn.Module):
|
107 |
+
def __init__(self):
|
108 |
+
super().__init__()
|
109 |
+
self.norm = nn.LayerNorm(512)
|
110 |
+
self.projection = nn.Linear(512, 768)
|
111 |
+
self.dropout = nn.Dropout(0.1)
|
112 |
+
|
113 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
114 |
+
x = self.norm(x)
|
115 |
+
x = self.projection(x)
|
116 |
+
x = self.dropout(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class PositionalConvEmbedding(nn.Module):
|
121 |
+
def __init__(self):
|
122 |
+
super().__init__()
|
123 |
+
self.conv = nn.Conv1d(
|
124 |
+
768,
|
125 |
+
768,
|
126 |
+
kernel_size=128,
|
127 |
+
padding=128 // 2,
|
128 |
+
groups=16,
|
129 |
+
)
|
130 |
+
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
131 |
+
|
132 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
133 |
+
x = self.conv(x.transpose(1, 2))
|
134 |
+
x = t_func.gelu(x[:, :, :-1])
|
135 |
+
return x.transpose(1, 2)
|
136 |
+
|
137 |
+
|
138 |
+
class TransformerEncoder(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
|
141 |
+
) -> None:
|
142 |
+
super(TransformerEncoder, self).__init__()
|
143 |
+
self.layers = nn.ModuleList(
|
144 |
+
[copy.deepcopy(encoder_layer) for _ in range(num_layers)]
|
145 |
+
)
|
146 |
+
self.num_layers = num_layers
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
src: torch.Tensor,
|
151 |
+
mask: torch.Tensor = None,
|
152 |
+
src_key_padding_mask: torch.Tensor = None,
|
153 |
+
output_layer: Optional[int] = None,
|
154 |
+
) -> torch.Tensor:
|
155 |
+
output = src
|
156 |
+
for layer in self.layers[:output_layer]:
|
157 |
+
output = layer(
|
158 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
159 |
+
)
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
def _compute_mask(
|
164 |
+
shape: Tuple[int, int],
|
165 |
+
mask_prob: float,
|
166 |
+
mask_length: int,
|
167 |
+
device: torch.device,
|
168 |
+
min_masks: int = 0,
|
169 |
+
) -> torch.Tensor:
|
170 |
+
batch_size, sequence_length = shape
|
171 |
+
|
172 |
+
if mask_length < 1:
|
173 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
174 |
+
|
175 |
+
if mask_length > sequence_length:
|
176 |
+
raise ValueError(
|
177 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
178 |
+
)
|
179 |
+
|
180 |
+
# compute number of masked spans in batch
|
181 |
+
num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
|
182 |
+
num_masked_spans = max(num_masked_spans, min_masks)
|
183 |
+
|
184 |
+
# make sure num masked indices <= sequence_length
|
185 |
+
if num_masked_spans * mask_length > sequence_length:
|
186 |
+
num_masked_spans = sequence_length // mask_length
|
187 |
+
|
188 |
+
# SpecAugment mask to fill
|
189 |
+
mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
190 |
+
|
191 |
+
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
192 |
+
uniform_dist = torch.ones(
|
193 |
+
(batch_size, sequence_length - (mask_length - 1)), device=device
|
194 |
+
)
|
195 |
+
|
196 |
+
# get random indices to mask
|
197 |
+
mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
|
198 |
+
|
199 |
+
# expand masked indices to masked spans
|
200 |
+
mask_indices = (
|
201 |
+
mask_indices.unsqueeze(dim=-1)
|
202 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
203 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
204 |
+
)
|
205 |
+
offsets = (
|
206 |
+
torch.arange(mask_length, device=device)[None, None, :]
|
207 |
+
.expand((batch_size, num_masked_spans, mask_length))
|
208 |
+
.reshape(batch_size, num_masked_spans * mask_length)
|
209 |
+
)
|
210 |
+
mask_idxs = mask_indices + offsets
|
211 |
+
|
212 |
+
# scatter indices to mask
|
213 |
+
mask = mask.scatter(1, mask_idxs, True)
|
214 |
+
|
215 |
+
return mask
|
216 |
+
|
217 |
+
|
218 |
+
def hubert_soft(
|
219 |
+
path: str
|
220 |
+
) -> HubertSoft:
|
221 |
+
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
|
222 |
+
Args:
|
223 |
+
path (str): path of a pretrained model
|
224 |
+
"""
|
225 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
226 |
+
hubert = HubertSoft()
|
227 |
+
checkpoint = torch.load(path)
|
228 |
+
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
|
229 |
+
hubert.load_state_dict(checkpoint)
|
230 |
+
hubert.eval().to(dev)
|
231 |
+
return hubert
|
232 |
+
|
233 |
+
|
234 |
+
def get_units(hbt_soft, raw_wav_path, dev=torch.device('cuda')):
|
235 |
+
wav, sr = librosa.load(raw_wav_path, sr=None)
|
236 |
+
assert (sr >= 16000)
|
237 |
+
if len(wav.shape) > 1:
|
238 |
+
wav = librosa.to_mono(wav)
|
239 |
+
if sr != 16000:
|
240 |
+
wav16 = librosa.resample(wav, sr, 16000)
|
241 |
+
else:
|
242 |
+
wav16 = wav
|
243 |
+
dev = torch.device("cuda" if (dev == torch.device('cuda') and torch.cuda.is_available()) else "cpu")
|
244 |
+
torch.cuda.is_available() and torch.cuda.empty_cache()
|
245 |
+
with torch.inference_mode():
|
246 |
+
units = hbt_soft.units(torch.FloatTensor(wav16.astype(float)).unsqueeze(0).unsqueeze(0).to(dev))
|
247 |
+
return units
|
248 |
+
|
249 |
+
|
250 |
+
def get_end_file(dir_path, end):
|
251 |
+
file_list = []
|
252 |
+
for root, dirs, files in os.walk(dir_path):
|
253 |
+
files = [f for f in files if f[0] != '.']
|
254 |
+
dirs[:] = [d for d in dirs if d[0] != '.']
|
255 |
+
for f_file in files:
|
256 |
+
if f_file.endswith(end):
|
257 |
+
file_list.append(os.path.join(root, f_file).replace("\\", "/"))
|
258 |
+
return file_list
|
259 |
+
|
260 |
+
|
261 |
+
if __name__ == '__main__':
|
262 |
+
from pathlib import Path
|
263 |
+
|
264 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
265 |
+
# hubert的模型路径
|
266 |
+
hbt_model = hubert_soft(str(list(Path(hparams['hubert_path']).home().rglob('*.pt'))[0]))
|
267 |
+
# 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
|
268 |
+
file_lists = list(Path(hparams['raw_data_dir']).rglob('*.wav'))
|
269 |
+
nums = len(file_lists)
|
270 |
+
count = 0
|
271 |
+
for wav_path in file_lists:
|
272 |
+
npy_path = wav_path.with_suffix(".npy")
|
273 |
+
npy_content = get_units(hbt_model, wav_path).cpu().numpy()[0]
|
274 |
+
np.save(str(npy_path), npy_content)
|
275 |
+
count += 1
|
276 |
+
print(f"hubert process:{round(count * 100 / nums, 2)}%")
|
network/hubert/vec_model.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def load_model(vec_path):
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
print("load model(s) from {}".format(vec_path))
|
12 |
+
from fairseq import checkpoint_utils
|
13 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
14 |
+
[vec_path],
|
15 |
+
suffix="",
|
16 |
+
)
|
17 |
+
model = models[0]
|
18 |
+
model = model.to(device)
|
19 |
+
model.eval()
|
20 |
+
return model
|
21 |
+
|
22 |
+
|
23 |
+
def get_vec_units(con_model, audio_path, dev):
|
24 |
+
audio, sampling_rate = librosa.load(audio_path)
|
25 |
+
if len(audio.shape) > 1:
|
26 |
+
audio = librosa.to_mono(audio.transpose(1, 0))
|
27 |
+
if sampling_rate != 16000:
|
28 |
+
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
29 |
+
|
30 |
+
feats = torch.from_numpy(audio).float()
|
31 |
+
if feats.dim() == 2: # double channels
|
32 |
+
feats = feats.mean(-1)
|
33 |
+
assert feats.dim() == 1, feats.dim()
|
34 |
+
feats = feats.view(1, -1)
|
35 |
+
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
|
36 |
+
inputs = {
|
37 |
+
"source": feats.to(dev),
|
38 |
+
"padding_mask": padding_mask.to(dev),
|
39 |
+
"output_layer": 9, # layer 9
|
40 |
+
}
|
41 |
+
with torch.no_grad():
|
42 |
+
logits = con_model.extract_features(**inputs)
|
43 |
+
feats = con_model.final_proj(logits[0])
|
44 |
+
return feats
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
model_path = "../../checkpoints/checkpoint_best_legacy_500.pt" # checkpoint_best_legacy_500.pt
|
50 |
+
vec_model = load_model(model_path)
|
51 |
+
# 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
|
52 |
+
file_lists = list(Path("../../data/vecfox").rglob('*.wav'))
|
53 |
+
nums = len(file_lists)
|
54 |
+
count = 0
|
55 |
+
for wav_path in file_lists:
|
56 |
+
npy_path = wav_path.with_suffix(".npy")
|
57 |
+
npy_content = get_vec_units(vec_model, str(wav_path), device).cpu().numpy()[0]
|
58 |
+
np.save(str(npy_path), npy_content)
|
59 |
+
count += 1
|
60 |
+
print(f"hubert process:{round(count * 100 / nums, 2)}%")
|
network/vocoders/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from network.vocoders import hifigan
|
2 |
+
from network.vocoders import nsf_hifigan
|
network/vocoders/base_vocoder.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
VOCODERS = {}
|
3 |
+
|
4 |
+
|
5 |
+
def register_vocoder(cls):
|
6 |
+
VOCODERS[cls.__name__.lower()] = cls
|
7 |
+
VOCODERS[cls.__name__] = cls
|
8 |
+
return cls
|
9 |
+
|
10 |
+
|
11 |
+
def get_vocoder_cls(hparams):
|
12 |
+
if hparams['vocoder'] in VOCODERS:
|
13 |
+
return VOCODERS[hparams['vocoder']]
|
14 |
+
else:
|
15 |
+
vocoder_cls = hparams['vocoder']
|
16 |
+
pkg = ".".join(vocoder_cls.split(".")[:-1])
|
17 |
+
cls_name = vocoder_cls.split(".")[-1]
|
18 |
+
vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
|
19 |
+
return vocoder_cls
|
20 |
+
|
21 |
+
|
22 |
+
class BaseVocoder:
|
23 |
+
def spec2wav(self, mel):
|
24 |
+
"""
|
25 |
+
|
26 |
+
:param mel: [T, 80]
|
27 |
+
:return: wav: [T']
|
28 |
+
"""
|
29 |
+
|
30 |
+
raise NotImplementedError
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def wav2spec(wav_fn):
|
34 |
+
"""
|
35 |
+
|
36 |
+
:param wav_fn: str
|
37 |
+
:return: wav, mel: [T, 80]
|
38 |
+
"""
|
39 |
+
raise NotImplementedError
|
network/vocoders/hifigan.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import utils
|
10 |
+
from modules.hifigan.hifigan import HifiGanGenerator
|
11 |
+
from utils.hparams import hparams, set_hparams
|
12 |
+
from network.vocoders.base_vocoder import register_vocoder
|
13 |
+
from network.vocoders.pwg import PWG
|
14 |
+
from network.vocoders.vocoder_utils import denoise
|
15 |
+
|
16 |
+
|
17 |
+
def load_model(config_path, file_path):
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
ext = os.path.splitext(file_path)[-1]
|
20 |
+
if ext == '.pth':
|
21 |
+
if '.yaml' in config_path:
|
22 |
+
config = set_hparams(config_path, global_hparams=False)
|
23 |
+
elif '.json' in config_path:
|
24 |
+
config = json.load(open(config_path, 'r', encoding='utf-8'))
|
25 |
+
model = torch.load(file_path, map_location="cpu")
|
26 |
+
elif ext == '.ckpt':
|
27 |
+
ckpt_dict = torch.load(file_path, map_location="cpu")
|
28 |
+
if '.yaml' in config_path:
|
29 |
+
config = set_hparams(config_path, global_hparams=False)
|
30 |
+
state = ckpt_dict["state_dict"]["model_gen"]
|
31 |
+
elif '.json' in config_path:
|
32 |
+
config = json.load(open(config_path, 'r', encoding='utf-8'))
|
33 |
+
state = ckpt_dict["generator"]
|
34 |
+
model = HifiGanGenerator(config)
|
35 |
+
model.load_state_dict(state, strict=True)
|
36 |
+
model.remove_weight_norm()
|
37 |
+
model = model.eval().to(device)
|
38 |
+
print(f"| Loaded model parameters from {file_path}.")
|
39 |
+
print(f"| HifiGAN device: {device}.")
|
40 |
+
return model, config, device
|
41 |
+
|
42 |
+
|
43 |
+
total_time = 0
|
44 |
+
|
45 |
+
|
46 |
+
@register_vocoder
|
47 |
+
class HifiGAN(PWG):
|
48 |
+
def __init__(self):
|
49 |
+
base_dir = hparams['vocoder_ckpt']
|
50 |
+
config_path = f'{base_dir}/config.yaml'
|
51 |
+
if os.path.exists(config_path):
|
52 |
+
file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
|
53 |
+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
|
54 |
+
print('| load HifiGAN: ', file_path)
|
55 |
+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
|
56 |
+
else:
|
57 |
+
config_path = f'{base_dir}/config.json'
|
58 |
+
ckpt = f'{base_dir}/generator_v1'
|
59 |
+
if os.path.exists(config_path):
|
60 |
+
self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
|
61 |
+
|
62 |
+
def spec2wav(self, mel, **kwargs):
|
63 |
+
device = self.device
|
64 |
+
with torch.no_grad():
|
65 |
+
c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
|
66 |
+
with utils.Timer('hifigan', print_time=hparams['profile_infer']):
|
67 |
+
f0 = kwargs.get('f0')
|
68 |
+
if f0 is not None and hparams.get('use_nsf'):
|
69 |
+
f0 = torch.FloatTensor(f0[None, :]).to(device)
|
70 |
+
y = self.model(c, f0).view(-1)
|
71 |
+
else:
|
72 |
+
y = self.model(c).view(-1)
|
73 |
+
wav_out = y.cpu().numpy()
|
74 |
+
if hparams.get('vocoder_denoise_c', 0.0) > 0:
|
75 |
+
wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c'])
|
76 |
+
return wav_out
|
77 |
+
|
78 |
+
# @staticmethod
|
79 |
+
# def wav2spec(wav_fn, **kwargs):
|
80 |
+
# wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate'])
|
81 |
+
# wav_torch = torch.FloatTensor(wav)[None, :]
|
82 |
+
# mel = mel_spectrogram(wav_torch, hparams).numpy()[0]
|
83 |
+
# return wav, mel.T
|
network/vocoders/nsf_hifigan.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from modules.nsf_hifigan.models import load_model, Generator
|
4 |
+
from modules.nsf_hifigan.nvSTFT import load_wav_to_torch, STFT
|
5 |
+
from utils.hparams import hparams
|
6 |
+
from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
|
7 |
+
|
8 |
+
@register_vocoder
|
9 |
+
class NsfHifiGAN(BaseVocoder):
|
10 |
+
def __init__(self, device=None):
|
11 |
+
if device is None:
|
12 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
self.device = device
|
14 |
+
model_path = hparams['vocoder_ckpt']
|
15 |
+
if os.path.exists(model_path):
|
16 |
+
print('| Load HifiGAN: ', model_path)
|
17 |
+
self.model, self.h = load_model(model_path, device=self.device)
|
18 |
+
else:
|
19 |
+
print('Error: HifiGAN model file is not found!')
|
20 |
+
|
21 |
+
def spec2wav_torch(self, mel, **kwargs): # mel: [B, T, bins]
|
22 |
+
if self.h.sampling_rate != hparams['audio_sample_rate']:
|
23 |
+
print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
|
24 |
+
if self.h.num_mels != hparams['audio_num_mel_bins']:
|
25 |
+
print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
|
26 |
+
if self.h.n_fft != hparams['fft_size']:
|
27 |
+
print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
|
28 |
+
if self.h.win_size != hparams['win_size']:
|
29 |
+
print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
|
30 |
+
if self.h.hop_size != hparams['hop_size']:
|
31 |
+
print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
|
32 |
+
if self.h.fmin != hparams['fmin']:
|
33 |
+
print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
|
34 |
+
if self.h.fmax != hparams['fmax']:
|
35 |
+
print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
|
36 |
+
with torch.no_grad():
|
37 |
+
c = mel.transpose(2, 1) #[B, T, bins]
|
38 |
+
#log10 to log mel
|
39 |
+
c = 2.30259 * c
|
40 |
+
f0 = kwargs.get('f0') #[B, T]
|
41 |
+
if f0 is not None and hparams.get('use_nsf'):
|
42 |
+
y = self.model(c, f0).view(-1)
|
43 |
+
else:
|
44 |
+
y = self.model(c).view(-1)
|
45 |
+
return y
|
46 |
+
|
47 |
+
def spec2wav(self, mel, **kwargs):
|
48 |
+
if self.h.sampling_rate != hparams['audio_sample_rate']:
|
49 |
+
print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
|
50 |
+
if self.h.num_mels != hparams['audio_num_mel_bins']:
|
51 |
+
print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
|
52 |
+
if self.h.n_fft != hparams['fft_size']:
|
53 |
+
print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
|
54 |
+
if self.h.win_size != hparams['win_size']:
|
55 |
+
print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
|
56 |
+
if self.h.hop_size != hparams['hop_size']:
|
57 |
+
print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
|
58 |
+
if self.h.fmin != hparams['fmin']:
|
59 |
+
print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
|
60 |
+
if self.h.fmax != hparams['fmax']:
|
61 |
+
print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
|
62 |
+
with torch.no_grad():
|
63 |
+
c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(self.device)
|
64 |
+
#log10 to log mel
|
65 |
+
c = 2.30259 * c
|
66 |
+
f0 = kwargs.get('f0')
|
67 |
+
if f0 is not None and hparams.get('use_nsf'):
|
68 |
+
f0 = torch.FloatTensor(f0[None, :]).to(self.device)
|
69 |
+
y = self.model(c, f0).view(-1)
|
70 |
+
else:
|
71 |
+
y = self.model(c).view(-1)
|
72 |
+
wav_out = y.cpu().numpy()
|
73 |
+
return wav_out
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def wav2spec(inp_path, device=None):
|
77 |
+
if device is None:
|
78 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
79 |
+
sampling_rate = hparams['audio_sample_rate']
|
80 |
+
num_mels = hparams['audio_num_mel_bins']
|
81 |
+
n_fft = hparams['fft_size']
|
82 |
+
win_size =hparams['win_size']
|
83 |
+
hop_size = hparams['hop_size']
|
84 |
+
fmin = hparams['fmin']
|
85 |
+
fmax = hparams['fmax']
|
86 |
+
stft = STFT(sampling_rate, num_mels, n_fft, win_size, hop_size, fmin, fmax)
|
87 |
+
with torch.no_grad():
|
88 |
+
wav_torch, _ = load_wav_to_torch(inp_path, target_sr=stft.target_sr)
|
89 |
+
mel_torch = stft.get_mel(wav_torch.unsqueeze(0).to(device)).squeeze(0).T
|
90 |
+
#log mel to log10 mel
|
91 |
+
mel_torch = 0.434294 * mel_torch
|
92 |
+
return wav_torch.cpu().numpy(), mel_torch.cpu().numpy()
|
network/vocoders/pwg.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import re
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
import yaml
|
6 |
+
from sklearn.preprocessing import StandardScaler
|
7 |
+
from torch import nn
|
8 |
+
from modules.parallel_wavegan.models import ParallelWaveGANGenerator
|
9 |
+
from modules.parallel_wavegan.utils import read_hdf5
|
10 |
+
from utils.hparams import hparams
|
11 |
+
from utils.pitch_utils import f0_to_coarse
|
12 |
+
from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def load_pwg_model(config_path, checkpoint_path, stats_path):
|
17 |
+
# load config
|
18 |
+
with open(config_path, encoding='utf-8') as f:
|
19 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
20 |
+
|
21 |
+
# setup
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
device = torch.device("cuda")
|
24 |
+
else:
|
25 |
+
device = torch.device("cpu")
|
26 |
+
model = ParallelWaveGANGenerator(**config["generator_params"])
|
27 |
+
|
28 |
+
ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
|
29 |
+
if 'state_dict' not in ckpt_dict: # official vocoder
|
30 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"])
|
31 |
+
scaler = StandardScaler()
|
32 |
+
if config["format"] == "hdf5":
|
33 |
+
scaler.mean_ = read_hdf5(stats_path, "mean")
|
34 |
+
scaler.scale_ = read_hdf5(stats_path, "scale")
|
35 |
+
elif config["format"] == "npy":
|
36 |
+
scaler.mean_ = np.load(stats_path)[0]
|
37 |
+
scaler.scale_ = np.load(stats_path)[1]
|
38 |
+
else:
|
39 |
+
raise ValueError("support only hdf5 or npy format.")
|
40 |
+
else: # custom PWG vocoder
|
41 |
+
fake_task = nn.Module()
|
42 |
+
fake_task.model_gen = model
|
43 |
+
fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False)
|
44 |
+
scaler = None
|
45 |
+
|
46 |
+
model.remove_weight_norm()
|
47 |
+
model = model.eval().to(device)
|
48 |
+
print(f"| Loaded model parameters from {checkpoint_path}.")
|
49 |
+
print(f"| PWG device: {device}.")
|
50 |
+
return model, scaler, config, device
|
51 |
+
|
52 |
+
|
53 |
+
@register_vocoder
|
54 |
+
class PWG(BaseVocoder):
|
55 |
+
def __init__(self):
|
56 |
+
if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model
|
57 |
+
base_dir = 'wavegan_pretrained'
|
58 |
+
ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
|
59 |
+
ckpt = sorted(ckpts, key=
|
60 |
+
lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
|
61 |
+
config_path = f'{base_dir}/config.yaml'
|
62 |
+
print('| load PWG: ', ckpt)
|
63 |
+
self.model, self.scaler, self.config, self.device = load_pwg_model(
|
64 |
+
config_path=config_path,
|
65 |
+
checkpoint_path=ckpt,
|
66 |
+
stats_path=f'{base_dir}/stats.h5',
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
base_dir = hparams['vocoder_ckpt']
|
70 |
+
print(base_dir)
|
71 |
+
config_path = f'{base_dir}/config.yaml'
|
72 |
+
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
73 |
+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
74 |
+
print('| load PWG: ', ckpt)
|
75 |
+
self.scaler = None
|
76 |
+
self.model, _, self.config, self.device = load_pwg_model(
|
77 |
+
config_path=config_path,
|
78 |
+
checkpoint_path=ckpt,
|
79 |
+
stats_path=f'{base_dir}/stats.h5',
|
80 |
+
)
|
81 |
+
|
82 |
+
def spec2wav(self, mel, **kwargs):
|
83 |
+
# start generation
|
84 |
+
config = self.config
|
85 |
+
device = self.device
|
86 |
+
pad_size = (config["generator_params"]["aux_context_window"],
|
87 |
+
config["generator_params"]["aux_context_window"])
|
88 |
+
c = mel
|
89 |
+
if self.scaler is not None:
|
90 |
+
c = self.scaler.transform(c)
|
91 |
+
|
92 |
+
with torch.no_grad():
|
93 |
+
z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
|
94 |
+
c = np.pad(c, (pad_size, (0, 0)), "edge")
|
95 |
+
c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
|
96 |
+
p = kwargs.get('f0')
|
97 |
+
if p is not None:
|
98 |
+
p = f0_to_coarse(p)
|
99 |
+
p = np.pad(p, (pad_size,), "edge")
|
100 |
+
p = torch.LongTensor(p[None, :]).to(device)
|
101 |
+
y = self.model(z, c, p).view(-1)
|
102 |
+
wav_out = y.cpu().numpy()
|
103 |
+
return wav_out
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def wav2spec(wav_fn, return_linear=False):
|
107 |
+
from preprocessing.data_gen_utils import process_utterance
|
108 |
+
res = process_utterance(
|
109 |
+
wav_fn, fft_size=hparams['fft_size'],
|
110 |
+
hop_size=hparams['hop_size'],
|
111 |
+
win_length=hparams['win_size'],
|
112 |
+
num_mels=hparams['audio_num_mel_bins'],
|
113 |
+
fmin=hparams['fmin'],
|
114 |
+
fmax=hparams['fmax'],
|
115 |
+
sample_rate=hparams['audio_sample_rate'],
|
116 |
+
loud_norm=hparams['loud_norm'],
|
117 |
+
min_level_db=hparams['min_level_db'],
|
118 |
+
return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10)))
|
119 |
+
if return_linear:
|
120 |
+
return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft]
|
121 |
+
else:
|
122 |
+
return res[0], res[1].T
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def wav2mfcc(wav_fn):
|
126 |
+
fft_size = hparams['fft_size']
|
127 |
+
hop_size = hparams['hop_size']
|
128 |
+
win_length = hparams['win_size']
|
129 |
+
sample_rate = hparams['audio_sample_rate']
|
130 |
+
wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
|
131 |
+
mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
|
132 |
+
n_fft=fft_size, hop_length=hop_size,
|
133 |
+
win_length=win_length, pad_mode="constant", power=1.0)
|
134 |
+
mfcc_delta = librosa.feature.delta(mfcc, order=1)
|
135 |
+
mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
|
136 |
+
mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
|
137 |
+
return mfcc
|
network/vocoders/vocoder_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
|
3 |
+
from utils.hparams import hparams
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def denoise(wav, v=0.1):
|
8 |
+
spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'],
|
9 |
+
win_length=hparams['win_size'], pad_mode='constant')
|
10 |
+
spec_m = np.abs(spec)
|
11 |
+
spec_m = np.clip(spec_m - v, a_min=0, a_max=None)
|
12 |
+
spec_a = np.angle(spec)
|
13 |
+
|
14 |
+
return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'],
|
15 |
+
win_length=hparams['win_size'])
|
preprocessing/SVCpre.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
|
3 |
+
item: one piece of data
|
4 |
+
item_name: data id
|
5 |
+
wavfn: wave file path
|
6 |
+
txt: lyrics
|
7 |
+
ph: phoneme
|
8 |
+
tgfn: text grid file path (unused)
|
9 |
+
spk: dataset name
|
10 |
+
wdb: word boundary
|
11 |
+
ph_durs: phoneme durations
|
12 |
+
midi: pitch as midi notes
|
13 |
+
midi_dur: midi duration
|
14 |
+
is_slur: keep singing upon note changes
|
15 |
+
'''
|
16 |
+
|
17 |
+
|
18 |
+
from copy import deepcopy
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
from preprocessing.process_pipeline import File2Batch
|
23 |
+
from utils.hparams import hparams
|
24 |
+
from preprocessing.base_binarizer import BaseBinarizer
|
25 |
+
|
26 |
+
SVCSINGING_ITEM_ATTRIBUTES = ['wav_fn', 'spk_id']
|
27 |
+
class SVCBinarizer(BaseBinarizer):
|
28 |
+
def __init__(self, item_attributes=SVCSINGING_ITEM_ATTRIBUTES):
|
29 |
+
super().__init__(item_attributes)
|
30 |
+
print('spkers: ', set(item['spk_id'] for item in self.items.values()))
|
31 |
+
self.item_names = sorted(list(self.items.keys()))
|
32 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
33 |
+
# self._valid_item_names=[]
|
34 |
+
|
35 |
+
def split_train_test_set(self, item_names):
|
36 |
+
item_names = deepcopy(item_names)
|
37 |
+
if hparams['choose_test_manually']:
|
38 |
+
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
|
39 |
+
else:
|
40 |
+
test_item_names = item_names[-5:]
|
41 |
+
train_item_names = [x for x in item_names if x not in set(test_item_names)]
|
42 |
+
logging.info("train {}".format(len(train_item_names)))
|
43 |
+
logging.info("test {}".format(len(test_item_names)))
|
44 |
+
return train_item_names, test_item_names
|
45 |
+
|
46 |
+
@property
|
47 |
+
def train_item_names(self):
|
48 |
+
return self._train_item_names
|
49 |
+
|
50 |
+
@property
|
51 |
+
def valid_item_names(self):
|
52 |
+
return self._test_item_names
|
53 |
+
|
54 |
+
@property
|
55 |
+
def test_item_names(self):
|
56 |
+
return self._test_item_names
|
57 |
+
|
58 |
+
def load_meta_data(self):
|
59 |
+
self.items = File2Batch.file2temporary_dict()
|
60 |
+
|
61 |
+
def _phone_encoder(self):
|
62 |
+
from preprocessing.hubertinfer import Hubertencoder
|
63 |
+
return Hubertencoder(hparams['hubert_path'])
|
preprocessing/base_binarizer.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from webbrowser import get
|
3 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
4 |
+
import yaml
|
5 |
+
from utils.multiprocess_utils import chunked_multiprocess_run
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
# from resemblyzer import VoiceEncoder
|
9 |
+
from tqdm import tqdm
|
10 |
+
from preprocessing.data_gen_utils import get_mel2ph, get_pitch_parselmouth, build_phone_encoder,get_pitch_crepe
|
11 |
+
from utils.hparams import set_hparams, hparams
|
12 |
+
import numpy as np
|
13 |
+
from utils.indexed_datasets import IndexedDatasetBuilder
|
14 |
+
|
15 |
+
|
16 |
+
class BinarizationError(Exception):
|
17 |
+
pass
|
18 |
+
|
19 |
+
BASE_ITEM_ATTRIBUTES = ['txt', 'ph', 'wav_fn', 'tg_fn', 'spk_id']
|
20 |
+
|
21 |
+
class BaseBinarizer:
|
22 |
+
'''
|
23 |
+
Base class for data processing.
|
24 |
+
1. *process* and *process_data_split*:
|
25 |
+
process entire data, generate the train-test split (support parallel processing);
|
26 |
+
2. *process_item*:
|
27 |
+
process singe piece of data;
|
28 |
+
3. *get_pitch*:
|
29 |
+
infer the pitch using some algorithm;
|
30 |
+
4. *get_align*:
|
31 |
+
get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
|
32 |
+
5. phoneme encoder, voice encoder, etc.
|
33 |
+
|
34 |
+
Subclasses should define:
|
35 |
+
1. *load_metadata*:
|
36 |
+
how to read multiple datasets from files;
|
37 |
+
2. *train_item_names*, *valid_item_names*, *test_item_names*:
|
38 |
+
how to split the dataset;
|
39 |
+
3. load_ph_set:
|
40 |
+
the phoneme set.
|
41 |
+
'''
|
42 |
+
def __init__(self, item_attributes=BASE_ITEM_ATTRIBUTES):
|
43 |
+
self.binarization_args = hparams['binarization_args']
|
44 |
+
#self.pre_align_args = hparams['pre_align_args']
|
45 |
+
|
46 |
+
self.items = {}
|
47 |
+
# every item in self.items has some attributes
|
48 |
+
self.item_attributes = item_attributes
|
49 |
+
|
50 |
+
self.load_meta_data()
|
51 |
+
# check program correctness 检查itemdict的key只能在给定的列表中取值
|
52 |
+
assert all([attr in self.item_attributes for attr in list(self.items.values())[0].keys()])
|
53 |
+
self.item_names = sorted(list(self.items.keys()))
|
54 |
+
|
55 |
+
if self.binarization_args['shuffle']:
|
56 |
+
random.seed(1234)
|
57 |
+
random.shuffle(self.item_names)
|
58 |
+
|
59 |
+
# set default get_pitch algorithm
|
60 |
+
if hparams['use_crepe']:
|
61 |
+
self.get_pitch_algorithm = get_pitch_crepe
|
62 |
+
else:
|
63 |
+
self.get_pitch_algorithm = get_pitch_parselmouth
|
64 |
+
|
65 |
+
def load_meta_data(self):
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
@property
|
69 |
+
def train_item_names(self):
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
@property
|
73 |
+
def valid_item_names(self):
|
74 |
+
raise NotImplementedError
|
75 |
+
|
76 |
+
@property
|
77 |
+
def test_item_names(self):
|
78 |
+
raise NotImplementedError
|
79 |
+
|
80 |
+
def build_spk_map(self):
|
81 |
+
spk_map = set()
|
82 |
+
for item_name in self.item_names:
|
83 |
+
spk_name = self.items[item_name]['spk_id']
|
84 |
+
spk_map.add(spk_name)
|
85 |
+
spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
|
86 |
+
assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
|
87 |
+
return spk_map
|
88 |
+
|
89 |
+
def item_name2spk_id(self, item_name):
|
90 |
+
return self.spk_map[self.items[item_name]['spk_id']]
|
91 |
+
|
92 |
+
def _phone_encoder(self):
|
93 |
+
'''
|
94 |
+
use hubert encoder
|
95 |
+
'''
|
96 |
+
raise NotImplementedError
|
97 |
+
'''
|
98 |
+
create 'phone_set.json' file if it doesn't exist
|
99 |
+
'''
|
100 |
+
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
|
101 |
+
ph_set = []
|
102 |
+
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
|
103 |
+
self.load_ph_set(ph_set)
|
104 |
+
ph_set = sorted(set(ph_set))
|
105 |
+
json.dump(ph_set, open(ph_set_fn, 'w', encoding='utf-8'))
|
106 |
+
print("| Build phone set: ", ph_set)
|
107 |
+
else:
|
108 |
+
ph_set = json.load(open(ph_set_fn, 'r', encoding='utf-8'))
|
109 |
+
print("| Load phone set: ", ph_set)
|
110 |
+
return build_phone_encoder(hparams['binary_data_dir'])
|
111 |
+
|
112 |
+
|
113 |
+
def load_ph_set(self, ph_set):
|
114 |
+
raise NotImplementedError
|
115 |
+
|
116 |
+
def meta_data_iterator(self, prefix):
|
117 |
+
if prefix == 'valid':
|
118 |
+
item_names = self.valid_item_names
|
119 |
+
elif prefix == 'test':
|
120 |
+
item_names = self.test_item_names
|
121 |
+
else:
|
122 |
+
item_names = self.train_item_names
|
123 |
+
for item_name in item_names:
|
124 |
+
meta_data = self.items[item_name]
|
125 |
+
yield item_name, meta_data
|
126 |
+
|
127 |
+
def process(self):
|
128 |
+
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
|
129 |
+
self.spk_map = self.build_spk_map()
|
130 |
+
print("| spk_map: ", self.spk_map)
|
131 |
+
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
|
132 |
+
json.dump(self.spk_map, open(spk_map_fn, 'w', encoding='utf-8'))
|
133 |
+
|
134 |
+
self.phone_encoder =self._phone_encoder()
|
135 |
+
self.process_data_split('valid')
|
136 |
+
self.process_data_split('test')
|
137 |
+
self.process_data_split('train')
|
138 |
+
|
139 |
+
def process_data_split(self, prefix):
|
140 |
+
data_dir = hparams['binary_data_dir']
|
141 |
+
args = []
|
142 |
+
builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
|
143 |
+
lengths = []
|
144 |
+
f0s = []
|
145 |
+
total_sec = 0
|
146 |
+
# if self.binarization_args['with_spk_embed']:
|
147 |
+
# voice_encoder = VoiceEncoder().cuda()
|
148 |
+
|
149 |
+
for item_name, meta_data in self.meta_data_iterator(prefix):
|
150 |
+
args.append([item_name, meta_data, self.binarization_args])
|
151 |
+
spec_min=[]
|
152 |
+
spec_max=[]
|
153 |
+
# code for single cpu processing
|
154 |
+
for i in tqdm(reversed(range(len(args))), total=len(args)):
|
155 |
+
a = args[i]
|
156 |
+
item = self.process_item(*a)
|
157 |
+
if item is None:
|
158 |
+
continue
|
159 |
+
spec_min.append(item['spec_min'])
|
160 |
+
spec_max.append(item['spec_max'])
|
161 |
+
# item['spk_embe'] = voice_encoder.embed_utterance(item['wav']) \
|
162 |
+
# if self.binardization_args['with_spk_embed'] else None
|
163 |
+
if not self.binarization_args['with_wav'] and 'wav' in item:
|
164 |
+
if hparams['debug']:
|
165 |
+
print("del wav")
|
166 |
+
del item['wav']
|
167 |
+
if(hparams['debug']):
|
168 |
+
print(item)
|
169 |
+
builder.add_item(item)
|
170 |
+
lengths.append(item['len'])
|
171 |
+
total_sec += item['sec']
|
172 |
+
# if item.get('f0') is not None:
|
173 |
+
# f0s.append(item['f0'])
|
174 |
+
if prefix=='train':
|
175 |
+
spec_max=np.max(spec_max,0)
|
176 |
+
spec_min=np.min(spec_min,0)
|
177 |
+
print(spec_max.shape)
|
178 |
+
with open(hparams['config_path'], encoding='utf-8') as f:
|
179 |
+
_hparams=yaml.safe_load(f)
|
180 |
+
_hparams['spec_max']=spec_max.tolist()
|
181 |
+
_hparams['spec_min']=spec_min.tolist()
|
182 |
+
with open(hparams['config_path'], 'w', encoding='utf-8') as f:
|
183 |
+
yaml.safe_dump(_hparams,f)
|
184 |
+
builder.finalize()
|
185 |
+
np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
|
186 |
+
if len(f0s) > 0:
|
187 |
+
f0s = np.concatenate(f0s, 0)
|
188 |
+
f0s = f0s[f0s != 0]
|
189 |
+
np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
|
190 |
+
print(f"| {prefix} total duration: {total_sec:.3f}s")
|
191 |
+
|
192 |
+
def process_item(self, item_name, meta_data, binarization_args):
|
193 |
+
from preprocessing.process_pipeline import File2Batch
|
194 |
+
return File2Batch.temporary_dict2processed_input(item_name, meta_data, self.phone_encoder, binarization_args)
|
195 |
+
|
196 |
+
def get_align(self, meta_data, mel, phone_encoded, res):
|
197 |
+
raise NotImplementedError
|
198 |
+
|
199 |
+
def get_align_from_textgrid(self, meta_data, mel, phone_encoded, res):
|
200 |
+
'''
|
201 |
+
NOTE: this part of script is *isolated* from other scripts, which means
|
202 |
+
it may not be compatible with the current version.
|
203 |
+
'''
|
204 |
+
return
|
205 |
+
tg_fn, ph = meta_data['tg_fn'], meta_data['ph']
|
206 |
+
if tg_fn is not None and os.path.exists(tg_fn):
|
207 |
+
mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
|
208 |
+
else:
|
209 |
+
raise BinarizationError(f"Align not found")
|
210 |
+
if mel2ph.max() - 1 >= len(phone_encoded):
|
211 |
+
raise BinarizationError(
|
212 |
+
f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
|
213 |
+
res['mel2ph'] = mel2ph
|
214 |
+
res['dur'] = dur
|
215 |
+
|
216 |
+
def get_f0cwt(self, f0, res):
|
217 |
+
'''
|
218 |
+
NOTE: this part of script is *isolated* from other scripts, which means
|
219 |
+
it may not be compatible with the current version.
|
220 |
+
'''
|
221 |
+
return
|
222 |
+
from utils.cwt import get_cont_lf0, get_lf0_cwt
|
223 |
+
uv, cont_lf0_lpf = get_cont_lf0(f0)
|
224 |
+
logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
|
225 |
+
cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
|
226 |
+
Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
|
227 |
+
if np.any(np.isnan(Wavelet_lf0)):
|
228 |
+
raise BinarizationError("NaN CWT")
|
229 |
+
res['cwt_spec'] = Wavelet_lf0
|
230 |
+
res['cwt_scales'] = scales
|
231 |
+
res['f0_mean'] = logf0s_mean_org
|
232 |
+
res['f0_std'] = logf0s_std_org
|
233 |
+
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
set_hparams()
|
237 |
+
BaseBinarizer().process()
|