Spaces:
Runtime error
Runtime error
Commit
·
14cff58
1
Parent(s):
640063e
Upload 45 files
Browse files- .gitattributes +8 -26
- .gitignore +160 -0
- MANIFEST.in +4 -0
- README.md +28 -6
- app.py +86 -0
- data/azusa/azusa.pt +3 -0
- data/encoder.pt +3 -0
- data/g_hifigan.pt +3 -0
- data/ltyai/ltyai.pt +3 -0
- data/nanmei/nanmei.pt +3 -0
- data/tianyi/tianyi.pt +3 -0
- data/wavernn.pt +3 -0
- mockingbirdforuse/__init__.py +120 -0
- mockingbirdforuse/encoder/__init__.py +0 -0
- mockingbirdforuse/encoder/audio.py +121 -0
- mockingbirdforuse/encoder/hparams.py +42 -0
- mockingbirdforuse/encoder/inference.py +154 -0
- mockingbirdforuse/encoder/model.py +145 -0
- mockingbirdforuse/log.py +40 -0
- mockingbirdforuse/synthesizer/__init__.py +0 -0
- mockingbirdforuse/synthesizer/gst_hyperparameters.py +19 -0
- mockingbirdforuse/synthesizer/hparams.py +113 -0
- mockingbirdforuse/synthesizer/inference.py +151 -0
- mockingbirdforuse/synthesizer/models/global_style_token.py +175 -0
- mockingbirdforuse/synthesizer/models/tacotron.py +678 -0
- mockingbirdforuse/synthesizer/utils/__init__.py +46 -0
- mockingbirdforuse/synthesizer/utils/cleaners.py +91 -0
- mockingbirdforuse/synthesizer/utils/logmmse.py +245 -0
- mockingbirdforuse/synthesizer/utils/numbers.py +70 -0
- mockingbirdforuse/synthesizer/utils/symbols.py +20 -0
- mockingbirdforuse/synthesizer/utils/text.py +74 -0
- mockingbirdforuse/vocoder/__init__.py +0 -0
- mockingbirdforuse/vocoder/distribution.py +136 -0
- mockingbirdforuse/vocoder/hifigan/__init__.py +0 -0
- mockingbirdforuse/vocoder/hifigan/hparams.py +37 -0
- mockingbirdforuse/vocoder/hifigan/inference.py +32 -0
- mockingbirdforuse/vocoder/hifigan/models.py +460 -0
- mockingbirdforuse/vocoder/wavernn/__init__.py +0 -0
- mockingbirdforuse/vocoder/wavernn/audio.py +118 -0
- mockingbirdforuse/vocoder/wavernn/hparams.py +53 -0
- mockingbirdforuse/vocoder/wavernn/inference.py +56 -0
- mockingbirdforuse/vocoder/wavernn/models/deepmind_version.py +180 -0
- mockingbirdforuse/vocoder/wavernn/models/fatchord_version.py +445 -0
- packages.txt +3 -0
- requirements.txt +13 -0
.gitattributes
CHANGED
@@ -1,34 +1,16 @@
|
|
1 |
-
*.
|
2 |
-
*.
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
4 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
14 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
15 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
MANIFEST.in
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include assets/*
|
2 |
+
include inputs/*
|
3 |
+
include LICENSE
|
4 |
+
include requirements.txt
|
README.md
CHANGED
@@ -1,12 +1,34 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: MockingBird
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.1.7
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio` or `streamlit`
|
28 |
+
|
29 |
+
`app_file`: _string_
|
30 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
31 |
+
Path is relative to the root of the repository.
|
32 |
+
|
33 |
+
`pinned`: _boolean_
|
34 |
+
Whether the Space stays on top of your list.
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import httpx
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
from tempfile import NamedTemporaryFile
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from mockingbirdforuse import MockingBird
|
9 |
+
|
10 |
+
|
11 |
+
mockingbird = MockingBird()
|
12 |
+
mockingbird_path = Path(os.path.dirname(__file__)) / "data"
|
13 |
+
base_url = "https://al.smoe.top/d/Home/source/mockingbird/"
|
14 |
+
|
15 |
+
for sy in ["encoder.pt", "g_hifigan.pt", "wavernn.pt"]:
|
16 |
+
if not os.path.exists(os.path.join(mockingbird_path, sy)):
|
17 |
+
torch.hub.download_url_to_file(f"{base_url}/{sy}", mockingbird_path / sy)
|
18 |
+
|
19 |
+
for model in ["azusa", "nanmei", "ltyai", "tianyi"]:
|
20 |
+
model_path = mockingbird_path / model
|
21 |
+
model_path.mkdir(parents=True, exist_ok=True)
|
22 |
+
for file_name in ["record.wav", f"{model}.pt"]:
|
23 |
+
if not os.path.exists(os.path.join(model_path, file_name)):
|
24 |
+
torch.hub.download_url_to_file(
|
25 |
+
f"{base_url}/{model}/{file_name}", model_path / file_name
|
26 |
+
)
|
27 |
+
|
28 |
+
mockingbird.load_model(
|
29 |
+
Path(os.path.join(mockingbird_path, "encoder.pt")),
|
30 |
+
Path(os.path.join(mockingbird_path, "g_hifigan.pt")),
|
31 |
+
Path(os.path.join(mockingbird_path, "wavernn.pt")),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def inference(
|
36 |
+
text: str,
|
37 |
+
model_name: str,
|
38 |
+
vocoder_type: str = "HifiGan",
|
39 |
+
style_idx: int = 0,
|
40 |
+
min_stop_token: int = 9,
|
41 |
+
steps: int = 2000,
|
42 |
+
):
|
43 |
+
model_path = mockingbird_path / model_name
|
44 |
+
mockingbird.set_synthesizer(Path(os.path.join(model_path, f"{model_name}.pt")))
|
45 |
+
fd = NamedTemporaryFile(suffix=".wav", delete=False)
|
46 |
+
record = mockingbird.synthesize(
|
47 |
+
text=str(text),
|
48 |
+
input_wav=model_path / "record.wav",
|
49 |
+
vocoder_type=vocoder_type,
|
50 |
+
style_idx=style_idx,
|
51 |
+
min_stop_token=min_stop_token,
|
52 |
+
steps=steps,
|
53 |
+
)
|
54 |
+
with open(fd.name, "wb") as file:
|
55 |
+
file.write(record.getvalue())
|
56 |
+
return fd.name
|
57 |
+
|
58 |
+
|
59 |
+
title = "MockingBird"
|
60 |
+
description = "🚀AI拟声: 5秒内克隆您的声音并生成任意语音内容 Clone a voice in 5 seconds to generate arbitrary speech in real-time"
|
61 |
+
article = "<a href='https://github.com/babysor/MockingBird'>Github Repo</a></p>"
|
62 |
+
|
63 |
+
gr.Interface(
|
64 |
+
inference,
|
65 |
+
[
|
66 |
+
gr.Textbox(label="Input"),
|
67 |
+
gr.Radio(
|
68 |
+
["azusa", "nanmei", "ltyai", "tianyi"],
|
69 |
+
label="model type",
|
70 |
+
value="azusa",
|
71 |
+
),
|
72 |
+
gr.Radio(
|
73 |
+
["HifiGan", "WaveRNN"],
|
74 |
+
label="Vocoder type",
|
75 |
+
value="HifiGan",
|
76 |
+
),
|
77 |
+
gr.Slider(minimum=-1, maximum=9, step=1, label="style idx", value=0),
|
78 |
+
gr.Slider(minimum=3, maximum=9, label="min stop token", value=9),
|
79 |
+
gr.Slider(minimum=200, maximum=2000, label="steps", value=2000),
|
80 |
+
],
|
81 |
+
gr.Audio(type="filepath", label="Output"),
|
82 |
+
title=title,
|
83 |
+
description=description,
|
84 |
+
article=article,
|
85 |
+
examples=[["阿梓不是你的电子播放器", "azusa", "HifiGan", 0, 9, 2000], ["不是", "nanmei", "HifiGan", 0, 9, 2000]],
|
86 |
+
).launch()
|
data/azusa/azusa.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f5cc81057c8c7a5c8000ac8f5dd0335f878484640e69e2bb1f7a84d9b0bbf90
|
3 |
+
size 526153469
|
data/encoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:57715adc6f36047166ab06e37b904240aee2f4d10fc88f78ed91510cf4b38666
|
3 |
+
size 17095158
|
data/g_hifigan.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c5b29830f9b42c481c108cb0b89d56f380928d4d46e1d30d65c92340ddc694e
|
3 |
+
size 51985448
|
data/ltyai/ltyai.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4bd4b759a30efd70d0064628c3b107aa7cd9d0bff8a36a242946a46d7c5235c
|
3 |
+
size 526153021
|
data/nanmei/nanmei.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95e90985b4c6b8090d8b328e7b23078eb00cffa0464ca9982464f0000b44a2a9
|
3 |
+
size 526153469
|
data/tianyi/tianyi.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9140c057ad8f4243e47a18103e773e0f823c4423927eec67dd47a3c3e9a9293
|
3 |
+
size 526153469
|
data/wavernn.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
|
3 |
+
size 53845290
|
mockingbirdforuse/__init__.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
from io import BytesIO
|
5 |
+
from pathlib import Path
|
6 |
+
from scipy.io import wavfile
|
7 |
+
from typing import List, Literal, Optional
|
8 |
+
|
9 |
+
from .encoder.inference import Encoder, preprocess_wav
|
10 |
+
from .synthesizer.inference import Synthesizer
|
11 |
+
from .vocoder.hifigan.inference import HifiGanVocoder
|
12 |
+
from .vocoder.wavernn.inference import WaveRNNVocoder
|
13 |
+
from .log import logger
|
14 |
+
|
15 |
+
|
16 |
+
def process_text(text: str) -> List[str]:
|
17 |
+
punctuation = "!,。、,?!," # punctuate and split/clean text
|
18 |
+
processed_texts = []
|
19 |
+
text = re.sub(r"[{}]+".format(punctuation), "\n", text)
|
20 |
+
for processed_text in text.split("\n"):
|
21 |
+
if processed_text:
|
22 |
+
processed_texts.append(processed_text.strip())
|
23 |
+
return processed_texts
|
24 |
+
|
25 |
+
|
26 |
+
class MockingBird:
|
27 |
+
def __init__(self):
|
28 |
+
self.encoder: Optional[Encoder] = None
|
29 |
+
self.gan_vocoder: Optional[HifiGanVocoder] = None
|
30 |
+
self.rnn_vocoder: Optional[WaveRNNVocoder] = None
|
31 |
+
self.synthesizer: Optional[Synthesizer] = None
|
32 |
+
|
33 |
+
def load_model(
|
34 |
+
self,
|
35 |
+
encoder_path: Path,
|
36 |
+
gan_vocoder_path: Optional[Path] = None,
|
37 |
+
rnn_vocoder_path: Optional[Path] = None,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
设置 Encoder模型 和 Vocoder模型 路径
|
41 |
+
|
42 |
+
Args:
|
43 |
+
encoder_path (Path): Encoder模型路径
|
44 |
+
gan_vocoder_path (Path): HifiGan Vocoder模型路径,可选,需要用到 HifiGan 类型时必须填写
|
45 |
+
rnn_vocoder_path (Path): WaveRNN Vocoder模型路径,可选,需要用到 WaveRNN 类型时必须填写
|
46 |
+
"""
|
47 |
+
self.encoder = Encoder(encoder_path)
|
48 |
+
if gan_vocoder_path:
|
49 |
+
self.gan_vocoder = HifiGanVocoder(gan_vocoder_path)
|
50 |
+
if rnn_vocoder_path:
|
51 |
+
self.rnn_vocoder = WaveRNNVocoder(rnn_vocoder_path)
|
52 |
+
|
53 |
+
def set_synthesizer(self, synthesizer_path: Path):
|
54 |
+
"""
|
55 |
+
设置Synthesizer模型路径
|
56 |
+
|
57 |
+
Args:
|
58 |
+
synthesizer_path (Path): Synthesizer模型路径
|
59 |
+
"""
|
60 |
+
self.synthesizer = Synthesizer(synthesizer_path)
|
61 |
+
logger.info(f"using synthesizer model: {synthesizer_path}")
|
62 |
+
|
63 |
+
def synthesize(
|
64 |
+
self,
|
65 |
+
text: str,
|
66 |
+
input_wav: Path,
|
67 |
+
vocoder_type: Literal["HifiGan", "WaveRNN"] = "HifiGan",
|
68 |
+
style_idx: int = 0,
|
69 |
+
min_stop_token: int = 5,
|
70 |
+
steps: int = 1000,
|
71 |
+
) -> BytesIO:
|
72 |
+
"""
|
73 |
+
生成语音
|
74 |
+
|
75 |
+
Args:
|
76 |
+
text (str): 目标文字
|
77 |
+
input_wav (Path): 目标录音路径
|
78 |
+
vocoder_type (HifiGan / WaveRNN): Vocoder模型,默认使用HifiGan
|
79 |
+
style_idx (int, optional): Style 范围 -1~9,默认为 0
|
80 |
+
min_stop_token (int, optional): Accuracy(精度) 范围3~9,默认为 5
|
81 |
+
steps (int, optional): MaxLength(最大句长) 范围200~2000,默认为 1000
|
82 |
+
"""
|
83 |
+
if not self.encoder:
|
84 |
+
raise Exception("Please set encoder path first")
|
85 |
+
|
86 |
+
if not self.synthesizer:
|
87 |
+
raise Exception("Please set synthesizer path first")
|
88 |
+
|
89 |
+
# Load input wav
|
90 |
+
wav, sample_rate = librosa.load(input_wav)
|
91 |
+
|
92 |
+
encoder_wav = preprocess_wav(wav, sample_rate)
|
93 |
+
embed, _, _ = self.encoder.embed_utterance(encoder_wav, return_partials=True)
|
94 |
+
|
95 |
+
# Load input text
|
96 |
+
texts = process_text(text)
|
97 |
+
|
98 |
+
# synthesize and vocode
|
99 |
+
embeds = [embed] * len(texts)
|
100 |
+
specs = self.synthesizer.synthesize_spectrograms(
|
101 |
+
texts,
|
102 |
+
embeds,
|
103 |
+
style_idx=style_idx,
|
104 |
+
min_stop_token=min_stop_token,
|
105 |
+
steps=steps,
|
106 |
+
)
|
107 |
+
spec = np.concatenate(specs, axis=1)
|
108 |
+
if vocoder_type == "WaveRNN":
|
109 |
+
if not self.rnn_vocoder:
|
110 |
+
raise Exception("Please set wavernn vocoder path first")
|
111 |
+
wav, sample_rate = self.rnn_vocoder.infer_waveform(spec)
|
112 |
+
else:
|
113 |
+
if not self.gan_vocoder:
|
114 |
+
raise Exception("Please set hifigan vocoder path first")
|
115 |
+
wav, sample_rate = self.gan_vocoder.infer_waveform(spec)
|
116 |
+
|
117 |
+
# Return cooked wav
|
118 |
+
out = BytesIO()
|
119 |
+
wavfile.write(out, sample_rate, wav.astype(np.float32))
|
120 |
+
return out
|
mockingbirdforuse/encoder/__init__.py
ADDED
File without changes
|
mockingbirdforuse/encoder/audio.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import struct
|
2 |
+
import librosa
|
3 |
+
import webrtcvad
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Optional, Union
|
7 |
+
from scipy.ndimage.morphology import binary_dilation
|
8 |
+
|
9 |
+
from .hparams import hparams as hp
|
10 |
+
|
11 |
+
|
12 |
+
def preprocess_wav(
|
13 |
+
fpath_or_wav: Union[str, Path, np.ndarray],
|
14 |
+
source_sr: Optional[int] = None,
|
15 |
+
normalize: Optional[bool] = True,
|
16 |
+
trim_silence: Optional[bool] = True,
|
17 |
+
):
|
18 |
+
"""
|
19 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
20 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
21 |
+
|
22 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
23 |
+
just .wav), either the waveform as a numpy array of floats.
|
24 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
25 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
26 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
27 |
+
this argument will be ignored.
|
28 |
+
"""
|
29 |
+
# Load the wav from disk if needed
|
30 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
31 |
+
wav, source_sr = librosa.load(str(fpath_or_wav))
|
32 |
+
else:
|
33 |
+
wav = fpath_or_wav
|
34 |
+
|
35 |
+
# Resample the wav if needed
|
36 |
+
if source_sr is not None and source_sr != hp.sampling_rate:
|
37 |
+
wav = librosa.resample(wav, orig_sr=source_sr, target_sr=hp.sampling_rate)
|
38 |
+
|
39 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
40 |
+
if normalize:
|
41 |
+
wav = normalize_volume(wav, hp.audio_norm_target_dBFS, increase_only=True)
|
42 |
+
if webrtcvad and trim_silence:
|
43 |
+
wav = trim_long_silences(wav)
|
44 |
+
|
45 |
+
return wav
|
46 |
+
|
47 |
+
|
48 |
+
def wav_to_mel_spectrogram(wav):
|
49 |
+
"""
|
50 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
51 |
+
Note: this not a log-mel spectrogram.
|
52 |
+
"""
|
53 |
+
frames = librosa.feature.melspectrogram(
|
54 |
+
y=wav,
|
55 |
+
sr=hp.sampling_rate,
|
56 |
+
n_fft=int(hp.sampling_rate * hp.mel_window_length / 1000),
|
57 |
+
hop_length=int(hp.sampling_rate * hp.mel_window_step / 1000),
|
58 |
+
n_mels=hp.mel_n_channels,
|
59 |
+
)
|
60 |
+
return frames.astype(np.float32).T
|
61 |
+
|
62 |
+
|
63 |
+
def trim_long_silences(wav):
|
64 |
+
"""
|
65 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
66 |
+
threshold determined by the VAD parameters in params.py.
|
67 |
+
|
68 |
+
:param wav: the raw waveform as a numpy array of floats
|
69 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
70 |
+
"""
|
71 |
+
# Compute the voice detection window size
|
72 |
+
samples_per_window = (hp.vad_window_length * hp.sampling_rate) // 1000
|
73 |
+
|
74 |
+
# Trim the end of the audio to have a multiple of the window size
|
75 |
+
wav = wav[: len(wav) - (len(wav) % samples_per_window)]
|
76 |
+
|
77 |
+
# Convert the float waveform to 16-bit mono PCM
|
78 |
+
int16_max = (2**15) - 1
|
79 |
+
pcm_wave = struct.pack(
|
80 |
+
"%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)
|
81 |
+
)
|
82 |
+
|
83 |
+
# Perform voice activation detection
|
84 |
+
voice_flags = []
|
85 |
+
vad = webrtcvad.Vad(mode=3)
|
86 |
+
for window_start in range(0, len(wav), samples_per_window):
|
87 |
+
window_end = window_start + samples_per_window
|
88 |
+
voice_flags.append(
|
89 |
+
vad.is_speech(
|
90 |
+
pcm_wave[window_start * 2 : window_end * 2],
|
91 |
+
sample_rate=hp.sampling_rate,
|
92 |
+
)
|
93 |
+
)
|
94 |
+
voice_flags = np.array(voice_flags)
|
95 |
+
|
96 |
+
# Smooth the voice detection with a moving average
|
97 |
+
def moving_average(array, width):
|
98 |
+
array_padded = np.concatenate(
|
99 |
+
(np.zeros((width - 1) // 2), array, np.zeros(width // 2))
|
100 |
+
)
|
101 |
+
ret = np.cumsum(array_padded, dtype=float)
|
102 |
+
ret[width:] = ret[width:] - ret[:-width]
|
103 |
+
return ret[width - 1 :] / width
|
104 |
+
|
105 |
+
audio_mask = moving_average(voice_flags, hp.vad_moving_average_width)
|
106 |
+
audio_mask = np.round(audio_mask).astype(np.bool8)
|
107 |
+
|
108 |
+
# Dilate the voiced regions
|
109 |
+
audio_mask = binary_dilation(audio_mask, np.ones(hp.vad_max_silence_length + 1))
|
110 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
111 |
+
|
112 |
+
return wav[audio_mask == True]
|
113 |
+
|
114 |
+
|
115 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
116 |
+
if increase_only and decrease_only:
|
117 |
+
raise ValueError("Both increase only and decrease only are set")
|
118 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
|
119 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
120 |
+
return wav
|
121 |
+
return wav * (10 ** (dBFS_change / 20))
|
mockingbirdforuse/encoder/hparams.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class HParams:
|
6 |
+
## Mel-filterbank
|
7 |
+
mel_window_length = 25 # In milliseconds
|
8 |
+
mel_window_step = 10 # In milliseconds
|
9 |
+
mel_n_channels = 40
|
10 |
+
|
11 |
+
## Audio
|
12 |
+
sampling_rate = 16000
|
13 |
+
# Number of spectrogram frames in a partial utterance
|
14 |
+
partials_n_frames = 160 # 1600 ms
|
15 |
+
# Number of spectrogram frames at inference
|
16 |
+
inference_n_frames = 80 # 800 ms
|
17 |
+
|
18 |
+
## Voice Activation Detection
|
19 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
20 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
21 |
+
vad_window_length = 30 # In milliseconds
|
22 |
+
# Number of frames to average together when performing the moving average smoothing.
|
23 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
24 |
+
vad_moving_average_width = 8
|
25 |
+
# Maximum number of consecutive silent frames a segment can have.
|
26 |
+
vad_max_silence_length = 6
|
27 |
+
|
28 |
+
## Audio volume normalization
|
29 |
+
audio_norm_target_dBFS = -30
|
30 |
+
|
31 |
+
## Model parameters
|
32 |
+
model_hidden_size = 256
|
33 |
+
model_embedding_size = 256
|
34 |
+
model_num_layers = 3
|
35 |
+
|
36 |
+
## Training parameters
|
37 |
+
learning_rate_init = 1e-4
|
38 |
+
speakers_per_batch = 64
|
39 |
+
utterances_per_speaker = 10
|
40 |
+
|
41 |
+
|
42 |
+
hparams = HParams()
|
mockingbirdforuse/encoder/inference.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from . import audio
|
6 |
+
from .model import SpeakerEncoder
|
7 |
+
from .audio import preprocess_wav # We want to expose this function from here
|
8 |
+
from .hparams import hparams as hp
|
9 |
+
from ..log import logger
|
10 |
+
|
11 |
+
|
12 |
+
class Encoder:
|
13 |
+
def __init__(self, model_path: Path):
|
14 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
self._model = SpeakerEncoder(self._device, torch.device("cpu"))
|
16 |
+
checkpoint = torch.load(model_path, self._device)
|
17 |
+
self._model.load_state_dict(checkpoint["model_state"])
|
18 |
+
self._model.eval()
|
19 |
+
logger.info(
|
20 |
+
f"Loaded encoder {model_path.name} trained to step {checkpoint['step']}"
|
21 |
+
)
|
22 |
+
|
23 |
+
def embed_frames_batch(self, frames_batch):
|
24 |
+
"""
|
25 |
+
Computes embeddings for a batch of mel spectrogram.
|
26 |
+
|
27 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
28 |
+
(batch_size, n_frames, n_channels)
|
29 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
30 |
+
"""
|
31 |
+
|
32 |
+
frames = torch.from_numpy(frames_batch).to(self._device)
|
33 |
+
embed = self._model.forward(frames).detach().cpu().numpy()
|
34 |
+
return embed
|
35 |
+
|
36 |
+
def compute_partial_slices(
|
37 |
+
self,
|
38 |
+
n_samples,
|
39 |
+
partial_utterance_n_frames=hp.partials_n_frames,
|
40 |
+
min_pad_coverage=0.75,
|
41 |
+
overlap=0.5,
|
42 |
+
rate=None,
|
43 |
+
):
|
44 |
+
"""
|
45 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
46 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
47 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
48 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
49 |
+
defined in params_data.py.
|
50 |
+
|
51 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
52 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
53 |
+
|
54 |
+
:param n_samples: the number of samples in the waveform
|
55 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
56 |
+
utterance
|
57 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
58 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
59 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
60 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
61 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
62 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
63 |
+
utterances are entirely disjoint.
|
64 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
65 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
66 |
+
utterances.
|
67 |
+
"""
|
68 |
+
assert 0 <= overlap < 1
|
69 |
+
assert 0 < min_pad_coverage <= 1
|
70 |
+
|
71 |
+
if rate != None:
|
72 |
+
samples_per_frame = int((hp.sampling_rate * hp.mel_window_step / 1000))
|
73 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
74 |
+
frame_step = int(np.round((hp.sampling_rate / rate) / samples_per_frame))
|
75 |
+
else:
|
76 |
+
samples_per_frame = int((hp.sampling_rate * hp.mel_window_step / 1000))
|
77 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
78 |
+
frame_step = max(
|
79 |
+
int(np.round(partial_utterance_n_frames * (1 - overlap))), 1
|
80 |
+
)
|
81 |
+
|
82 |
+
assert 0 < frame_step, "The rate is too high"
|
83 |
+
assert (
|
84 |
+
frame_step <= hp.partials_n_frames
|
85 |
+
), "The rate is too low, it should be %f at least" % (
|
86 |
+
hp.sampling_rate / (samples_per_frame * hp.partials_n_frames)
|
87 |
+
)
|
88 |
+
|
89 |
+
# Compute the slices
|
90 |
+
wav_slices, mel_slices = [], []
|
91 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
92 |
+
for i in range(0, steps, frame_step):
|
93 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
94 |
+
wav_range = mel_range * samples_per_frame
|
95 |
+
mel_slices.append(slice(*mel_range))
|
96 |
+
wav_slices.append(slice(*wav_range))
|
97 |
+
|
98 |
+
# Evaluate whether extra padding is warranted or not
|
99 |
+
last_wav_range = wav_slices[-1]
|
100 |
+
coverage = (n_samples - last_wav_range.start) / (
|
101 |
+
last_wav_range.stop - last_wav_range.start
|
102 |
+
)
|
103 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
104 |
+
mel_slices = mel_slices[:-1]
|
105 |
+
wav_slices = wav_slices[:-1]
|
106 |
+
|
107 |
+
return wav_slices, mel_slices
|
108 |
+
|
109 |
+
def embed_utterance(
|
110 |
+
self, wav, using_partials: bool = True, return_partials: bool = False, **kwargs
|
111 |
+
):
|
112 |
+
"""
|
113 |
+
Computes an embedding for a single utterance.
|
114 |
+
|
115 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
116 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
117 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
118 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
119 |
+
spectogram to the network.
|
120 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
121 |
+
wav slices that correspond to the partial embeddings.
|
122 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
123 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
124 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
125 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
126 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
127 |
+
instead.
|
128 |
+
"""
|
129 |
+
# Process the entire utterance if not using partials
|
130 |
+
if not using_partials:
|
131 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
132 |
+
embed = self.embed_frames_batch(frames[None, ...])[0]
|
133 |
+
if return_partials:
|
134 |
+
return embed, None, None
|
135 |
+
return embed
|
136 |
+
|
137 |
+
# Compute where to split the utterance into partials and pad if necessary
|
138 |
+
wave_slices, mel_slices = self.compute_partial_slices(len(wav), **kwargs)
|
139 |
+
max_wave_length = wave_slices[-1].stop
|
140 |
+
if max_wave_length >= len(wav):
|
141 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
142 |
+
|
143 |
+
# Split the utterance into partials
|
144 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
145 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
146 |
+
partial_embeds = self.embed_frames_batch(frames_batch)
|
147 |
+
|
148 |
+
# Compute the utterance embedding from the partial embeddings
|
149 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
150 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
151 |
+
|
152 |
+
if return_partials:
|
153 |
+
return embed, partial_embeds, wave_slices
|
154 |
+
return embed
|
mockingbirdforuse/encoder/model.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch import nn
|
4 |
+
from scipy.optimize import brentq
|
5 |
+
from sklearn.metrics import roc_curve
|
6 |
+
from scipy.interpolate import interp1d
|
7 |
+
from torch.nn.parameter import Parameter
|
8 |
+
from torch.nn.utils.clip_grad import clip_grad_norm_
|
9 |
+
|
10 |
+
from .hparams import hparams as hp
|
11 |
+
|
12 |
+
|
13 |
+
class SpeakerEncoder(nn.Module):
|
14 |
+
def __init__(self, device, loss_device):
|
15 |
+
super().__init__()
|
16 |
+
self.loss_device = loss_device
|
17 |
+
|
18 |
+
# Network defition
|
19 |
+
self.lstm = nn.LSTM(
|
20 |
+
input_size=hp.mel_n_channels,
|
21 |
+
hidden_size=hp.model_hidden_size,
|
22 |
+
num_layers=hp.model_num_layers,
|
23 |
+
batch_first=True,
|
24 |
+
).to(device)
|
25 |
+
self.linear = nn.Linear(
|
26 |
+
in_features=hp.model_hidden_size, out_features=hp.model_embedding_size
|
27 |
+
).to(device)
|
28 |
+
self.relu = torch.nn.ReLU().to(device)
|
29 |
+
|
30 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
31 |
+
self.similarity_weight = Parameter(torch.tensor([10.0])).to(loss_device)
|
32 |
+
self.similarity_bias = Parameter(torch.tensor([-5.0])).to(loss_device)
|
33 |
+
|
34 |
+
# Loss
|
35 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
36 |
+
|
37 |
+
def do_gradient_ops(self):
|
38 |
+
# Gradient scale
|
39 |
+
self.similarity_weight.grad *= 0.01
|
40 |
+
self.similarity_bias.grad *= 0.01
|
41 |
+
|
42 |
+
# Gradient clipping
|
43 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
44 |
+
|
45 |
+
def forward(self, utterances, hidden_init=None):
|
46 |
+
"""
|
47 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
48 |
+
|
49 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
50 |
+
(batch_size, n_frames, n_channels)
|
51 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
52 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
53 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
54 |
+
"""
|
55 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
56 |
+
# and the final cell state.
|
57 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
58 |
+
|
59 |
+
# We take only the hidden state of the last layer
|
60 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
61 |
+
|
62 |
+
# L2-normalize it
|
63 |
+
embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
|
64 |
+
|
65 |
+
return embeds
|
66 |
+
|
67 |
+
def similarity_matrix(self, embeds):
|
68 |
+
"""
|
69 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
70 |
+
|
71 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
72 |
+
utterances_per_speaker, embedding_size)
|
73 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
74 |
+
utterances_per_speaker, speakers_per_batch)
|
75 |
+
"""
|
76 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
77 |
+
|
78 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
79 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
80 |
+
centroids_incl = centroids_incl.clone() / (
|
81 |
+
torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5
|
82 |
+
)
|
83 |
+
|
84 |
+
# Exclusive centroids (1 per utterance)
|
85 |
+
centroids_excl = torch.sum(embeds, dim=1, keepdim=True) - embeds
|
86 |
+
centroids_excl /= utterances_per_speaker - 1
|
87 |
+
centroids_excl = centroids_excl.clone() / (
|
88 |
+
torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5
|
89 |
+
)
|
90 |
+
|
91 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
92 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
93 |
+
# We vectorize the computation for efficiency.
|
94 |
+
sim_matrix = torch.zeros(
|
95 |
+
speakers_per_batch, utterances_per_speaker, speakers_per_batch
|
96 |
+
).to(self.loss_device)
|
97 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int32)
|
98 |
+
for j in range(speakers_per_batch):
|
99 |
+
mask = np.where(mask_matrix[j])[0]
|
100 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
101 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
102 |
+
|
103 |
+
## Even more vectorized version (slower maybe because of transpose)
|
104 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
105 |
+
# ).to(self.loss_device)
|
106 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
107 |
+
# mask = np.where(1 - eye)
|
108 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
109 |
+
# mask = np.where(eye)
|
110 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
111 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
112 |
+
|
113 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
114 |
+
return sim_matrix
|
115 |
+
|
116 |
+
def loss(self, embeds):
|
117 |
+
"""
|
118 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
119 |
+
|
120 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
121 |
+
utterances_per_speaker, embedding_size)
|
122 |
+
:return: the loss and the EER for this batch of embeddings.
|
123 |
+
"""
|
124 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
125 |
+
|
126 |
+
# Loss
|
127 |
+
sim_matrix = self.similarity_matrix(embeds)
|
128 |
+
sim_matrix = sim_matrix.reshape(
|
129 |
+
(speakers_per_batch * utterances_per_speaker, speakers_per_batch)
|
130 |
+
)
|
131 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
132 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
133 |
+
loss = self.loss_fn(sim_matrix, target)
|
134 |
+
|
135 |
+
# EER (not backpropagated)
|
136 |
+
with torch.no_grad():
|
137 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int32)[0]
|
138 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
139 |
+
preds = sim_matrix.detach().cpu().numpy()
|
140 |
+
|
141 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
142 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
143 |
+
eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
|
144 |
+
|
145 |
+
return loss, eer
|
mockingbirdforuse/log.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import loguru
|
3 |
+
|
4 |
+
from typing import TYPE_CHECKING, Union
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from loguru import Logger
|
8 |
+
|
9 |
+
logger: "Logger" = loguru.logger
|
10 |
+
|
11 |
+
|
12 |
+
class Filter:
|
13 |
+
def __init__(self) -> None:
|
14 |
+
self.level: Union[int, str] = "DEBUG"
|
15 |
+
|
16 |
+
def __call__(self, record):
|
17 |
+
module_name: str = record["name"]
|
18 |
+
record["name"] = module_name.split(".")[0]
|
19 |
+
levelno = (
|
20 |
+
logger.level(self.level).no if isinstance(self.level, str) else self.level
|
21 |
+
)
|
22 |
+
return record["level"].no >= levelno
|
23 |
+
|
24 |
+
|
25 |
+
logger.remove()
|
26 |
+
default_filter: Filter = Filter()
|
27 |
+
default_format: str = (
|
28 |
+
"<g>{time:MM-DD HH:mm:ss}</g> "
|
29 |
+
"[<lvl>{level}</lvl>] "
|
30 |
+
"<c><u>{name}</u></c> | "
|
31 |
+
"{message}"
|
32 |
+
)
|
33 |
+
logger.add(
|
34 |
+
sys.stdout,
|
35 |
+
level=0,
|
36 |
+
colorize=True,
|
37 |
+
diagnose=False,
|
38 |
+
filter=default_filter,
|
39 |
+
format=default_format,
|
40 |
+
)
|
mockingbirdforuse/synthesizer/__init__.py
ADDED
File without changes
|
mockingbirdforuse/synthesizer/gst_hyperparameters.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class GSTHyperparameters:
|
6 |
+
E = 512
|
7 |
+
|
8 |
+
# reference encoder
|
9 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
10 |
+
|
11 |
+
# style token layer
|
12 |
+
token_num = 10
|
13 |
+
# token_emb_size = 256
|
14 |
+
num_heads = 8
|
15 |
+
|
16 |
+
n_mels = 256 # Number of Mel banks to generate
|
17 |
+
|
18 |
+
|
19 |
+
hparams = GSTHyperparameters()
|
mockingbirdforuse/synthesizer/hparams.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class HParams:
|
6 |
+
### Signal Processing (used in both synthesizer and vocoder)
|
7 |
+
sample_rate = 16000
|
8 |
+
n_fft = 800
|
9 |
+
num_mels = 80
|
10 |
+
hop_size = 200
|
11 |
+
"""Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)"""
|
12 |
+
win_size = 800
|
13 |
+
"""Tacotron uses 50 ms frame length (set to sample_rate * 0.050)"""
|
14 |
+
fmin = 55
|
15 |
+
min_level_db = -100
|
16 |
+
ref_level_db = 20
|
17 |
+
max_abs_value = 4.0
|
18 |
+
"""Gradient explodes if too big, premature convergence if too small."""
|
19 |
+
preemphasis = 0.97
|
20 |
+
"""Filter coefficient to use if preemphasize is True"""
|
21 |
+
preemphasize = True
|
22 |
+
|
23 |
+
### Tacotron Text-to-Speech (TTS)
|
24 |
+
tts_embed_dims = 512
|
25 |
+
"""Embedding dimension for the graphemes/phoneme inputs"""
|
26 |
+
tts_encoder_dims = 256
|
27 |
+
tts_decoder_dims = 128
|
28 |
+
tts_postnet_dims = 512
|
29 |
+
tts_encoder_K = 5
|
30 |
+
tts_lstm_dims = 1024
|
31 |
+
tts_postnet_K = 5
|
32 |
+
tts_num_highways = 4
|
33 |
+
tts_dropout = 0.5
|
34 |
+
tts_cleaner_names = ["basic_cleaners"]
|
35 |
+
tts_stop_threshold = -3.4
|
36 |
+
"""
|
37 |
+
Value below which audio generation ends.
|
38 |
+
For example, for a range of [-4, 4], this
|
39 |
+
will terminate the sequence at the first
|
40 |
+
frame that has all values < -3.4
|
41 |
+
"""
|
42 |
+
|
43 |
+
### Tacotron Training
|
44 |
+
tts_schedule = [
|
45 |
+
(2, 1e-3, 10_000, 12),
|
46 |
+
(2, 5e-4, 15_000, 12),
|
47 |
+
(2, 2e-4, 20_000, 12),
|
48 |
+
(2, 1e-4, 30_000, 12),
|
49 |
+
(2, 5e-5, 40_000, 12),
|
50 |
+
(2, 1e-5, 60_000, 12),
|
51 |
+
(2, 5e-6, 160_000, 12),
|
52 |
+
(2, 3e-6, 320_000, 12),
|
53 |
+
(2, 1e-6, 640_000, 12),
|
54 |
+
]
|
55 |
+
"""
|
56 |
+
Progressive training schedule
|
57 |
+
(r, lr, step, batch_size)
|
58 |
+
r = reduction factor (# of mel frames synthesized for each decoder iteration)
|
59 |
+
lr = learning rate
|
60 |
+
"""
|
61 |
+
|
62 |
+
tts_clip_grad_norm = 1.0
|
63 |
+
"""clips the gradient norm to prevent explosion - set to None if not needed"""
|
64 |
+
tts_eval_interval = 500
|
65 |
+
"""
|
66 |
+
Number of steps between model evaluation (sample generation)
|
67 |
+
Set to -1 to generate after completing epoch, or 0 to disable
|
68 |
+
"""
|
69 |
+
tts_eval_num_samples = 1
|
70 |
+
"""Makes this number of samples"""
|
71 |
+
tts_finetune_layers = []
|
72 |
+
"""For finetune usage, if set, only selected layers will be trained, available: encoder,encoder_proj,gst,decoder,postnet,post_proj"""
|
73 |
+
|
74 |
+
### Data Preprocessing
|
75 |
+
max_mel_frames = 900
|
76 |
+
rescale = True
|
77 |
+
rescaling_max = 0.9
|
78 |
+
synthesis_batch_size = 16
|
79 |
+
"""For vocoder preprocessing and inference."""
|
80 |
+
|
81 |
+
### Mel Visualization and Griffin-Lim
|
82 |
+
signal_normalization = True
|
83 |
+
power = 1.5
|
84 |
+
griffin_lim_iters = 60
|
85 |
+
|
86 |
+
### Audio processing options
|
87 |
+
fmax = 7600
|
88 |
+
"""Should not exceed (sample_rate // 2)"""
|
89 |
+
allow_clipping_in_normalization = True
|
90 |
+
"""Used when signal_normalization = True"""
|
91 |
+
clip_mels_length = True
|
92 |
+
"""If true, discards samples exceeding max_mel_frames"""
|
93 |
+
use_lws = False
|
94 |
+
"""Fast spectrogram phase recovery using local weighted sums"""
|
95 |
+
symmetric_mels = True
|
96 |
+
"""Sets mel range to [-max_abs_value, max_abs_value] if True, and [0, max_abs_value] if False"""
|
97 |
+
trim_silence = True
|
98 |
+
"""Use with sample_rate of 16000 for best results"""
|
99 |
+
|
100 |
+
### SV2TTS
|
101 |
+
speaker_embedding_size = 256
|
102 |
+
"""Dimension for the speaker embedding"""
|
103 |
+
silence_min_duration_split = 0.4
|
104 |
+
"""Duration in seconds of a silence for an utterance to be split"""
|
105 |
+
utterance_min_duration = 1.6
|
106 |
+
"""Duration in seconds below which utterances are discarded"""
|
107 |
+
use_gst = True
|
108 |
+
"""Whether to use global style token"""
|
109 |
+
use_ser_for_gst = True
|
110 |
+
"""Whether to use speaker embedding referenced for global style token"""
|
111 |
+
|
112 |
+
|
113 |
+
hparams = HParams()
|
mockingbirdforuse/synthesizer/inference.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union, List
|
6 |
+
from pypinyin import lazy_pinyin, Style
|
7 |
+
|
8 |
+
from .hparams import hparams as hp
|
9 |
+
from .utils.symbols import symbols
|
10 |
+
from .models.tacotron import Tacotron
|
11 |
+
from .utils.text import text_to_sequence
|
12 |
+
from .utils.logmmse import denoise, profile_noise
|
13 |
+
from ..log import logger
|
14 |
+
|
15 |
+
|
16 |
+
class Synthesizer:
|
17 |
+
def __init__(self, model_path: Path):
|
18 |
+
# Check for GPU
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
self.device = torch.device("cuda")
|
21 |
+
else:
|
22 |
+
self.device = torch.device("cpu")
|
23 |
+
logger.info(f"Synthesizer using device: {self.device}")
|
24 |
+
|
25 |
+
self._model = Tacotron(
|
26 |
+
embed_dims=hp.tts_embed_dims,
|
27 |
+
num_chars=len(symbols),
|
28 |
+
encoder_dims=hp.tts_encoder_dims,
|
29 |
+
decoder_dims=hp.tts_decoder_dims,
|
30 |
+
n_mels=hp.num_mels,
|
31 |
+
fft_bins=hp.num_mels,
|
32 |
+
postnet_dims=hp.tts_postnet_dims,
|
33 |
+
encoder_K=hp.tts_encoder_K,
|
34 |
+
lstm_dims=hp.tts_lstm_dims,
|
35 |
+
postnet_K=hp.tts_postnet_K,
|
36 |
+
num_highways=hp.tts_num_highways,
|
37 |
+
dropout=hp.tts_dropout,
|
38 |
+
stop_threshold=hp.tts_stop_threshold,
|
39 |
+
speaker_embedding_size=hp.speaker_embedding_size,
|
40 |
+
).to(self.device)
|
41 |
+
|
42 |
+
self._model.load(model_path, self.device)
|
43 |
+
self._model.eval()
|
44 |
+
|
45 |
+
logger.info(
|
46 |
+
'Loaded synthesizer "%s" trained to step %d'
|
47 |
+
% (model_path.name, self._model.state_dict()["step"])
|
48 |
+
)
|
49 |
+
|
50 |
+
def synthesize_spectrograms(
|
51 |
+
self,
|
52 |
+
texts: List[str],
|
53 |
+
embeddings: Union[np.ndarray, List[np.ndarray]],
|
54 |
+
return_alignments=False,
|
55 |
+
style_idx=0,
|
56 |
+
min_stop_token=5,
|
57 |
+
steps=2000,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Synthesizes mel spectrograms from texts and speaker embeddings.
|
61 |
+
|
62 |
+
:param texts: a list of N text prompts to be synthesized
|
63 |
+
:param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
|
64 |
+
:param return_alignments: if True, a matrix representing the alignments between the
|
65 |
+
characters
|
66 |
+
and each decoder output step will be returned for each spectrogram
|
67 |
+
:return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
|
68 |
+
sequence length of spectrogram i, and possibly the alignments.
|
69 |
+
"""
|
70 |
+
|
71 |
+
logger.debug("Read " + str(texts))
|
72 |
+
texts = [
|
73 |
+
" ".join(lazy_pinyin(v, style=Style.TONE3, neutral_tone_with_five=True))
|
74 |
+
for v in texts
|
75 |
+
]
|
76 |
+
logger.debug("Synthesizing " + str(texts))
|
77 |
+
# Preprocess text inputs
|
78 |
+
inputs = [text_to_sequence(text, hp.tts_cleaner_names) for text in texts]
|
79 |
+
if not isinstance(embeddings, list):
|
80 |
+
embeddings = [embeddings]
|
81 |
+
|
82 |
+
# Batch inputs
|
83 |
+
batched_inputs = [
|
84 |
+
inputs[i : i + hp.synthesis_batch_size]
|
85 |
+
for i in range(0, len(inputs), hp.synthesis_batch_size)
|
86 |
+
]
|
87 |
+
batched_embeds = [
|
88 |
+
embeddings[i : i + hp.synthesis_batch_size]
|
89 |
+
for i in range(0, len(embeddings), hp.synthesis_batch_size)
|
90 |
+
]
|
91 |
+
|
92 |
+
specs = []
|
93 |
+
alignments = []
|
94 |
+
for i, batch in enumerate(batched_inputs, 1):
|
95 |
+
logger.debug(f"\n| Generating {i}/{len(batched_inputs)}")
|
96 |
+
|
97 |
+
# Pad texts so they are all the same length
|
98 |
+
text_lens = [len(text) for text in batch]
|
99 |
+
max_text_len = max(text_lens)
|
100 |
+
chars = [pad1d(text, max_text_len) for text in batch]
|
101 |
+
chars = np.stack(chars)
|
102 |
+
|
103 |
+
# Stack speaker embeddings into 2D array for batch processing
|
104 |
+
speaker_embeds = np.stack(batched_embeds[i - 1])
|
105 |
+
|
106 |
+
# Convert to tensor
|
107 |
+
chars = torch.tensor(chars).long().to(self.device)
|
108 |
+
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
109 |
+
|
110 |
+
# Inference
|
111 |
+
_, mels, alignments = self._model.generate(
|
112 |
+
chars,
|
113 |
+
speaker_embeddings,
|
114 |
+
style_idx=style_idx,
|
115 |
+
min_stop_token=min_stop_token,
|
116 |
+
steps=steps,
|
117 |
+
)
|
118 |
+
mels = mels.detach().cpu().numpy()
|
119 |
+
for m in mels:
|
120 |
+
# Trim silence from end of each spectrogram
|
121 |
+
while np.max(m[:, -1]) < hp.tts_stop_threshold:
|
122 |
+
m = m[:, :-1]
|
123 |
+
specs.append(m)
|
124 |
+
|
125 |
+
logger.debug("\n\nDone.\n")
|
126 |
+
return (specs, alignments) if return_alignments else specs
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def load_preprocess_wav(fpath):
|
130 |
+
"""
|
131 |
+
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
132 |
+
train the synthesizer.
|
133 |
+
"""
|
134 |
+
wav = librosa.load(path=str(fpath), sr=hp.sample_rate)[0]
|
135 |
+
if hp.rescale:
|
136 |
+
wav = wav / np.abs(wav).max() * hp.rescaling_max
|
137 |
+
# denoise
|
138 |
+
if len(wav) > hp.sample_rate * (0.3 + 0.1):
|
139 |
+
noise_wav = np.concatenate(
|
140 |
+
[
|
141 |
+
wav[: int(hp.sample_rate * 0.15)],
|
142 |
+
wav[-int(hp.sample_rate * 0.15) :],
|
143 |
+
]
|
144 |
+
)
|
145 |
+
profile = profile_noise(noise_wav, hp.sample_rate)
|
146 |
+
wav = denoise(wav, profile)
|
147 |
+
return wav
|
148 |
+
|
149 |
+
|
150 |
+
def pad1d(x, max_len, pad_value=0):
|
151 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
mockingbirdforuse/synthesizer/models/global_style_token.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.init as init
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
import torch.nn.functional as tFunctional
|
6 |
+
|
7 |
+
from ..hparams import hparams as hp
|
8 |
+
from ..gst_hyperparameters import hparams as gst_hp
|
9 |
+
|
10 |
+
|
11 |
+
class GlobalStyleToken(nn.Module):
|
12 |
+
"""
|
13 |
+
inputs: style mel spectrograms [batch_size, num_spec_frames, num_mel]
|
14 |
+
speaker_embedding: speaker mel spectrograms [batch_size, num_spec_frames, num_mel]
|
15 |
+
outputs: [batch_size, embedding_dim]
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, speaker_embedding_dim=None):
|
19 |
+
|
20 |
+
super().__init__()
|
21 |
+
self.encoder = ReferenceEncoder()
|
22 |
+
self.stl = STL(speaker_embedding_dim)
|
23 |
+
|
24 |
+
def forward(self, inputs, speaker_embedding=None):
|
25 |
+
enc_out = self.encoder(inputs)
|
26 |
+
# concat speaker_embedding according to https://github.com/mozilla/TTS/blob/master/TTS/tts/layers/gst_layers.py
|
27 |
+
if hp.use_ser_for_gst and speaker_embedding is not None:
|
28 |
+
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
29 |
+
style_embed = self.stl(enc_out)
|
30 |
+
|
31 |
+
return style_embed
|
32 |
+
|
33 |
+
|
34 |
+
class ReferenceEncoder(nn.Module):
|
35 |
+
"""
|
36 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
37 |
+
outputs --- [N, ref_enc_gru_size]
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self):
|
41 |
+
|
42 |
+
super().__init__()
|
43 |
+
K = len(gst_hp.ref_enc_filters)
|
44 |
+
filters = [1] + gst_hp.ref_enc_filters
|
45 |
+
convs = [
|
46 |
+
nn.Conv2d(
|
47 |
+
in_channels=filters[i],
|
48 |
+
out_channels=filters[i + 1],
|
49 |
+
kernel_size=(3, 3),
|
50 |
+
stride=(2, 2),
|
51 |
+
padding=(1, 1),
|
52 |
+
)
|
53 |
+
for i in range(K)
|
54 |
+
]
|
55 |
+
self.convs = nn.ModuleList(convs)
|
56 |
+
self.bns = nn.ModuleList(
|
57 |
+
[nn.BatchNorm2d(num_features=gst_hp.ref_enc_filters[i]) for i in range(K)]
|
58 |
+
)
|
59 |
+
|
60 |
+
out_channels = self.calculate_channels(gst_hp.n_mels, 3, 2, 1, K)
|
61 |
+
self.gru = nn.GRU(
|
62 |
+
input_size=gst_hp.ref_enc_filters[-1] * out_channels,
|
63 |
+
hidden_size=gst_hp.E // 2,
|
64 |
+
batch_first=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
def forward(self, inputs):
|
68 |
+
N = inputs.size(0)
|
69 |
+
out = inputs.view(N, 1, -1, gst_hp.n_mels) # [N, 1, Ty, n_mels]
|
70 |
+
for conv, bn in zip(self.convs, self.bns):
|
71 |
+
out = conv(out)
|
72 |
+
out = bn(out)
|
73 |
+
out = tFunctional.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
74 |
+
|
75 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
76 |
+
T = out.size(1)
|
77 |
+
N = out.size(0)
|
78 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
79 |
+
|
80 |
+
self.gru.flatten_parameters()
|
81 |
+
memory, out = self.gru(out) # out --- [1, N, E//2]
|
82 |
+
|
83 |
+
return out.squeeze(0)
|
84 |
+
|
85 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
86 |
+
for i in range(n_convs):
|
87 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
88 |
+
return L
|
89 |
+
|
90 |
+
|
91 |
+
class STL(nn.Module):
|
92 |
+
"""
|
93 |
+
inputs --- [N, E//2]
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self, speaker_embedding_dim=None):
|
97 |
+
|
98 |
+
super().__init__()
|
99 |
+
self.embed = Parameter(
|
100 |
+
torch.FloatTensor(gst_hp.token_num, gst_hp.E // gst_hp.num_heads)
|
101 |
+
)
|
102 |
+
d_q = gst_hp.E // 2
|
103 |
+
d_k = gst_hp.E // gst_hp.num_heads
|
104 |
+
# self.attention = MultiHeadAttention(gst_hp.num_heads, d_model, d_q, d_v)
|
105 |
+
if hp.use_ser_for_gst and speaker_embedding_dim is not None:
|
106 |
+
d_q += speaker_embedding_dim
|
107 |
+
self.attention = MultiHeadAttention(
|
108 |
+
query_dim=d_q, key_dim=d_k, num_units=gst_hp.E, num_heads=gst_hp.num_heads
|
109 |
+
)
|
110 |
+
|
111 |
+
init.normal_(self.embed, mean=0, std=0.5)
|
112 |
+
|
113 |
+
def forward(self, inputs):
|
114 |
+
N = inputs.size(0)
|
115 |
+
query = inputs.unsqueeze(1) # [N, 1, E//2]
|
116 |
+
keys = (
|
117 |
+
torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
|
118 |
+
) # [N, token_num, E // num_heads]
|
119 |
+
style_embed = self.attention(query, keys)
|
120 |
+
|
121 |
+
return style_embed
|
122 |
+
|
123 |
+
|
124 |
+
class MultiHeadAttention(nn.Module):
|
125 |
+
"""
|
126 |
+
input:
|
127 |
+
query --- [N, T_q, query_dim]
|
128 |
+
key --- [N, T_k, key_dim]
|
129 |
+
output:
|
130 |
+
out --- [N, T_q, num_units]
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
134 |
+
|
135 |
+
super().__init__()
|
136 |
+
self.num_units = num_units
|
137 |
+
self.num_heads = num_heads
|
138 |
+
self.key_dim = key_dim
|
139 |
+
|
140 |
+
self.W_query = nn.Linear(
|
141 |
+
in_features=query_dim, out_features=num_units, bias=False
|
142 |
+
)
|
143 |
+
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
144 |
+
self.W_value = nn.Linear(
|
145 |
+
in_features=key_dim, out_features=num_units, bias=False
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, query, key):
|
149 |
+
querys = self.W_query(query) # [N, T_q, num_units]
|
150 |
+
keys = self.W_key(key) # [N, T_k, num_units]
|
151 |
+
values = self.W_value(key)
|
152 |
+
|
153 |
+
split_size = self.num_units // self.num_heads
|
154 |
+
querys = torch.stack(
|
155 |
+
torch.split(querys, split_size, dim=2), dim=0
|
156 |
+
) # [h, N, T_q, num_units/h]
|
157 |
+
keys = torch.stack(
|
158 |
+
torch.split(keys, split_size, dim=2), dim=0
|
159 |
+
) # [h, N, T_k, num_units/h]
|
160 |
+
values = torch.stack(
|
161 |
+
torch.split(values, split_size, dim=2), dim=0
|
162 |
+
) # [h, N, T_k, num_units/h]
|
163 |
+
|
164 |
+
# score = softmax(QK^T / (d_k ** 0.5))
|
165 |
+
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
166 |
+
scores = scores / (self.key_dim**0.5)
|
167 |
+
scores = tFunctional.softmax(scores, dim=3)
|
168 |
+
|
169 |
+
# out = score * V
|
170 |
+
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
171 |
+
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
|
172 |
+
0
|
173 |
+
) # [N, T_q, num_units]
|
174 |
+
|
175 |
+
return out
|
mockingbirdforuse/synthesizer/models/tacotron.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..hparams import hparams as hp
|
7 |
+
from .global_style_token import GlobalStyleToken
|
8 |
+
from ..gst_hyperparameters import hparams as gst_hp
|
9 |
+
from ...log import logger
|
10 |
+
|
11 |
+
|
12 |
+
class HighwayNetwork(nn.Module):
|
13 |
+
def __init__(self, size):
|
14 |
+
super().__init__()
|
15 |
+
self.W1 = nn.Linear(size, size)
|
16 |
+
self.W2 = nn.Linear(size, size)
|
17 |
+
self.W1.bias.data.fill_(0.0)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x1 = self.W1(x)
|
21 |
+
x2 = self.W2(x)
|
22 |
+
g = torch.sigmoid(x2)
|
23 |
+
y = g * F.relu(x1) + (1.0 - g) * x
|
24 |
+
return y
|
25 |
+
|
26 |
+
|
27 |
+
class Encoder(nn.Module):
|
28 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
29 |
+
super().__init__()
|
30 |
+
prenet_dims = (encoder_dims, encoder_dims)
|
31 |
+
cbhg_channels = encoder_dims
|
32 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
33 |
+
self.pre_net = PreNet(
|
34 |
+
embed_dims,
|
35 |
+
fc1_dims=prenet_dims[0],
|
36 |
+
fc2_dims=prenet_dims[1],
|
37 |
+
dropout=dropout,
|
38 |
+
)
|
39 |
+
self.cbhg = CBHG(
|
40 |
+
K=K,
|
41 |
+
in_channels=cbhg_channels,
|
42 |
+
channels=cbhg_channels,
|
43 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
44 |
+
num_highways=num_highways,
|
45 |
+
)
|
46 |
+
|
47 |
+
def forward(self, x, speaker_embedding=None):
|
48 |
+
x = self.embedding(x)
|
49 |
+
x = self.pre_net(x)
|
50 |
+
x.transpose_(1, 2)
|
51 |
+
x = self.cbhg(x)
|
52 |
+
if speaker_embedding is not None:
|
53 |
+
x = self.add_speaker_embedding(x, speaker_embedding)
|
54 |
+
return x
|
55 |
+
|
56 |
+
def add_speaker_embedding(self, x, speaker_embedding):
|
57 |
+
# SV2TTS
|
58 |
+
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
59 |
+
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
60 |
+
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
61 |
+
# This concats the speaker embedding for each char in the encoder output
|
62 |
+
|
63 |
+
# Save the dimensions as human-readable names
|
64 |
+
batch_size = x.size()[0]
|
65 |
+
num_chars = x.size()[1]
|
66 |
+
|
67 |
+
if speaker_embedding.dim() == 1:
|
68 |
+
idx = 0
|
69 |
+
else:
|
70 |
+
idx = 1
|
71 |
+
|
72 |
+
# Start by making a copy of each speaker embedding to match the input text length
|
73 |
+
# The output of this has size (batch_size, num_chars * speaker_embedding_size)
|
74 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
75 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
76 |
+
|
77 |
+
# Reshape it and transpose
|
78 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
79 |
+
e = e.transpose(1, 2)
|
80 |
+
|
81 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
82 |
+
x = torch.cat((x, e), 2)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class BatchNormConv(nn.Module):
|
87 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
88 |
+
super().__init__()
|
89 |
+
self.conv = nn.Conv1d(
|
90 |
+
in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False
|
91 |
+
)
|
92 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
93 |
+
self.relu = relu
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.conv(x)
|
97 |
+
x = F.relu(x) if self.relu is True else x
|
98 |
+
return self.bnorm(x)
|
99 |
+
|
100 |
+
|
101 |
+
class CBHG(nn.Module):
|
102 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
# List of all rnns to call `flatten_parameters()` on
|
106 |
+
self._to_flatten = []
|
107 |
+
|
108 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
109 |
+
self.conv1d_bank = nn.ModuleList()
|
110 |
+
for k in self.bank_kernels:
|
111 |
+
conv = BatchNormConv(in_channels, channels, k)
|
112 |
+
self.conv1d_bank.append(conv)
|
113 |
+
|
114 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
115 |
+
|
116 |
+
self.conv_project1 = BatchNormConv(
|
117 |
+
len(self.bank_kernels) * channels, proj_channels[0], 3
|
118 |
+
)
|
119 |
+
self.conv_project2 = BatchNormConv(
|
120 |
+
proj_channels[0], proj_channels[1], 3, relu=False
|
121 |
+
)
|
122 |
+
|
123 |
+
# Fix the highway input if necessary
|
124 |
+
if proj_channels[-1] != channels:
|
125 |
+
self.highway_mismatch = True
|
126 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
127 |
+
else:
|
128 |
+
self.highway_mismatch = False
|
129 |
+
|
130 |
+
self.highways = nn.ModuleList()
|
131 |
+
for i in range(num_highways):
|
132 |
+
hn = HighwayNetwork(channels)
|
133 |
+
self.highways.append(hn)
|
134 |
+
|
135 |
+
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
136 |
+
self._to_flatten.append(self.rnn)
|
137 |
+
|
138 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
139 |
+
self._flatten_parameters()
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
143 |
+
# the model gets replicated, making it no longer guaranteed that the
|
144 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
145 |
+
self.rnn.flatten_parameters()
|
146 |
+
|
147 |
+
# Save these for later
|
148 |
+
residual = x
|
149 |
+
seq_len = x.size(-1)
|
150 |
+
conv_bank = []
|
151 |
+
|
152 |
+
# Convolution Bank
|
153 |
+
for conv in self.conv1d_bank:
|
154 |
+
c = conv(x) # Convolution
|
155 |
+
conv_bank.append(c[:, :, :seq_len])
|
156 |
+
|
157 |
+
# Stack along the channel axis
|
158 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
159 |
+
|
160 |
+
# dump the last padding to fit residual
|
161 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
162 |
+
|
163 |
+
# Conv1d projections
|
164 |
+
x = self.conv_project1(x)
|
165 |
+
x = self.conv_project2(x)
|
166 |
+
|
167 |
+
# Residual Connect
|
168 |
+
x = x + residual
|
169 |
+
|
170 |
+
# Through the highways
|
171 |
+
x = x.transpose(1, 2)
|
172 |
+
if self.highway_mismatch is True:
|
173 |
+
x = self.pre_highway(x)
|
174 |
+
for h in self.highways:
|
175 |
+
x = h(x)
|
176 |
+
|
177 |
+
# And then the RNN
|
178 |
+
x, _ = self.rnn(x)
|
179 |
+
return x
|
180 |
+
|
181 |
+
def _flatten_parameters(self):
|
182 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
183 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
184 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
185 |
+
|
186 |
+
|
187 |
+
class PreNet(nn.Module):
|
188 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
189 |
+
super().__init__()
|
190 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
191 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
192 |
+
self.p = dropout
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
x = self.fc1(x)
|
196 |
+
x = F.relu(x)
|
197 |
+
x = F.dropout(x, self.p, training=True)
|
198 |
+
x = self.fc2(x)
|
199 |
+
x = F.relu(x)
|
200 |
+
x = F.dropout(x, self.p, training=True)
|
201 |
+
return x
|
202 |
+
|
203 |
+
|
204 |
+
class Attention(nn.Module):
|
205 |
+
def __init__(self, attn_dims):
|
206 |
+
super().__init__()
|
207 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
208 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
209 |
+
|
210 |
+
def forward(self, encoder_seq_proj, query, t):
|
211 |
+
# Transform the query vector
|
212 |
+
query_proj = self.W(query).unsqueeze(1)
|
213 |
+
|
214 |
+
# Compute the scores
|
215 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
216 |
+
scores = F.softmax(u, dim=1)
|
217 |
+
|
218 |
+
return scores.transpose(1, 2)
|
219 |
+
|
220 |
+
|
221 |
+
class LSA(nn.Module):
|
222 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
223 |
+
super().__init__()
|
224 |
+
self.conv = nn.Conv1d(
|
225 |
+
1,
|
226 |
+
filters,
|
227 |
+
padding=(kernel_size - 1) // 2,
|
228 |
+
kernel_size=kernel_size,
|
229 |
+
bias=True,
|
230 |
+
)
|
231 |
+
self.L = nn.Linear(filters, attn_dim, bias=False)
|
232 |
+
self.W = nn.Linear(
|
233 |
+
attn_dim, attn_dim, bias=True
|
234 |
+
) # Include the attention bias in this term
|
235 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
236 |
+
self.cumulative = None
|
237 |
+
self.attention = None
|
238 |
+
|
239 |
+
def init_attention(self, encoder_seq_proj):
|
240 |
+
device = encoder_seq_proj.device # use same device as parameters
|
241 |
+
b, t, c = encoder_seq_proj.size()
|
242 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
243 |
+
self.attention = torch.zeros(b, t, device=device)
|
244 |
+
|
245 |
+
def forward(self, encoder_seq_proj, query, t, chars):
|
246 |
+
|
247 |
+
if t == 0:
|
248 |
+
self.init_attention(encoder_seq_proj)
|
249 |
+
|
250 |
+
processed_query = self.W(query).unsqueeze(1)
|
251 |
+
|
252 |
+
location = self.cumulative.unsqueeze(1)
|
253 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
254 |
+
|
255 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
256 |
+
u = u.squeeze(-1)
|
257 |
+
|
258 |
+
# Mask zero padding chars
|
259 |
+
u = u * (chars != 0).float()
|
260 |
+
|
261 |
+
# Smooth Attention
|
262 |
+
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
263 |
+
scores = F.softmax(u, dim=1)
|
264 |
+
self.attention = scores
|
265 |
+
self.cumulative = self.cumulative + self.attention
|
266 |
+
|
267 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
268 |
+
|
269 |
+
|
270 |
+
class Decoder(nn.Module):
|
271 |
+
# Class variable because its value doesn't change between classes
|
272 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
273 |
+
max_r = 20
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
n_mels,
|
278 |
+
encoder_dims,
|
279 |
+
decoder_dims,
|
280 |
+
lstm_dims,
|
281 |
+
dropout,
|
282 |
+
speaker_embedding_size,
|
283 |
+
):
|
284 |
+
super().__init__()
|
285 |
+
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
286 |
+
self.n_mels = n_mels
|
287 |
+
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
288 |
+
self.prenet = PreNet(
|
289 |
+
n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], dropout=dropout
|
290 |
+
)
|
291 |
+
self.attn_net = LSA(decoder_dims)
|
292 |
+
if hp.use_gst:
|
293 |
+
speaker_embedding_size += gst_hp.E
|
294 |
+
self.attn_rnn = nn.GRUCell(
|
295 |
+
encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims
|
296 |
+
)
|
297 |
+
self.rnn_input = nn.Linear(
|
298 |
+
encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims
|
299 |
+
)
|
300 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
301 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
302 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
303 |
+
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
304 |
+
|
305 |
+
def zoneout(self, prev, current, device, p=0.1):
|
306 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
307 |
+
return prev * mask + current * (1 - mask)
|
308 |
+
|
309 |
+
def forward(
|
310 |
+
self,
|
311 |
+
encoder_seq,
|
312 |
+
encoder_seq_proj,
|
313 |
+
prenet_in,
|
314 |
+
hidden_states,
|
315 |
+
cell_states,
|
316 |
+
context_vec,
|
317 |
+
t,
|
318 |
+
chars,
|
319 |
+
):
|
320 |
+
|
321 |
+
# Need this for reshaping mels
|
322 |
+
batch_size = encoder_seq.size(0)
|
323 |
+
device = encoder_seq.device
|
324 |
+
# Unpack the hidden and cell states
|
325 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
326 |
+
rnn1_cell, rnn2_cell = cell_states
|
327 |
+
|
328 |
+
# PreNet for the Attention RNN
|
329 |
+
prenet_out = self.prenet(prenet_in)
|
330 |
+
|
331 |
+
# Compute the Attention RNN hidden state
|
332 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
333 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
334 |
+
|
335 |
+
# Compute the attention scores
|
336 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
337 |
+
|
338 |
+
# Dot product to create the context vector
|
339 |
+
context_vec = scores @ encoder_seq
|
340 |
+
context_vec = context_vec.squeeze(1)
|
341 |
+
|
342 |
+
# Concat Attention RNN output w. Context Vector & project
|
343 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
344 |
+
x = self.rnn_input(x)
|
345 |
+
|
346 |
+
# Compute first Residual RNN
|
347 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
348 |
+
if self.training:
|
349 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next, device=device)
|
350 |
+
else:
|
351 |
+
rnn1_hidden = rnn1_hidden_next
|
352 |
+
x = x + rnn1_hidden
|
353 |
+
|
354 |
+
# Compute second Residual RNN
|
355 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
356 |
+
if self.training:
|
357 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next, device=device)
|
358 |
+
else:
|
359 |
+
rnn2_hidden = rnn2_hidden_next
|
360 |
+
x = x + rnn2_hidden
|
361 |
+
|
362 |
+
# Project Mels
|
363 |
+
mels = self.mel_proj(x)
|
364 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, : self.r]
|
365 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
366 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
367 |
+
|
368 |
+
# Stop token prediction
|
369 |
+
s = torch.cat((x, context_vec), dim=1)
|
370 |
+
s = self.stop_proj(s)
|
371 |
+
stop_tokens = torch.sigmoid(s)
|
372 |
+
|
373 |
+
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
374 |
+
|
375 |
+
|
376 |
+
class Tacotron(nn.Module):
|
377 |
+
def __init__(
|
378 |
+
self,
|
379 |
+
embed_dims,
|
380 |
+
num_chars,
|
381 |
+
encoder_dims,
|
382 |
+
decoder_dims,
|
383 |
+
n_mels,
|
384 |
+
fft_bins,
|
385 |
+
postnet_dims,
|
386 |
+
encoder_K,
|
387 |
+
lstm_dims,
|
388 |
+
postnet_K,
|
389 |
+
num_highways,
|
390 |
+
dropout,
|
391 |
+
stop_threshold,
|
392 |
+
speaker_embedding_size,
|
393 |
+
):
|
394 |
+
super().__init__()
|
395 |
+
self.n_mels = n_mels
|
396 |
+
self.lstm_dims = lstm_dims
|
397 |
+
self.encoder_dims = encoder_dims
|
398 |
+
self.decoder_dims = decoder_dims
|
399 |
+
self.speaker_embedding_size = speaker_embedding_size
|
400 |
+
self.encoder = Encoder(
|
401 |
+
embed_dims, num_chars, encoder_dims, encoder_K, num_highways, dropout
|
402 |
+
)
|
403 |
+
project_dims = encoder_dims + speaker_embedding_size
|
404 |
+
if hp.use_gst:
|
405 |
+
project_dims += gst_hp.E
|
406 |
+
self.encoder_proj = nn.Linear(project_dims, decoder_dims, bias=False)
|
407 |
+
if hp.use_gst:
|
408 |
+
self.gst = GlobalStyleToken(speaker_embedding_size)
|
409 |
+
self.decoder = Decoder(
|
410 |
+
n_mels,
|
411 |
+
encoder_dims,
|
412 |
+
decoder_dims,
|
413 |
+
lstm_dims,
|
414 |
+
dropout,
|
415 |
+
speaker_embedding_size,
|
416 |
+
)
|
417 |
+
self.postnet = CBHG(
|
418 |
+
postnet_K, n_mels, postnet_dims, [postnet_dims, fft_bins], num_highways
|
419 |
+
)
|
420 |
+
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
421 |
+
|
422 |
+
self.init_model()
|
423 |
+
self.num_params()
|
424 |
+
|
425 |
+
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
426 |
+
self.register_buffer(
|
427 |
+
"stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)
|
428 |
+
)
|
429 |
+
|
430 |
+
@property
|
431 |
+
def r(self):
|
432 |
+
return self.decoder.r.item()
|
433 |
+
|
434 |
+
@r.setter
|
435 |
+
def r(self, value):
|
436 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
437 |
+
|
438 |
+
@staticmethod
|
439 |
+
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
440 |
+
speaker_embeddings_ = speaker_embeddings.expand(
|
441 |
+
outputs.size(0), outputs.size(1), -1
|
442 |
+
)
|
443 |
+
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
444 |
+
return outputs
|
445 |
+
|
446 |
+
def forward(self, texts, mels, speaker_embedding):
|
447 |
+
device = texts.device # use same device as parameters
|
448 |
+
|
449 |
+
self.step += 1
|
450 |
+
batch_size, _, steps = mels.size()
|
451 |
+
|
452 |
+
# Initialise all hidden states and pack into tuple
|
453 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
454 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
455 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
456 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
457 |
+
|
458 |
+
# Initialise all lstm cell states and pack into tuple
|
459 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
460 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
461 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
462 |
+
|
463 |
+
# <GO> Frame for start of decoder loop
|
464 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
465 |
+
|
466 |
+
# Need an initial context vector
|
467 |
+
size = self.encoder_dims + self.speaker_embedding_size
|
468 |
+
if hp.use_gst:
|
469 |
+
size += gst_hp.E
|
470 |
+
context_vec = torch.zeros(batch_size, size, device=device)
|
471 |
+
|
472 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
473 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
474 |
+
encoder_seq = self.encoder(texts, speaker_embedding)
|
475 |
+
# put after encoder
|
476 |
+
if hp.use_gst and self.gst is not None:
|
477 |
+
style_embed = self.gst(
|
478 |
+
speaker_embedding, speaker_embedding
|
479 |
+
) # for training, speaker embedding can represent both style inputs and referenced
|
480 |
+
# style_embed = style_embed.expand_as(encoder_seq)
|
481 |
+
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
482 |
+
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
483 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
484 |
+
|
485 |
+
# Need a couple of lists for outputs
|
486 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
487 |
+
|
488 |
+
# Run the decoder loop
|
489 |
+
for t in range(0, steps, self.r):
|
490 |
+
prenet_in = mels[:, :, t - 1] if t > 0 else go_frame
|
491 |
+
(
|
492 |
+
mel_frames,
|
493 |
+
scores,
|
494 |
+
hidden_states,
|
495 |
+
cell_states,
|
496 |
+
context_vec,
|
497 |
+
stop_tokens,
|
498 |
+
) = self.decoder(
|
499 |
+
encoder_seq,
|
500 |
+
encoder_seq_proj,
|
501 |
+
prenet_in,
|
502 |
+
hidden_states,
|
503 |
+
cell_states,
|
504 |
+
context_vec,
|
505 |
+
t,
|
506 |
+
texts,
|
507 |
+
)
|
508 |
+
mel_outputs.append(mel_frames)
|
509 |
+
attn_scores.append(scores)
|
510 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
511 |
+
|
512 |
+
# Concat the mel outputs into sequence
|
513 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
514 |
+
|
515 |
+
# Post-Process for Linear Spectrograms
|
516 |
+
postnet_out = self.postnet(mel_outputs)
|
517 |
+
linear = self.post_proj(postnet_out)
|
518 |
+
linear = linear.transpose(1, 2)
|
519 |
+
|
520 |
+
# For easy visualisation
|
521 |
+
attn_scores = torch.cat(attn_scores, 1)
|
522 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
523 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
524 |
+
|
525 |
+
return mel_outputs, linear, attn_scores, stop_outputs
|
526 |
+
|
527 |
+
def generate(
|
528 |
+
self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5
|
529 |
+
):
|
530 |
+
self.eval()
|
531 |
+
device = x.device # use same device as parameters
|
532 |
+
|
533 |
+
batch_size, _ = x.size()
|
534 |
+
|
535 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
536 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
537 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
538 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
539 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
540 |
+
|
541 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
542 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
543 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
544 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
545 |
+
|
546 |
+
# Need a <GO> Frame for start of decoder loop
|
547 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
548 |
+
|
549 |
+
# Need an initial context vector
|
550 |
+
size = self.encoder_dims + self.speaker_embedding_size
|
551 |
+
if hp.use_gst:
|
552 |
+
size += gst_hp.E
|
553 |
+
context_vec = torch.zeros(batch_size, size, device=device)
|
554 |
+
|
555 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
556 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
557 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
558 |
+
|
559 |
+
# put after encoder
|
560 |
+
if hp.use_gst and self.gst is not None:
|
561 |
+
if style_idx >= 0 and style_idx < 10:
|
562 |
+
query = torch.zeros(1, 1, self.gst.stl.attention.num_units)
|
563 |
+
if device.type == "cuda":
|
564 |
+
query = query.cuda()
|
565 |
+
gst_embed = torch.tanh(self.gst.stl.embed)
|
566 |
+
key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1)
|
567 |
+
style_embed = self.gst.stl.attention(query, key)
|
568 |
+
else:
|
569 |
+
speaker_embedding_style = torch.zeros(
|
570 |
+
speaker_embedding.size()[0], 1, self.speaker_embedding_size
|
571 |
+
).to(device)
|
572 |
+
style_embed = self.gst(speaker_embedding_style, speaker_embedding)
|
573 |
+
encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed)
|
574 |
+
# style_embed = style_embed.expand_as(encoder_seq)
|
575 |
+
# encoder_seq = torch.cat((encoder_seq, style_embed), 2)
|
576 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
577 |
+
|
578 |
+
# Need a couple of lists for outputs
|
579 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
580 |
+
|
581 |
+
# Run the decoder loop
|
582 |
+
for t in range(0, steps, self.r):
|
583 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
584 |
+
(
|
585 |
+
mel_frames,
|
586 |
+
scores,
|
587 |
+
hidden_states,
|
588 |
+
cell_states,
|
589 |
+
context_vec,
|
590 |
+
stop_tokens,
|
591 |
+
) = self.decoder(
|
592 |
+
encoder_seq,
|
593 |
+
encoder_seq_proj,
|
594 |
+
prenet_in,
|
595 |
+
hidden_states,
|
596 |
+
cell_states,
|
597 |
+
context_vec,
|
598 |
+
t,
|
599 |
+
x,
|
600 |
+
)
|
601 |
+
mel_outputs.append(mel_frames)
|
602 |
+
attn_scores.append(scores)
|
603 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
604 |
+
# Stop the loop when all stop tokens in batch exceed threshold
|
605 |
+
if (stop_tokens * 10 > min_stop_token).all() and t > 10:
|
606 |
+
break
|
607 |
+
|
608 |
+
# Concat the mel outputs into sequence
|
609 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
610 |
+
|
611 |
+
# Post-Process for Linear Spectrograms
|
612 |
+
postnet_out = self.postnet(mel_outputs)
|
613 |
+
linear = self.post_proj(postnet_out)
|
614 |
+
|
615 |
+
linear = linear.transpose(1, 2)
|
616 |
+
|
617 |
+
# For easy visualisation
|
618 |
+
attn_scores = torch.cat(attn_scores, 1)
|
619 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
620 |
+
|
621 |
+
self.train()
|
622 |
+
|
623 |
+
return mel_outputs, linear, attn_scores
|
624 |
+
|
625 |
+
def init_model(self):
|
626 |
+
for p in self.parameters():
|
627 |
+
if p.dim() > 1:
|
628 |
+
nn.init.xavier_uniform_(p)
|
629 |
+
|
630 |
+
def finetune_partial(self, whitelist_layers):
|
631 |
+
self.zero_grad()
|
632 |
+
for name, child in self.named_children():
|
633 |
+
if name in whitelist_layers:
|
634 |
+
logger.debug("Trainable Layer: %s" % name)
|
635 |
+
logger.debug(
|
636 |
+
"Trainable Parameters: %.3f"
|
637 |
+
% sum([np.prod(p.size()) for p in child.parameters()])
|
638 |
+
)
|
639 |
+
for param in child.parameters():
|
640 |
+
param.requires_grad = False
|
641 |
+
|
642 |
+
def get_step(self):
|
643 |
+
return self.step.data.item()
|
644 |
+
|
645 |
+
def reset_step(self):
|
646 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
647 |
+
self.step = self.step.data.new_tensor(1)
|
648 |
+
|
649 |
+
def load(self, path, device, optimizer=None):
|
650 |
+
# Use device of model params as location for loaded state
|
651 |
+
checkpoint = torch.load(str(path), map_location=device)
|
652 |
+
self.load_state_dict(checkpoint["model_state"], strict=False)
|
653 |
+
|
654 |
+
if "optimizer_state" in checkpoint and optimizer is not None:
|
655 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
656 |
+
|
657 |
+
def save(self, path, optimizer=None):
|
658 |
+
if optimizer is not None:
|
659 |
+
torch.save(
|
660 |
+
{
|
661 |
+
"model_state": self.state_dict(),
|
662 |
+
"optimizer_state": optimizer.state_dict(),
|
663 |
+
},
|
664 |
+
str(path),
|
665 |
+
)
|
666 |
+
else:
|
667 |
+
torch.save(
|
668 |
+
{
|
669 |
+
"model_state": self.state_dict(),
|
670 |
+
},
|
671 |
+
str(path),
|
672 |
+
)
|
673 |
+
|
674 |
+
def num_params(self):
|
675 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
676 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
677 |
+
logger.debug("Trainable Parameters: %.3fM" % parameters)
|
678 |
+
return parameters
|
mockingbirdforuse/synthesizer/utils/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
_output_ref = None
|
5 |
+
_replicas_ref = None
|
6 |
+
|
7 |
+
|
8 |
+
def data_parallel_workaround(model, *input):
|
9 |
+
global _output_ref
|
10 |
+
global _replicas_ref
|
11 |
+
device_ids = list(range(torch.cuda.device_count()))
|
12 |
+
output_device = device_ids[0]
|
13 |
+
replicas = torch.nn.parallel.replicate(model, device_ids)
|
14 |
+
# input.shape = (num_args, batch, ...)
|
15 |
+
inputs = torch.nn.parallel.scatter(input, device_ids)
|
16 |
+
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
17 |
+
replicas = replicas[: len(inputs)]
|
18 |
+
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
19 |
+
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
20 |
+
_output_ref = outputs
|
21 |
+
_replicas_ref = replicas
|
22 |
+
return y_hat
|
23 |
+
|
24 |
+
|
25 |
+
class ValueWindow:
|
26 |
+
def __init__(self, window_size=100):
|
27 |
+
self._window_size = window_size
|
28 |
+
self._values = []
|
29 |
+
|
30 |
+
def append(self, x):
|
31 |
+
self._values = self._values[-(self._window_size - 1) :] + [x]
|
32 |
+
|
33 |
+
@property
|
34 |
+
def sum(self):
|
35 |
+
return sum(self._values)
|
36 |
+
|
37 |
+
@property
|
38 |
+
def count(self):
|
39 |
+
return len(self._values)
|
40 |
+
|
41 |
+
@property
|
42 |
+
def average(self):
|
43 |
+
return self.sum / max(1, self.count)
|
44 |
+
|
45 |
+
def reset(self):
|
46 |
+
self._values = []
|
mockingbirdforuse/synthesizer/utils/cleaners.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
3 |
+
|
4 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
5 |
+
hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
|
6 |
+
1. "english_cleaners" for English text
|
7 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
8 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
9 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
10 |
+
the symbols in symbols.py to match your data).
|
11 |
+
"""
|
12 |
+
|
13 |
+
import re
|
14 |
+
from unidecode import unidecode
|
15 |
+
from .numbers import normalize_numbers
|
16 |
+
|
17 |
+
# Regular expression matching whitespace:
|
18 |
+
_whitespace_re = re.compile(r"\s+")
|
19 |
+
|
20 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
21 |
+
_abbreviations = [
|
22 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
23 |
+
for x in [
|
24 |
+
("mrs", "misess"),
|
25 |
+
("mr", "mister"),
|
26 |
+
("dr", "doctor"),
|
27 |
+
("st", "saint"),
|
28 |
+
("co", "company"),
|
29 |
+
("jr", "junior"),
|
30 |
+
("maj", "major"),
|
31 |
+
("gen", "general"),
|
32 |
+
("drs", "doctors"),
|
33 |
+
("rev", "reverend"),
|
34 |
+
("lt", "lieutenant"),
|
35 |
+
("hon", "honorable"),
|
36 |
+
("sgt", "sergeant"),
|
37 |
+
("capt", "captain"),
|
38 |
+
("esq", "esquire"),
|
39 |
+
("ltd", "limited"),
|
40 |
+
("col", "colonel"),
|
41 |
+
("ft", "fort"),
|
42 |
+
]
|
43 |
+
]
|
44 |
+
|
45 |
+
|
46 |
+
def expand_abbreviations(text):
|
47 |
+
for regex, replacement in _abbreviations:
|
48 |
+
text = re.sub(regex, replacement, text)
|
49 |
+
return text
|
50 |
+
|
51 |
+
|
52 |
+
def expand_numbers(text):
|
53 |
+
return normalize_numbers(text)
|
54 |
+
|
55 |
+
|
56 |
+
def lowercase(text):
|
57 |
+
"""lowercase input tokens."""
|
58 |
+
return text.lower()
|
59 |
+
|
60 |
+
|
61 |
+
def collapse_whitespace(text):
|
62 |
+
return re.sub(_whitespace_re, " ", text)
|
63 |
+
|
64 |
+
|
65 |
+
def convert_to_ascii(text):
|
66 |
+
return unidecode(text)
|
67 |
+
|
68 |
+
|
69 |
+
def basic_cleaners(text):
|
70 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
71 |
+
text = lowercase(text)
|
72 |
+
text = collapse_whitespace(text)
|
73 |
+
return text
|
74 |
+
|
75 |
+
|
76 |
+
def transliteration_cleaners(text):
|
77 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
78 |
+
text = convert_to_ascii(text)
|
79 |
+
text = lowercase(text)
|
80 |
+
text = collapse_whitespace(text)
|
81 |
+
return text
|
82 |
+
|
83 |
+
|
84 |
+
def english_cleaners(text):
|
85 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
86 |
+
text = convert_to_ascii(text)
|
87 |
+
text = lowercase(text)
|
88 |
+
text = expand_numbers(text)
|
89 |
+
text = expand_abbreviations(text)
|
90 |
+
text = collapse_whitespace(text)
|
91 |
+
return text
|
mockingbirdforuse/synthesizer/utils/logmmse.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
#
|
3 |
+
# Copyright (c) 2015 braindead
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
#
|
23 |
+
#
|
24 |
+
# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
|
25 |
+
# simply modified the interface to meet my needs.
|
26 |
+
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import math
|
30 |
+
from scipy.special import expn
|
31 |
+
from collections import namedtuple
|
32 |
+
|
33 |
+
NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
|
34 |
+
|
35 |
+
|
36 |
+
def profile_noise(noise, sampling_rate, window_size=0):
|
37 |
+
"""
|
38 |
+
Creates a profile of the noise in a given waveform.
|
39 |
+
|
40 |
+
:param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
41 |
+
:param sampling_rate: the sampling rate of the audio
|
42 |
+
:param window_size: the size of the window the logmmse algorithm operates on. A default value
|
43 |
+
will be picked if left as 0.
|
44 |
+
:return: a NoiseProfile object
|
45 |
+
"""
|
46 |
+
noise, dtype = to_float(noise)
|
47 |
+
noise += np.finfo(np.float64).eps
|
48 |
+
|
49 |
+
if window_size == 0:
|
50 |
+
window_size = int(math.floor(0.02 * sampling_rate))
|
51 |
+
|
52 |
+
if window_size % 2 == 1:
|
53 |
+
window_size = window_size + 1
|
54 |
+
|
55 |
+
perc = 50
|
56 |
+
len1 = int(math.floor(window_size * perc / 100))
|
57 |
+
len2 = int(window_size - len1)
|
58 |
+
|
59 |
+
win = np.hanning(window_size)
|
60 |
+
win = win * len2 / np.sum(win)
|
61 |
+
n_fft = 2 * window_size
|
62 |
+
|
63 |
+
noise_mean = np.zeros(n_fft)
|
64 |
+
n_frames = len(noise) // window_size
|
65 |
+
for j in range(0, window_size * n_frames, window_size):
|
66 |
+
noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
|
67 |
+
noise_mu2 = (noise_mean / n_frames) ** 2
|
68 |
+
|
69 |
+
return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
|
70 |
+
|
71 |
+
|
72 |
+
def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
|
73 |
+
"""
|
74 |
+
Cleans the noise from a speech waveform given a noise profile. The waveform must have the
|
75 |
+
same sampling rate as the one used to create the noise profile.
|
76 |
+
|
77 |
+
:param wav: a speech waveform as a numpy array of floats or ints.
|
78 |
+
:param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
|
79 |
+
the same) waveform.
|
80 |
+
:param eta: voice threshold for noise update. While the voice activation detection value is
|
81 |
+
below this threshold, the noise profile will be continuously updated throughout the audio.
|
82 |
+
Set to 0 to disable updating the noise profile.
|
83 |
+
:return: the clean wav as a numpy array of floats or ints of the same length.
|
84 |
+
"""
|
85 |
+
wav, dtype = to_float(wav)
|
86 |
+
wav += np.finfo(np.float64).eps
|
87 |
+
p = noise_profile
|
88 |
+
|
89 |
+
nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
|
90 |
+
x_final = np.zeros(nframes * p.len2)
|
91 |
+
|
92 |
+
aa = 0.98
|
93 |
+
mu = 0.98
|
94 |
+
ksi_min = 10 ** (-25 / 10)
|
95 |
+
|
96 |
+
x_old = np.zeros(p.len1)
|
97 |
+
xk_prev = np.zeros(p.len1)
|
98 |
+
noise_mu2 = p.noise_mu2
|
99 |
+
for k in range(0, nframes * p.len2, p.len2):
|
100 |
+
insign = p.win * wav[k:k + p.window_size]
|
101 |
+
|
102 |
+
spec = np.fft.fft(insign, p.n_fft, axis=0)
|
103 |
+
sig = np.absolute(spec)
|
104 |
+
sig2 = sig ** 2
|
105 |
+
|
106 |
+
gammak = np.minimum(sig2 / noise_mu2, 40)
|
107 |
+
|
108 |
+
if xk_prev.all() == 0:
|
109 |
+
ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
110 |
+
else:
|
111 |
+
ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
112 |
+
ksi = np.maximum(ksi_min, ksi)
|
113 |
+
|
114 |
+
log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
|
115 |
+
vad_decision = np.sum(log_sigma_k) / p.window_size
|
116 |
+
if vad_decision < eta:
|
117 |
+
noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
118 |
+
|
119 |
+
a = ksi / (1 + ksi)
|
120 |
+
vk = a * gammak
|
121 |
+
ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
122 |
+
hw = a * np.exp(ei_vk)
|
123 |
+
sig = sig * hw
|
124 |
+
xk_prev = sig ** 2
|
125 |
+
xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
|
126 |
+
xi_w = np.real(xi_w)
|
127 |
+
|
128 |
+
x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
|
129 |
+
x_old = xi_w[p.len1:p.window_size]
|
130 |
+
|
131 |
+
output = from_float(x_final, dtype)
|
132 |
+
output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
|
133 |
+
return output
|
134 |
+
|
135 |
+
|
136 |
+
## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
|
137 |
+
## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
|
138 |
+
## webrctvad
|
139 |
+
# def vad(wav, sampling_rate, eta=0.15, window_size=0):
|
140 |
+
# """
|
141 |
+
# TODO: fix doc
|
142 |
+
# Creates a profile of the noise in a given waveform.
|
143 |
+
#
|
144 |
+
# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
145 |
+
# :param sampling_rate: the sampling rate of the audio
|
146 |
+
# :param window_size: the size of the window the logmmse algorithm operates on. A default value
|
147 |
+
# will be picked if left as 0.
|
148 |
+
# :param eta: voice threshold for noise update. While the voice activation detection value is
|
149 |
+
# below this threshold, the noise profile will be continuously updated throughout the audio.
|
150 |
+
# Set to 0 to disable updating the noise profile.
|
151 |
+
# """
|
152 |
+
# wav, dtype = to_float(wav)
|
153 |
+
# wav += np.finfo(np.float64).eps
|
154 |
+
#
|
155 |
+
# if window_size == 0:
|
156 |
+
# window_size = int(math.floor(0.02 * sampling_rate))
|
157 |
+
#
|
158 |
+
# if window_size % 2 == 1:
|
159 |
+
# window_size = window_size + 1
|
160 |
+
#
|
161 |
+
# perc = 50
|
162 |
+
# len1 = int(math.floor(window_size * perc / 100))
|
163 |
+
# len2 = int(window_size - len1)
|
164 |
+
#
|
165 |
+
# win = np.hanning(window_size)
|
166 |
+
# win = win * len2 / np.sum(win)
|
167 |
+
# n_fft = 2 * window_size
|
168 |
+
#
|
169 |
+
# wav_mean = np.zeros(n_fft)
|
170 |
+
# n_frames = len(wav) // window_size
|
171 |
+
# for j in range(0, window_size * n_frames, window_size):
|
172 |
+
# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
|
173 |
+
# noise_mu2 = (wav_mean / n_frames) ** 2
|
174 |
+
#
|
175 |
+
# wav, dtype = to_float(wav)
|
176 |
+
# wav += np.finfo(np.float64).eps
|
177 |
+
#
|
178 |
+
# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
|
179 |
+
# vad = np.zeros(nframes * len2, dtype=np.bool)
|
180 |
+
#
|
181 |
+
# aa = 0.98
|
182 |
+
# mu = 0.98
|
183 |
+
# ksi_min = 10 ** (-25 / 10)
|
184 |
+
#
|
185 |
+
# xk_prev = np.zeros(len1)
|
186 |
+
# noise_mu2 = noise_mu2
|
187 |
+
# for k in range(0, nframes * len2, len2):
|
188 |
+
# insign = win * wav[k:k + window_size]
|
189 |
+
#
|
190 |
+
# spec = np.fft.fft(insign, n_fft, axis=0)
|
191 |
+
# sig = np.absolute(spec)
|
192 |
+
# sig2 = sig ** 2
|
193 |
+
#
|
194 |
+
# gammak = np.minimum(sig2 / noise_mu2, 40)
|
195 |
+
#
|
196 |
+
# if xk_prev.all() == 0:
|
197 |
+
# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
198 |
+
# else:
|
199 |
+
# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
200 |
+
# ksi = np.maximum(ksi_min, ksi)
|
201 |
+
#
|
202 |
+
# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
|
203 |
+
# vad_decision = np.sum(log_sigma_k) / window_size
|
204 |
+
# if vad_decision < eta:
|
205 |
+
# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
206 |
+
#
|
207 |
+
# a = ksi / (1 + ksi)
|
208 |
+
# vk = a * gammak
|
209 |
+
# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
210 |
+
# hw = a * np.exp(ei_vk)
|
211 |
+
# sig = sig * hw
|
212 |
+
# xk_prev = sig ** 2
|
213 |
+
#
|
214 |
+
# vad[k:k + len2] = vad_decision >= eta
|
215 |
+
#
|
216 |
+
# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
|
217 |
+
# return vad
|
218 |
+
|
219 |
+
|
220 |
+
def to_float(_input):
|
221 |
+
if _input.dtype == np.float64:
|
222 |
+
return _input, _input.dtype
|
223 |
+
elif _input.dtype == np.float32:
|
224 |
+
return _input.astype(np.float64), _input.dtype
|
225 |
+
elif _input.dtype == np.uint8:
|
226 |
+
return (_input - 128) / 128., _input.dtype
|
227 |
+
elif _input.dtype == np.int16:
|
228 |
+
return _input / 32768., _input.dtype
|
229 |
+
elif _input.dtype == np.int32:
|
230 |
+
return _input / 2147483648., _input.dtype
|
231 |
+
raise ValueError('Unsupported wave file format')
|
232 |
+
|
233 |
+
|
234 |
+
def from_float(_input, dtype):
|
235 |
+
if dtype == np.float64:
|
236 |
+
return _input, np.float64
|
237 |
+
elif dtype == np.float32:
|
238 |
+
return _input.astype(np.float32)
|
239 |
+
elif dtype == np.uint8:
|
240 |
+
return ((_input * 128) + 128).astype(np.uint8)
|
241 |
+
elif dtype == np.int16:
|
242 |
+
return (_input * 32768).astype(np.int16)
|
243 |
+
elif dtype == np.int32:
|
244 |
+
return (_input * 2147483648).astype(np.int32)
|
245 |
+
raise ValueError('Unsupported wave file format')
|
mockingbirdforuse/synthesizer/utils/numbers.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import inflect
|
3 |
+
|
4 |
+
_inflect = inflect.engine()
|
5 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
6 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
7 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
8 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
9 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
10 |
+
_number_re = re.compile(r"[0-9]+")
|
11 |
+
|
12 |
+
|
13 |
+
def _remove_commas(m):
|
14 |
+
return m.group(1).replace(",", "")
|
15 |
+
|
16 |
+
|
17 |
+
def _expand_decimal_point(m):
|
18 |
+
return m.group(1).replace(".", " point ")
|
19 |
+
|
20 |
+
|
21 |
+
def _expand_dollars(m):
|
22 |
+
match = m.group(1)
|
23 |
+
parts = match.split(".")
|
24 |
+
if len(parts) > 2:
|
25 |
+
return match + " dollars" # Unexpected format
|
26 |
+
dollars = int(parts[0]) if parts[0] else 0
|
27 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
28 |
+
if dollars and cents:
|
29 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
30 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
31 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
32 |
+
elif dollars:
|
33 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
34 |
+
return "%s %s" % (dollars, dollar_unit)
|
35 |
+
elif cents:
|
36 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
37 |
+
return "%s %s" % (cents, cent_unit)
|
38 |
+
else:
|
39 |
+
return "zero dollars"
|
40 |
+
|
41 |
+
|
42 |
+
def _expand_ordinal(m):
|
43 |
+
return _inflect.number_to_words(m.group(0))
|
44 |
+
|
45 |
+
|
46 |
+
def _expand_number(m):
|
47 |
+
num = int(m.group(0))
|
48 |
+
if num > 1000 and num < 3000:
|
49 |
+
if num == 2000:
|
50 |
+
return "two thousand"
|
51 |
+
elif num > 2000 and num < 2010:
|
52 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
53 |
+
elif num % 100 == 0:
|
54 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
55 |
+
else:
|
56 |
+
return _inflect.number_to_words(
|
57 |
+
num, andword="", zero="oh", group=2
|
58 |
+
).replace(", ", " ")
|
59 |
+
else:
|
60 |
+
return _inflect.number_to_words(num, andword="")
|
61 |
+
|
62 |
+
|
63 |
+
def normalize_numbers(text):
|
64 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
65 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
66 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
67 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
68 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
69 |
+
text = re.sub(_number_re, _expand_number, text)
|
70 |
+
return text
|
mockingbirdforuse/synthesizer/utils/symbols.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
|
4 |
+
The default is a set of ASCII characters that works well for English or text that has been run
|
5 |
+
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
6 |
+
"""
|
7 |
+
# from . import cmudict
|
8 |
+
|
9 |
+
_pad = "_"
|
10 |
+
_eos = "~"
|
11 |
+
_characters = (
|
12 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!'(),-.:;? "
|
13 |
+
)
|
14 |
+
|
15 |
+
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model
|
16 |
+
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
17 |
+
# _arpabet = ["@' + s for s in cmudict.valid_symbols]
|
18 |
+
|
19 |
+
# Export all symbols:
|
20 |
+
symbols = [_pad, _eos] + list(_characters) # + _arpabet
|
mockingbirdforuse/synthesizer/utils/text.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .symbols import symbols
|
2 |
+
from . import cleaners
|
3 |
+
import re
|
4 |
+
|
5 |
+
# Mappings from symbol to numeric ID and vice versa:
|
6 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
7 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
8 |
+
|
9 |
+
# Regular expression matching text enclosed in curly braces:
|
10 |
+
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
11 |
+
|
12 |
+
|
13 |
+
def text_to_sequence(text, cleaner_names):
|
14 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
15 |
+
|
16 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
17 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
18 |
+
|
19 |
+
Args:
|
20 |
+
text: string to convert to a sequence
|
21 |
+
cleaner_names: names of the cleaner functions to run the text through
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List of integers corresponding to the symbols in the text
|
25 |
+
"""
|
26 |
+
sequence = []
|
27 |
+
|
28 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
29 |
+
while len(text):
|
30 |
+
m = _curly_re.match(text)
|
31 |
+
if not m:
|
32 |
+
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
33 |
+
break
|
34 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
35 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
36 |
+
text = m.group(3)
|
37 |
+
|
38 |
+
# Append EOS token
|
39 |
+
sequence.append(_symbol_to_id["~"])
|
40 |
+
return sequence
|
41 |
+
|
42 |
+
|
43 |
+
def sequence_to_text(sequence):
|
44 |
+
"""Converts a sequence of IDs back to a string"""
|
45 |
+
result = ""
|
46 |
+
for symbol_id in sequence:
|
47 |
+
if symbol_id in _id_to_symbol:
|
48 |
+
s = _id_to_symbol[symbol_id]
|
49 |
+
# Enclose ARPAbet back in curly braces:
|
50 |
+
if len(s) > 1 and s[0] == "@":
|
51 |
+
s = "{%s}" % s[1:]
|
52 |
+
result += s
|
53 |
+
return result.replace("}{", " ")
|
54 |
+
|
55 |
+
|
56 |
+
def _clean_text(text, cleaner_names):
|
57 |
+
for name in cleaner_names:
|
58 |
+
cleaner = getattr(cleaners, name)
|
59 |
+
if not cleaner:
|
60 |
+
raise Exception("Unknown cleaner: %s" % name)
|
61 |
+
text = cleaner(text)
|
62 |
+
return text
|
63 |
+
|
64 |
+
|
65 |
+
def _symbols_to_sequence(symbols):
|
66 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
67 |
+
|
68 |
+
|
69 |
+
def _arpabet_to_sequence(text):
|
70 |
+
return _symbols_to_sequence(["@" + s for s in text.split()])
|
71 |
+
|
72 |
+
|
73 |
+
def _should_keep_symbol(s):
|
74 |
+
return s in _symbol_to_id and s not in ("_", "~")
|
mockingbirdforuse/vocoder/__init__.py
ADDED
File without changes
|
mockingbirdforuse/vocoder/distribution.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def log_sum_exp(x):
|
7 |
+
"""numerically stable log_sum_exp implementation that prevents overflow"""
|
8 |
+
# TF ordering
|
9 |
+
axis = len(x.size()) - 1
|
10 |
+
m, _ = torch.max(x, dim=axis)
|
11 |
+
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
12 |
+
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
13 |
+
|
14 |
+
|
15 |
+
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
16 |
+
def discretized_mix_logistic_loss(
|
17 |
+
y_hat, y, num_classes=65536, log_scale_min=None, reduce=True
|
18 |
+
):
|
19 |
+
if log_scale_min is None:
|
20 |
+
log_scale_min = float(np.log(1e-14))
|
21 |
+
y_hat = y_hat.permute(0, 2, 1)
|
22 |
+
assert y_hat.dim() == 3
|
23 |
+
assert y_hat.size(1) % 3 == 0
|
24 |
+
nr_mix = y_hat.size(1) // 3
|
25 |
+
|
26 |
+
# (B x T x C)
|
27 |
+
y_hat = y_hat.transpose(1, 2)
|
28 |
+
|
29 |
+
# unpack parameters. (B, T, num_mixtures) x 3
|
30 |
+
logit_probs = y_hat[:, :, :nr_mix]
|
31 |
+
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
32 |
+
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
33 |
+
|
34 |
+
# B x T x 1 -> B x T x num_mixtures
|
35 |
+
y = y.expand_as(means)
|
36 |
+
|
37 |
+
centered_y = y - means
|
38 |
+
inv_stdv = torch.exp(-log_scales)
|
39 |
+
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
40 |
+
cdf_plus = torch.sigmoid(plus_in)
|
41 |
+
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
42 |
+
cdf_min = torch.sigmoid(min_in)
|
43 |
+
|
44 |
+
# log probability for edge case of 0 (before scaling)
|
45 |
+
# equivalent: torch.log(F.sigmoid(plus_in))
|
46 |
+
log_cdf_plus = plus_in - F.softplus(plus_in)
|
47 |
+
|
48 |
+
# log probability for edge case of 255 (before scaling)
|
49 |
+
# equivalent: (1 - F.sigmoid(min_in)).log()
|
50 |
+
log_one_minus_cdf_min = -F.softplus(min_in)
|
51 |
+
|
52 |
+
# probability for all other cases
|
53 |
+
cdf_delta = cdf_plus - cdf_min
|
54 |
+
|
55 |
+
mid_in = inv_stdv * centered_y
|
56 |
+
# log probability in the center of the bin, to be used in extreme cases
|
57 |
+
# (not actually used in our code)
|
58 |
+
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
59 |
+
|
60 |
+
# tf equivalent
|
61 |
+
"""
|
62 |
+
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
63 |
+
tf.where(x > 0.999, log_one_minus_cdf_min,
|
64 |
+
tf.where(cdf_delta > 1e-5,
|
65 |
+
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
66 |
+
log_pdf_mid - np.log(127.5))))
|
67 |
+
"""
|
68 |
+
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
69 |
+
# for num_classes=65536 case? 1e-7? not sure..
|
70 |
+
inner_inner_cond = (cdf_delta > 1e-5).float()
|
71 |
+
|
72 |
+
inner_inner_out = inner_inner_cond * torch.log(
|
73 |
+
torch.clamp(cdf_delta, min=1e-12)
|
74 |
+
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
75 |
+
inner_cond = (y > 0.999).float()
|
76 |
+
inner_out = (
|
77 |
+
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
78 |
+
)
|
79 |
+
cond = (y < -0.999).float()
|
80 |
+
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
81 |
+
|
82 |
+
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
83 |
+
|
84 |
+
if reduce:
|
85 |
+
return -torch.mean(log_sum_exp(log_probs))
|
86 |
+
else:
|
87 |
+
return -log_sum_exp(log_probs).unsqueeze(-1)
|
88 |
+
|
89 |
+
|
90 |
+
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
91 |
+
"""
|
92 |
+
Sample from discretized mixture of logistic distributions
|
93 |
+
Args:
|
94 |
+
y (Tensor): B x C x T
|
95 |
+
log_scale_min (float): Log scale minimum value
|
96 |
+
Returns:
|
97 |
+
Tensor: sample in range of [-1, 1].
|
98 |
+
"""
|
99 |
+
if log_scale_min is None:
|
100 |
+
log_scale_min = float(np.log(1e-14))
|
101 |
+
assert y.size(1) % 3 == 0
|
102 |
+
nr_mix = y.size(1) // 3
|
103 |
+
|
104 |
+
# B x T x C
|
105 |
+
y = y.transpose(1, 2)
|
106 |
+
logit_probs = y[:, :, :nr_mix]
|
107 |
+
|
108 |
+
# sample mixture indicator from softmax
|
109 |
+
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
110 |
+
temp = logit_probs.data - torch.log(-torch.log(temp))
|
111 |
+
_, argmax = temp.max(dim=-1)
|
112 |
+
|
113 |
+
# (B, T) -> (B, T, nr_mix)
|
114 |
+
one_hot = to_one_hot(argmax, nr_mix)
|
115 |
+
# select logistic parameters
|
116 |
+
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
117 |
+
log_scales = torch.clamp(
|
118 |
+
torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
|
119 |
+
)
|
120 |
+
# sample from logistic & clip to interval
|
121 |
+
# we don't actually round to the nearest 8bit value when sampling
|
122 |
+
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
123 |
+
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
124 |
+
|
125 |
+
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
126 |
+
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
def to_one_hot(tensor, n, fill_with=1.0):
|
131 |
+
# we perform one hot encore with respect to the last axis
|
132 |
+
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
133 |
+
if tensor.is_cuda:
|
134 |
+
one_hot = one_hot.cuda()
|
135 |
+
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
136 |
+
return one_hot
|
mockingbirdforuse/vocoder/hifigan/__init__.py
ADDED
File without changes
|
mockingbirdforuse/vocoder/hifigan/hparams.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class HParams:
|
6 |
+
resblock = "1"
|
7 |
+
num_gpus = 0
|
8 |
+
batch_size = 16
|
9 |
+
learning_rate = 0.0002
|
10 |
+
adam_b1 = 0.8
|
11 |
+
adam_b2 = 0.99
|
12 |
+
lr_decay = 0.999
|
13 |
+
seed = 1234
|
14 |
+
|
15 |
+
upsample_rates = [5, 5, 4, 2]
|
16 |
+
upsample_kernel_sizes = [10, 10, 8, 4]
|
17 |
+
upsample_initial_channel = 512
|
18 |
+
resblock_kernel_sizes = [3, 7, 11]
|
19 |
+
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
20 |
+
|
21 |
+
segment_size = 6400
|
22 |
+
num_mels = 80
|
23 |
+
num_freq = 1025
|
24 |
+
n_fft = 1024
|
25 |
+
hop_size = 200
|
26 |
+
win_size = 800
|
27 |
+
|
28 |
+
sampling_rate = 16000
|
29 |
+
|
30 |
+
fmin = 0
|
31 |
+
fmax = 7600
|
32 |
+
fmax_for_loss = None
|
33 |
+
|
34 |
+
num_workers = 4
|
35 |
+
|
36 |
+
|
37 |
+
hparams = HParams()
|
mockingbirdforuse/vocoder/hifigan/inference.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from .hparams import hparams as hp
|
5 |
+
from .models import Generator
|
6 |
+
from ...log import logger
|
7 |
+
|
8 |
+
|
9 |
+
class HifiGanVocoder:
|
10 |
+
def __init__(self, model_path: Path):
|
11 |
+
torch.manual_seed(hp.seed)
|
12 |
+
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
self.generator = Generator(hp).to(self._device)
|
14 |
+
|
15 |
+
logger.debug("Loading '{}'".format(model_path))
|
16 |
+
state_dict_g = torch.load(model_path, map_location=self._device)
|
17 |
+
logger.debug("Complete.")
|
18 |
+
|
19 |
+
self.generator.load_state_dict(state_dict_g["generator"])
|
20 |
+
self.generator.eval()
|
21 |
+
self.generator.remove_weight_norm()
|
22 |
+
|
23 |
+
def infer_waveform(self, mel):
|
24 |
+
mel = torch.FloatTensor(mel).to(self._device)
|
25 |
+
mel = mel.unsqueeze(0)
|
26 |
+
|
27 |
+
with torch.no_grad():
|
28 |
+
y_g_hat = self.generator(mel)
|
29 |
+
audio = y_g_hat.squeeze()
|
30 |
+
audio = audio.cpu().numpy()
|
31 |
+
|
32 |
+
return audio, hp.sampling_rate
|
mockingbirdforuse/vocoder/hifigan/models.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
5 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
6 |
+
from torch.nn.utils.weight_norm import weight_norm, remove_weight_norm
|
7 |
+
from ...log import logger
|
8 |
+
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
|
12 |
+
def init_weights(m, mean=0.0, std=0.01):
|
13 |
+
classname = m.__class__.__name__
|
14 |
+
if classname.find("Conv") != -1:
|
15 |
+
m.weight.data.normal_(mean, std)
|
16 |
+
|
17 |
+
|
18 |
+
def get_padding(kernel_size, dilation=1):
|
19 |
+
return int((kernel_size * dilation - dilation) / 2)
|
20 |
+
|
21 |
+
|
22 |
+
class ResBlock1(torch.nn.Module):
|
23 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
24 |
+
super(ResBlock1, self).__init__()
|
25 |
+
self.h = h
|
26 |
+
self.convs1 = nn.ModuleList(
|
27 |
+
[
|
28 |
+
weight_norm(
|
29 |
+
Conv1d(
|
30 |
+
channels,
|
31 |
+
channels,
|
32 |
+
kernel_size,
|
33 |
+
1,
|
34 |
+
dilation=dilation[0],
|
35 |
+
padding=get_padding(kernel_size, dilation[0]),
|
36 |
+
)
|
37 |
+
),
|
38 |
+
weight_norm(
|
39 |
+
Conv1d(
|
40 |
+
channels,
|
41 |
+
channels,
|
42 |
+
kernel_size,
|
43 |
+
1,
|
44 |
+
dilation=dilation[1],
|
45 |
+
padding=get_padding(kernel_size, dilation[1]),
|
46 |
+
)
|
47 |
+
),
|
48 |
+
weight_norm(
|
49 |
+
Conv1d(
|
50 |
+
channels,
|
51 |
+
channels,
|
52 |
+
kernel_size,
|
53 |
+
1,
|
54 |
+
dilation=dilation[2],
|
55 |
+
padding=get_padding(kernel_size, dilation[2]),
|
56 |
+
)
|
57 |
+
),
|
58 |
+
]
|
59 |
+
)
|
60 |
+
self.convs1.apply(init_weights)
|
61 |
+
|
62 |
+
self.convs2 = nn.ModuleList(
|
63 |
+
[
|
64 |
+
weight_norm(
|
65 |
+
Conv1d(
|
66 |
+
channels,
|
67 |
+
channels,
|
68 |
+
kernel_size,
|
69 |
+
1,
|
70 |
+
dilation=1,
|
71 |
+
padding=get_padding(kernel_size, 1),
|
72 |
+
)
|
73 |
+
),
|
74 |
+
weight_norm(
|
75 |
+
Conv1d(
|
76 |
+
channels,
|
77 |
+
channels,
|
78 |
+
kernel_size,
|
79 |
+
1,
|
80 |
+
dilation=1,
|
81 |
+
padding=get_padding(kernel_size, 1),
|
82 |
+
)
|
83 |
+
),
|
84 |
+
weight_norm(
|
85 |
+
Conv1d(
|
86 |
+
channels,
|
87 |
+
channels,
|
88 |
+
kernel_size,
|
89 |
+
1,
|
90 |
+
dilation=1,
|
91 |
+
padding=get_padding(kernel_size, 1),
|
92 |
+
)
|
93 |
+
),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
self.convs2.apply(init_weights)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
100 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
101 |
+
xt = c1(xt)
|
102 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
103 |
+
xt = c2(xt)
|
104 |
+
x = xt + x
|
105 |
+
return x
|
106 |
+
|
107 |
+
def remove_weight_norm(self):
|
108 |
+
for l in self.convs1:
|
109 |
+
remove_weight_norm(l)
|
110 |
+
for l in self.convs2:
|
111 |
+
remove_weight_norm(l)
|
112 |
+
|
113 |
+
|
114 |
+
class ResBlock2(torch.nn.Module):
|
115 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
116 |
+
super(ResBlock2, self).__init__()
|
117 |
+
self.h = h
|
118 |
+
self.convs = nn.ModuleList(
|
119 |
+
[
|
120 |
+
weight_norm(
|
121 |
+
Conv1d(
|
122 |
+
channels,
|
123 |
+
channels,
|
124 |
+
kernel_size,
|
125 |
+
1,
|
126 |
+
dilation=dilation[0],
|
127 |
+
padding=get_padding(kernel_size, dilation[0]),
|
128 |
+
)
|
129 |
+
),
|
130 |
+
weight_norm(
|
131 |
+
Conv1d(
|
132 |
+
channels,
|
133 |
+
channels,
|
134 |
+
kernel_size,
|
135 |
+
1,
|
136 |
+
dilation=dilation[1],
|
137 |
+
padding=get_padding(kernel_size, dilation[1]),
|
138 |
+
)
|
139 |
+
),
|
140 |
+
]
|
141 |
+
)
|
142 |
+
self.convs.apply(init_weights)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
for c in self.convs:
|
146 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
147 |
+
xt = c(xt)
|
148 |
+
x = xt + x
|
149 |
+
return x
|
150 |
+
|
151 |
+
def remove_weight_norm(self):
|
152 |
+
for l in self.convs:
|
153 |
+
remove_weight_norm(l)
|
154 |
+
|
155 |
+
|
156 |
+
class InterpolationBlock(torch.nn.Module):
|
157 |
+
def __init__(
|
158 |
+
self, scale_factor, mode="nearest", align_corners=None, downsample=False
|
159 |
+
):
|
160 |
+
super(InterpolationBlock, self).__init__()
|
161 |
+
self.downsample = downsample
|
162 |
+
self.scale_factor = scale_factor
|
163 |
+
self.mode = mode
|
164 |
+
self.align_corners = align_corners
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
outputs = F.interpolate(
|
168 |
+
x,
|
169 |
+
size=x.shape[-1] * self.scale_factor
|
170 |
+
if not self.downsample
|
171 |
+
else x.shape[-1] // self.scale_factor,
|
172 |
+
mode=self.mode,
|
173 |
+
align_corners=self.align_corners,
|
174 |
+
recompute_scale_factor=False,
|
175 |
+
)
|
176 |
+
return outputs
|
177 |
+
|
178 |
+
|
179 |
+
class Generator(torch.nn.Module):
|
180 |
+
def __init__(self, h):
|
181 |
+
super(Generator, self).__init__()
|
182 |
+
self.h = h
|
183 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
184 |
+
self.num_upsamples = len(h.upsample_rates)
|
185 |
+
self.conv_pre = weight_norm(
|
186 |
+
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
187 |
+
)
|
188 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
189 |
+
|
190 |
+
self.ups = nn.ModuleList()
|
191 |
+
# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
192 |
+
# # self.ups.append(weight_norm(
|
193 |
+
# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
194 |
+
# # k, u, padding=(k-u)//2)))
|
195 |
+
if h.sampling_rate == 24000:
|
196 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
197 |
+
self.ups.append(
|
198 |
+
torch.nn.Sequential(
|
199 |
+
InterpolationBlock(u),
|
200 |
+
weight_norm(
|
201 |
+
torch.nn.Conv1d(
|
202 |
+
h.upsample_initial_channel // (2**i),
|
203 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
204 |
+
k,
|
205 |
+
padding=(k - 1) // 2,
|
206 |
+
)
|
207 |
+
),
|
208 |
+
)
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
212 |
+
self.ups.append(
|
213 |
+
weight_norm(
|
214 |
+
ConvTranspose1d(
|
215 |
+
h.upsample_initial_channel // (2**i),
|
216 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
217 |
+
k,
|
218 |
+
u,
|
219 |
+
padding=(u // 2 + u % 2),
|
220 |
+
output_padding=u % 2,
|
221 |
+
)
|
222 |
+
)
|
223 |
+
)
|
224 |
+
self.resblocks = nn.ModuleList()
|
225 |
+
ch = 0
|
226 |
+
for i in range(len(self.ups)):
|
227 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
228 |
+
for j, (k, d) in enumerate(
|
229 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
230 |
+
):
|
231 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
232 |
+
|
233 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
234 |
+
self.ups.apply(init_weights)
|
235 |
+
self.conv_post.apply(init_weights)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
x = self.conv_pre(x)
|
239 |
+
for i in range(self.num_upsamples):
|
240 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
241 |
+
x = self.ups[i](x)
|
242 |
+
xs = None
|
243 |
+
for j in range(self.num_kernels):
|
244 |
+
if xs is None:
|
245 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
246 |
+
else:
|
247 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
248 |
+
x = xs / self.num_kernels
|
249 |
+
x = F.leaky_relu(x)
|
250 |
+
x = self.conv_post(x)
|
251 |
+
x = torch.tanh(x)
|
252 |
+
return x
|
253 |
+
|
254 |
+
def remove_weight_norm(self):
|
255 |
+
logger.debug("Removing weight norm...")
|
256 |
+
for module in self.ups:
|
257 |
+
if self.h.sampling_rate == 24000:
|
258 |
+
remove_weight_norm(module[-1])
|
259 |
+
else:
|
260 |
+
remove_weight_norm(module)
|
261 |
+
for module in self.resblocks:
|
262 |
+
module.remove_weight_norm()
|
263 |
+
remove_weight_norm(self.conv_pre)
|
264 |
+
remove_weight_norm(self.conv_post)
|
265 |
+
|
266 |
+
|
267 |
+
class DiscriminatorP(torch.nn.Module):
|
268 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
269 |
+
super(DiscriminatorP, self).__init__()
|
270 |
+
self.period = period
|
271 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
272 |
+
self.convs = nn.ModuleList(
|
273 |
+
[
|
274 |
+
norm_f(
|
275 |
+
Conv2d(
|
276 |
+
1,
|
277 |
+
32,
|
278 |
+
(kernel_size, 1),
|
279 |
+
(stride, 1),
|
280 |
+
padding=(get_padding(5, 1), 0),
|
281 |
+
)
|
282 |
+
),
|
283 |
+
norm_f(
|
284 |
+
Conv2d(
|
285 |
+
32,
|
286 |
+
128,
|
287 |
+
(kernel_size, 1),
|
288 |
+
(stride, 1),
|
289 |
+
padding=(get_padding(5, 1), 0),
|
290 |
+
)
|
291 |
+
),
|
292 |
+
norm_f(
|
293 |
+
Conv2d(
|
294 |
+
128,
|
295 |
+
512,
|
296 |
+
(kernel_size, 1),
|
297 |
+
(stride, 1),
|
298 |
+
padding=(get_padding(5, 1), 0),
|
299 |
+
)
|
300 |
+
),
|
301 |
+
norm_f(
|
302 |
+
Conv2d(
|
303 |
+
512,
|
304 |
+
1024,
|
305 |
+
(kernel_size, 1),
|
306 |
+
(stride, 1),
|
307 |
+
padding=(get_padding(5, 1), 0),
|
308 |
+
)
|
309 |
+
),
|
310 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
311 |
+
]
|
312 |
+
)
|
313 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
fmap = []
|
317 |
+
|
318 |
+
# 1d to 2d
|
319 |
+
b, c, t = x.shape
|
320 |
+
if t % self.period != 0: # pad first
|
321 |
+
n_pad = self.period - (t % self.period)
|
322 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
323 |
+
t = t + n_pad
|
324 |
+
x = x.view(b, c, t // self.period, self.period)
|
325 |
+
|
326 |
+
for l in self.convs:
|
327 |
+
x = l(x)
|
328 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
329 |
+
fmap.append(x)
|
330 |
+
x = self.conv_post(x)
|
331 |
+
fmap.append(x)
|
332 |
+
x = torch.flatten(x, 1, -1)
|
333 |
+
|
334 |
+
return x, fmap
|
335 |
+
|
336 |
+
|
337 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
338 |
+
def __init__(self):
|
339 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
340 |
+
self.discriminators = nn.ModuleList(
|
341 |
+
[
|
342 |
+
DiscriminatorP(2),
|
343 |
+
DiscriminatorP(3),
|
344 |
+
DiscriminatorP(5),
|
345 |
+
DiscriminatorP(7),
|
346 |
+
DiscriminatorP(11),
|
347 |
+
]
|
348 |
+
)
|
349 |
+
|
350 |
+
def forward(self, y, y_hat):
|
351 |
+
y_d_rs = []
|
352 |
+
y_d_gs = []
|
353 |
+
fmap_rs = []
|
354 |
+
fmap_gs = []
|
355 |
+
for i, d in enumerate(self.discriminators):
|
356 |
+
y_d_r, fmap_r = d(y)
|
357 |
+
y_d_g, fmap_g = d(y_hat)
|
358 |
+
y_d_rs.append(y_d_r)
|
359 |
+
fmap_rs.append(fmap_r)
|
360 |
+
y_d_gs.append(y_d_g)
|
361 |
+
fmap_gs.append(fmap_g)
|
362 |
+
|
363 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
364 |
+
|
365 |
+
|
366 |
+
class DiscriminatorS(torch.nn.Module):
|
367 |
+
def __init__(self, use_spectral_norm=False):
|
368 |
+
super(DiscriminatorS, self).__init__()
|
369 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
370 |
+
self.convs = nn.ModuleList(
|
371 |
+
[
|
372 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
373 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
374 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
375 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
376 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
377 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
378 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
379 |
+
]
|
380 |
+
)
|
381 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
382 |
+
|
383 |
+
def forward(self, x):
|
384 |
+
fmap = []
|
385 |
+
for l in self.convs:
|
386 |
+
x = l(x)
|
387 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
388 |
+
fmap.append(x)
|
389 |
+
x = self.conv_post(x)
|
390 |
+
fmap.append(x)
|
391 |
+
x = torch.flatten(x, 1, -1)
|
392 |
+
|
393 |
+
return x, fmap
|
394 |
+
|
395 |
+
|
396 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
397 |
+
def __init__(self):
|
398 |
+
super(MultiScaleDiscriminator, self).__init__()
|
399 |
+
self.discriminators = nn.ModuleList(
|
400 |
+
[
|
401 |
+
DiscriminatorS(use_spectral_norm=True),
|
402 |
+
DiscriminatorS(),
|
403 |
+
DiscriminatorS(),
|
404 |
+
]
|
405 |
+
)
|
406 |
+
self.meanpools = nn.ModuleList(
|
407 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
408 |
+
)
|
409 |
+
|
410 |
+
def forward(self, y, y_hat):
|
411 |
+
y_d_rs = []
|
412 |
+
y_d_gs = []
|
413 |
+
fmap_rs = []
|
414 |
+
fmap_gs = []
|
415 |
+
for i, d in enumerate(self.discriminators):
|
416 |
+
if i != 0:
|
417 |
+
y = self.meanpools[i - 1](y)
|
418 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
419 |
+
y_d_r, fmap_r = d(y)
|
420 |
+
y_d_g, fmap_g = d(y_hat)
|
421 |
+
y_d_rs.append(y_d_r)
|
422 |
+
fmap_rs.append(fmap_r)
|
423 |
+
y_d_gs.append(y_d_g)
|
424 |
+
fmap_gs.append(fmap_g)
|
425 |
+
|
426 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
427 |
+
|
428 |
+
|
429 |
+
def feature_loss(fmap_r, fmap_g):
|
430 |
+
loss = 0
|
431 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
432 |
+
for rl, gl in zip(dr, dg):
|
433 |
+
loss += torch.mean(torch.abs(rl - gl))
|
434 |
+
|
435 |
+
return loss * 2
|
436 |
+
|
437 |
+
|
438 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
439 |
+
loss = 0
|
440 |
+
r_losses = []
|
441 |
+
g_losses = []
|
442 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
443 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
444 |
+
g_loss = torch.mean(dg**2)
|
445 |
+
loss += r_loss + g_loss
|
446 |
+
r_losses.append(r_loss.item())
|
447 |
+
g_losses.append(g_loss.item())
|
448 |
+
|
449 |
+
return loss, r_losses, g_losses
|
450 |
+
|
451 |
+
|
452 |
+
def generator_loss(disc_outputs):
|
453 |
+
loss = 0
|
454 |
+
gen_losses = []
|
455 |
+
for dg in disc_outputs:
|
456 |
+
l = torch.mean((1 - dg) ** 2)
|
457 |
+
gen_losses.append(l)
|
458 |
+
loss += l
|
459 |
+
|
460 |
+
return loss, gen_losses
|
mockingbirdforuse/vocoder/wavernn/__init__.py
ADDED
File without changes
|
mockingbirdforuse/vocoder/wavernn/audio.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import librosa
|
3 |
+
import numpy as np
|
4 |
+
import soundfile as sf
|
5 |
+
from scipy.signal import lfilter
|
6 |
+
|
7 |
+
from .hparams import hparams as hp
|
8 |
+
|
9 |
+
|
10 |
+
def label_2_float(x, bits):
|
11 |
+
return 2 * x / (2**bits - 1.0) - 1.0
|
12 |
+
|
13 |
+
|
14 |
+
def float_2_label(x, bits):
|
15 |
+
assert abs(x).max() <= 1.0
|
16 |
+
x = (x + 1.0) * (2**bits - 1) / 2
|
17 |
+
return x.clip(0, 2**bits - 1)
|
18 |
+
|
19 |
+
|
20 |
+
def load_wav(path):
|
21 |
+
return librosa.load(str(path), sr=hp.sample_rate)[0]
|
22 |
+
|
23 |
+
|
24 |
+
def save_wav(x, path):
|
25 |
+
sf.write(path, x.astype(np.float32), hp.sample_rate)
|
26 |
+
|
27 |
+
|
28 |
+
def split_signal(x):
|
29 |
+
unsigned = x + 2**15
|
30 |
+
coarse = unsigned // 256
|
31 |
+
fine = unsigned % 256
|
32 |
+
return coarse, fine
|
33 |
+
|
34 |
+
|
35 |
+
def combine_signal(coarse, fine):
|
36 |
+
return coarse * 256 + fine - 2**15
|
37 |
+
|
38 |
+
|
39 |
+
def encode_16bits(x):
|
40 |
+
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
41 |
+
|
42 |
+
|
43 |
+
mel_basis = None
|
44 |
+
|
45 |
+
|
46 |
+
def linear_to_mel(spectrogram):
|
47 |
+
global mel_basis
|
48 |
+
if mel_basis is None:
|
49 |
+
mel_basis = build_mel_basis()
|
50 |
+
return np.dot(mel_basis, spectrogram)
|
51 |
+
|
52 |
+
|
53 |
+
def build_mel_basis():
|
54 |
+
return librosa.filters.mel(
|
55 |
+
sr=hp.sample_rate,
|
56 |
+
n_fft=hp.n_fft,
|
57 |
+
n_mels=hp.num_mels,
|
58 |
+
fmin=hp.fmin,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def normalize(S):
|
63 |
+
return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
|
64 |
+
|
65 |
+
|
66 |
+
def denormalize(S):
|
67 |
+
return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
|
68 |
+
|
69 |
+
|
70 |
+
def amp_to_db(x):
|
71 |
+
return 20 * np.log10(np.maximum(1e-5, x))
|
72 |
+
|
73 |
+
|
74 |
+
def db_to_amp(x):
|
75 |
+
return np.power(10.0, x * 0.05)
|
76 |
+
|
77 |
+
|
78 |
+
def spectrogram(y):
|
79 |
+
D = stft(y)
|
80 |
+
S = amp_to_db(np.abs(D)) - hp.ref_level_db
|
81 |
+
return normalize(S)
|
82 |
+
|
83 |
+
|
84 |
+
def melspectrogram(y):
|
85 |
+
D = stft(y)
|
86 |
+
S = amp_to_db(linear_to_mel(np.abs(D)))
|
87 |
+
return normalize(S)
|
88 |
+
|
89 |
+
|
90 |
+
def stft(y):
|
91 |
+
return librosa.stft(
|
92 |
+
y=y,
|
93 |
+
n_fft=hp.n_fft,
|
94 |
+
hop_length=hp.hop_length,
|
95 |
+
win_length=hp.win_length,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def pre_emphasis(x):
|
100 |
+
return lfilter([1, -hp.preemphasis], [1], x)
|
101 |
+
|
102 |
+
|
103 |
+
def de_emphasis(x):
|
104 |
+
return lfilter([1], [1, -hp.preemphasis], x)
|
105 |
+
|
106 |
+
|
107 |
+
def encode_mu_law(x, mu):
|
108 |
+
mu = mu - 1
|
109 |
+
fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
|
110 |
+
return np.floor((fx + 1) / 2 * mu + 0.5)
|
111 |
+
|
112 |
+
|
113 |
+
def decode_mu_law(y, mu, from_labels=True):
|
114 |
+
if from_labels:
|
115 |
+
y = label_2_float(y, math.log2(mu))
|
116 |
+
mu = mu - 1
|
117 |
+
x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
|
118 |
+
return x
|
mockingbirdforuse/vocoder/wavernn/hparams.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from ...synthesizer.hparams import hparams as _syn_hp
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class HParams:
|
7 |
+
# Audio settings------------------------------------------------------------------------
|
8 |
+
# Match the values of the synthesizer
|
9 |
+
sample_rate = _syn_hp.sample_rate
|
10 |
+
n_fft = _syn_hp.n_fft
|
11 |
+
num_mels = _syn_hp.num_mels
|
12 |
+
hop_length = _syn_hp.hop_size
|
13 |
+
win_length = _syn_hp.win_size
|
14 |
+
fmin = _syn_hp.fmin
|
15 |
+
min_level_db = _syn_hp.min_level_db
|
16 |
+
ref_level_db = _syn_hp.ref_level_db
|
17 |
+
mel_max_abs_value = _syn_hp.max_abs_value
|
18 |
+
preemphasis = _syn_hp.preemphasis
|
19 |
+
apply_preemphasis = _syn_hp.preemphasize
|
20 |
+
|
21 |
+
bits = 9 # bit depth of signal
|
22 |
+
mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode
|
23 |
+
# below
|
24 |
+
|
25 |
+
# WAVERNN / VOCODER --------------------------------------------------------------------------------
|
26 |
+
voc_mode = "RAW" # either 'RAW' (softmax on raw bits) or 'MOL' (sample from
|
27 |
+
# mixture of logistics)
|
28 |
+
voc_upsample_factors = (
|
29 |
+
5,
|
30 |
+
5,
|
31 |
+
8,
|
32 |
+
) # NB - this needs to correctly factorise hop_length
|
33 |
+
voc_rnn_dims = 512
|
34 |
+
voc_fc_dims = 512
|
35 |
+
voc_compute_dims = 128
|
36 |
+
voc_res_out_dims = 128
|
37 |
+
voc_res_blocks = 10
|
38 |
+
|
39 |
+
# Training
|
40 |
+
voc_batch_size = 100
|
41 |
+
voc_lr = 1e-4
|
42 |
+
voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
|
43 |
+
voc_pad = 2 # this will pad the input so that the resnet can 'see' wider
|
44 |
+
# than input length
|
45 |
+
voc_seq_len = hop_length * 5 # must be a multiple of hop_length
|
46 |
+
|
47 |
+
# Generating / Synthesizing
|
48 |
+
voc_gen_batched = True # very fast (realtime+) single utterance batched generation
|
49 |
+
voc_target = 8000 # target number of samples to be generated in each batch entry
|
50 |
+
voc_overlap = 400 # number of samples for crossfading between batches
|
51 |
+
|
52 |
+
|
53 |
+
hparams = HParams()
|
mockingbirdforuse/vocoder/wavernn/inference.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from .hparams import hparams as hp
|
5 |
+
from .models.fatchord_version import WaveRNN
|
6 |
+
from ...log import logger
|
7 |
+
|
8 |
+
|
9 |
+
class WaveRNNVocoder:
|
10 |
+
def __init__(self, model_path: Path):
|
11 |
+
logger.debug("Building Wave-RNN")
|
12 |
+
self._model = WaveRNN(
|
13 |
+
rnn_dims=hp.voc_rnn_dims,
|
14 |
+
fc_dims=hp.voc_fc_dims,
|
15 |
+
bits=hp.bits,
|
16 |
+
pad=hp.voc_pad,
|
17 |
+
upsample_factors=hp.voc_upsample_factors,
|
18 |
+
feat_dims=hp.num_mels,
|
19 |
+
compute_dims=hp.voc_compute_dims,
|
20 |
+
res_out_dims=hp.voc_res_out_dims,
|
21 |
+
res_blocks=hp.voc_res_blocks,
|
22 |
+
hop_length=hp.hop_length,
|
23 |
+
sample_rate=hp.sample_rate,
|
24 |
+
mode=hp.voc_mode,
|
25 |
+
)
|
26 |
+
|
27 |
+
if torch.cuda.is_available():
|
28 |
+
self._model = self._model.cuda()
|
29 |
+
self._device = torch.device("cuda")
|
30 |
+
else:
|
31 |
+
self._device = torch.device("cpu")
|
32 |
+
|
33 |
+
logger.debug("Loading model weights at %s" % model_path)
|
34 |
+
checkpoint = torch.load(model_path, self._device)
|
35 |
+
self._model.load_state_dict(checkpoint["model_state"])
|
36 |
+
self._model.eval()
|
37 |
+
|
38 |
+
def infer_waveform(
|
39 |
+
self, mel, normalize=True, batched=True, target=8000, overlap=800
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
|
43 |
+
that of the synthesizer!)
|
44 |
+
|
45 |
+
:param normalize:
|
46 |
+
:param batched:
|
47 |
+
:param target:
|
48 |
+
:param overlap:
|
49 |
+
:return:
|
50 |
+
"""
|
51 |
+
|
52 |
+
if normalize:
|
53 |
+
mel = mel / hp.mel_max_abs_value
|
54 |
+
mel = torch.from_numpy(mel[None, ...])
|
55 |
+
wav = self._model.generate(mel, batched, target, overlap, hp.mu_law)
|
56 |
+
return wav, hp.sample_rate
|
mockingbirdforuse/vocoder/wavernn/models/deepmind_version.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
|
8 |
+
from ..audio import combine_signal
|
9 |
+
from ....log import logger
|
10 |
+
|
11 |
+
|
12 |
+
class WaveRNN(nn.Module):
|
13 |
+
def __init__(self, hidden_size=896, quantisation=256):
|
14 |
+
super(WaveRNN, self).__init__()
|
15 |
+
|
16 |
+
self.hidden_size = hidden_size
|
17 |
+
self.split_size = hidden_size // 2
|
18 |
+
|
19 |
+
# The main matmul
|
20 |
+
self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
21 |
+
|
22 |
+
# Output fc layers
|
23 |
+
self.O1 = nn.Linear(self.split_size, self.split_size)
|
24 |
+
self.O2 = nn.Linear(self.split_size, quantisation)
|
25 |
+
self.O3 = nn.Linear(self.split_size, self.split_size)
|
26 |
+
self.O4 = nn.Linear(self.split_size, quantisation)
|
27 |
+
|
28 |
+
# Input fc layers
|
29 |
+
self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
|
30 |
+
self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
|
31 |
+
|
32 |
+
# biases for the gates
|
33 |
+
self.bias_u = Parameter(torch.zeros(self.hidden_size))
|
34 |
+
self.bias_r = Parameter(torch.zeros(self.hidden_size))
|
35 |
+
self.bias_e = Parameter(torch.zeros(self.hidden_size))
|
36 |
+
|
37 |
+
# display num params
|
38 |
+
self.num_params()
|
39 |
+
|
40 |
+
def forward(self, prev_y, prev_hidden, current_coarse):
|
41 |
+
|
42 |
+
# Main matmul - the projection is split 3 ways
|
43 |
+
R_hidden = self.R(prev_hidden)
|
44 |
+
(
|
45 |
+
R_u,
|
46 |
+
R_r,
|
47 |
+
R_e,
|
48 |
+
) = torch.split(R_hidden, self.hidden_size, dim=1)
|
49 |
+
|
50 |
+
# Project the prev input
|
51 |
+
coarse_input_proj = self.I_coarse(prev_y)
|
52 |
+
I_coarse_u, I_coarse_r, I_coarse_e = torch.split(
|
53 |
+
coarse_input_proj, self.split_size, dim=1
|
54 |
+
)
|
55 |
+
|
56 |
+
# Project the prev input and current coarse sample
|
57 |
+
fine_input = torch.cat([prev_y, current_coarse], dim=1)
|
58 |
+
fine_input_proj = self.I_fine(fine_input)
|
59 |
+
I_fine_u, I_fine_r, I_fine_e = torch.split(
|
60 |
+
fine_input_proj, self.split_size, dim=1
|
61 |
+
)
|
62 |
+
|
63 |
+
# concatenate for the gates
|
64 |
+
I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
|
65 |
+
I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
|
66 |
+
I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
|
67 |
+
|
68 |
+
# Compute all gates for coarse and fine
|
69 |
+
u = F.sigmoid(R_u + I_u + self.bias_u)
|
70 |
+
r = F.sigmoid(R_r + I_r + self.bias_r)
|
71 |
+
e = torch.tanh(r * R_e + I_e + self.bias_e)
|
72 |
+
hidden = u * prev_hidden + (1.0 - u) * e
|
73 |
+
|
74 |
+
# Split the hidden state
|
75 |
+
hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
|
76 |
+
|
77 |
+
# Compute outputs
|
78 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
79 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
80 |
+
|
81 |
+
return out_coarse, out_fine, hidden
|
82 |
+
|
83 |
+
def generate(self, seq_len):
|
84 |
+
with torch.no_grad():
|
85 |
+
# First split up the biases for the gates
|
86 |
+
b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
|
87 |
+
b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
|
88 |
+
b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
|
89 |
+
|
90 |
+
# Lists for the two output seqs
|
91 |
+
c_outputs, f_outputs = [], []
|
92 |
+
|
93 |
+
# Some initial inputs
|
94 |
+
out_coarse = torch.LongTensor([0]).cuda()
|
95 |
+
out_fine = torch.LongTensor([0]).cuda()
|
96 |
+
|
97 |
+
# We'll meed a hidden state
|
98 |
+
hidden = self.init_hidden()
|
99 |
+
|
100 |
+
# Need a clock for display
|
101 |
+
start = time.time()
|
102 |
+
|
103 |
+
# Loop for generation
|
104 |
+
for i in range(seq_len):
|
105 |
+
|
106 |
+
# Split into two hidden states
|
107 |
+
hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
|
108 |
+
|
109 |
+
# Scale and concat previous predictions
|
110 |
+
out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.0
|
111 |
+
out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.0
|
112 |
+
prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
|
113 |
+
|
114 |
+
# Project input
|
115 |
+
coarse_input_proj = self.I_coarse(prev_outputs)
|
116 |
+
I_coarse_u, I_coarse_r, I_coarse_e = torch.split(
|
117 |
+
coarse_input_proj, self.split_size, dim=1
|
118 |
+
)
|
119 |
+
|
120 |
+
# Project hidden state and split 6 ways
|
121 |
+
R_hidden = self.R(hidden)
|
122 |
+
(
|
123 |
+
R_coarse_u,
|
124 |
+
R_fine_u,
|
125 |
+
R_coarse_r,
|
126 |
+
R_fine_r,
|
127 |
+
R_coarse_e,
|
128 |
+
R_fine_e,
|
129 |
+
) = torch.split(R_hidden, self.split_size, dim=1)
|
130 |
+
|
131 |
+
# Compute the coarse gates
|
132 |
+
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
|
133 |
+
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
|
134 |
+
e = torch.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
135 |
+
hidden_coarse = u * hidden_coarse + (1.0 - u) * e
|
136 |
+
|
137 |
+
# Compute the coarse output
|
138 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
139 |
+
posterior = F.softmax(out_coarse, dim=1)
|
140 |
+
distrib = torch.distributions.Categorical(posterior)
|
141 |
+
out_coarse = distrib.sample()
|
142 |
+
c_outputs.append(out_coarse)
|
143 |
+
|
144 |
+
# Project the [prev outputs and predicted coarse sample]
|
145 |
+
coarse_pred = out_coarse.float() / 127.5 - 1.0
|
146 |
+
fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
|
147 |
+
fine_input_proj = self.I_fine(fine_input)
|
148 |
+
I_fine_u, I_fine_r, I_fine_e = torch.split(
|
149 |
+
fine_input_proj, self.split_size, dim=1
|
150 |
+
)
|
151 |
+
|
152 |
+
# Compute the fine gates
|
153 |
+
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
|
154 |
+
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
|
155 |
+
e = torch.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
156 |
+
hidden_fine = u * hidden_fine + (1.0 - u) * e
|
157 |
+
|
158 |
+
# Compute the fine output
|
159 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
160 |
+
posterior = F.softmax(out_fine, dim=1)
|
161 |
+
distrib = torch.distributions.Categorical(posterior)
|
162 |
+
out_fine = distrib.sample()
|
163 |
+
f_outputs.append(out_fine)
|
164 |
+
|
165 |
+
# Put the hidden state back together
|
166 |
+
hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
|
167 |
+
|
168 |
+
coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
|
169 |
+
fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
|
170 |
+
output = combine_signal(coarse, fine)
|
171 |
+
|
172 |
+
return output, coarse, fine
|
173 |
+
|
174 |
+
def init_hidden(self, batch_size=1):
|
175 |
+
return torch.zeros(batch_size, self.hidden_size).cuda()
|
176 |
+
|
177 |
+
def num_params(self):
|
178 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
179 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
180 |
+
logger.debug("Trainable Parameters: %.3f million" % parameters)
|
mockingbirdforuse/vocoder/wavernn/models/fatchord_version.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.parameter import Parameter
|
7 |
+
|
8 |
+
from ..audio import de_emphasis, decode_mu_law
|
9 |
+
from ..hparams import hparams as hp
|
10 |
+
from ...distribution import sample_from_discretized_mix_logistic
|
11 |
+
from ....log import logger
|
12 |
+
|
13 |
+
|
14 |
+
class ResBlock(nn.Module):
|
15 |
+
def __init__(self, dims):
|
16 |
+
super().__init__()
|
17 |
+
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
18 |
+
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
19 |
+
self.batch_norm1 = nn.BatchNorm1d(dims)
|
20 |
+
self.batch_norm2 = nn.BatchNorm1d(dims)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
residual = x
|
24 |
+
x = self.conv1(x)
|
25 |
+
x = self.batch_norm1(x)
|
26 |
+
x = F.relu(x)
|
27 |
+
x = self.conv2(x)
|
28 |
+
x = self.batch_norm2(x)
|
29 |
+
return x + residual
|
30 |
+
|
31 |
+
|
32 |
+
class MelResNet(nn.Module):
|
33 |
+
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
34 |
+
super().__init__()
|
35 |
+
k_size = pad * 2 + 1
|
36 |
+
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
37 |
+
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
38 |
+
self.layers = nn.ModuleList()
|
39 |
+
for i in range(res_blocks):
|
40 |
+
self.layers.append(ResBlock(compute_dims))
|
41 |
+
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.conv_in(x)
|
45 |
+
x = self.batch_norm(x)
|
46 |
+
x = F.relu(x)
|
47 |
+
for f in self.layers:
|
48 |
+
x = f(x)
|
49 |
+
x = self.conv_out(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class Stretch2d(nn.Module):
|
54 |
+
def __init__(self, x_scale, y_scale):
|
55 |
+
super().__init__()
|
56 |
+
self.x_scale = x_scale
|
57 |
+
self.y_scale = y_scale
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
b, c, h, w = x.size()
|
61 |
+
x = x.unsqueeze(-1).unsqueeze(3)
|
62 |
+
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
63 |
+
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
64 |
+
|
65 |
+
|
66 |
+
class UpsampleNetwork(nn.Module):
|
67 |
+
def __init__(
|
68 |
+
self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
total_scale = np.cumproduct(upsample_scales)[-1]
|
72 |
+
self.indent = pad * total_scale
|
73 |
+
self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
74 |
+
self.resnet_stretch = Stretch2d(total_scale, 1)
|
75 |
+
self.up_layers = nn.ModuleList()
|
76 |
+
for scale in upsample_scales:
|
77 |
+
k_size = (1, scale * 2 + 1)
|
78 |
+
padding = (0, scale)
|
79 |
+
stretch = Stretch2d(scale, 1)
|
80 |
+
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
81 |
+
conv.weight.data.fill_(1.0 / k_size[1])
|
82 |
+
self.up_layers.append(stretch)
|
83 |
+
self.up_layers.append(conv)
|
84 |
+
|
85 |
+
def forward(self, m):
|
86 |
+
aux = self.resnet(m).unsqueeze(1)
|
87 |
+
aux = self.resnet_stretch(aux)
|
88 |
+
aux = aux.squeeze(1)
|
89 |
+
m = m.unsqueeze(1)
|
90 |
+
for f in self.up_layers:
|
91 |
+
m = f(m)
|
92 |
+
m = m.squeeze(1)[:, :, self.indent : -self.indent]
|
93 |
+
return m.transpose(1, 2), aux.transpose(1, 2)
|
94 |
+
|
95 |
+
|
96 |
+
class WaveRNN(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
rnn_dims,
|
100 |
+
fc_dims,
|
101 |
+
bits,
|
102 |
+
pad,
|
103 |
+
upsample_factors,
|
104 |
+
feat_dims,
|
105 |
+
compute_dims,
|
106 |
+
res_out_dims,
|
107 |
+
res_blocks,
|
108 |
+
hop_length,
|
109 |
+
sample_rate,
|
110 |
+
mode="RAW",
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
self.mode = mode
|
114 |
+
self.pad = pad
|
115 |
+
if self.mode == "RAW":
|
116 |
+
self.n_classes = 2**bits
|
117 |
+
elif self.mode == "MOL":
|
118 |
+
self.n_classes = 30
|
119 |
+
else:
|
120 |
+
RuntimeError("Unknown model mode value - ", self.mode)
|
121 |
+
|
122 |
+
self.rnn_dims = rnn_dims
|
123 |
+
self.aux_dims = res_out_dims // 4
|
124 |
+
self.hop_length = hop_length
|
125 |
+
self.sample_rate = sample_rate
|
126 |
+
|
127 |
+
self.upsample = UpsampleNetwork(
|
128 |
+
feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad
|
129 |
+
)
|
130 |
+
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
131 |
+
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
132 |
+
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
|
133 |
+
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
134 |
+
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
135 |
+
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
136 |
+
|
137 |
+
self.step = Parameter(torch.zeros(1).long(), requires_grad=False)
|
138 |
+
self.num_params()
|
139 |
+
|
140 |
+
def forward(self, x, mels):
|
141 |
+
self.step += 1
|
142 |
+
bsize = x.size(0)
|
143 |
+
if torch.cuda.is_available():
|
144 |
+
h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
145 |
+
h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
|
146 |
+
else:
|
147 |
+
h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
|
148 |
+
h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
|
149 |
+
mels, aux = self.upsample(mels)
|
150 |
+
|
151 |
+
aux_idx = [self.aux_dims * i for i in range(5)]
|
152 |
+
a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
|
153 |
+
a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
|
154 |
+
a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
|
155 |
+
a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
|
156 |
+
|
157 |
+
x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
158 |
+
x = self.I(x)
|
159 |
+
res = x
|
160 |
+
x, _ = self.rnn1(x, h1)
|
161 |
+
|
162 |
+
x = x + res
|
163 |
+
res = x
|
164 |
+
x = torch.cat([x, a2], dim=2)
|
165 |
+
x, _ = self.rnn2(x, h2)
|
166 |
+
|
167 |
+
x = x + res
|
168 |
+
x = torch.cat([x, a3], dim=2)
|
169 |
+
x = F.relu(self.fc1(x))
|
170 |
+
|
171 |
+
x = torch.cat([x, a4], dim=2)
|
172 |
+
x = F.relu(self.fc2(x))
|
173 |
+
return self.fc3(x)
|
174 |
+
|
175 |
+
def generate(self, mels, batched, target, overlap, mu_law):
|
176 |
+
mu_law = mu_law if self.mode == "RAW" else False
|
177 |
+
|
178 |
+
self.eval()
|
179 |
+
output = []
|
180 |
+
start = time.time()
|
181 |
+
rnn1 = self.get_gru_cell(self.rnn1)
|
182 |
+
rnn2 = self.get_gru_cell(self.rnn2)
|
183 |
+
|
184 |
+
with torch.no_grad():
|
185 |
+
if torch.cuda.is_available():
|
186 |
+
mels = mels.cuda()
|
187 |
+
else:
|
188 |
+
mels = mels.cpu()
|
189 |
+
wave_len = (mels.size(-1) - 1) * self.hop_length
|
190 |
+
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both")
|
191 |
+
mels, aux = self.upsample(mels.transpose(1, 2))
|
192 |
+
|
193 |
+
if batched:
|
194 |
+
mels = self.fold_with_overlap(mels, target, overlap)
|
195 |
+
aux = self.fold_with_overlap(aux, target, overlap)
|
196 |
+
|
197 |
+
b_size, seq_len, _ = mels.size()
|
198 |
+
|
199 |
+
if torch.cuda.is_available():
|
200 |
+
h1 = torch.zeros(b_size, self.rnn_dims).cuda()
|
201 |
+
h2 = torch.zeros(b_size, self.rnn_dims).cuda()
|
202 |
+
x = torch.zeros(b_size, 1).cuda()
|
203 |
+
else:
|
204 |
+
h1 = torch.zeros(b_size, self.rnn_dims).cpu()
|
205 |
+
h2 = torch.zeros(b_size, self.rnn_dims).cpu()
|
206 |
+
x = torch.zeros(b_size, 1).cpu()
|
207 |
+
|
208 |
+
d = self.aux_dims
|
209 |
+
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
|
210 |
+
|
211 |
+
for i in range(seq_len):
|
212 |
+
|
213 |
+
m_t = mels[:, i, :]
|
214 |
+
|
215 |
+
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
|
216 |
+
|
217 |
+
x = torch.cat([x, m_t, a1_t], dim=1)
|
218 |
+
x = self.I(x)
|
219 |
+
h1 = rnn1(x, h1)
|
220 |
+
|
221 |
+
x = x + h1
|
222 |
+
inp = torch.cat([x, a2_t], dim=1)
|
223 |
+
h2 = rnn2(inp, h2)
|
224 |
+
|
225 |
+
x = x + h2
|
226 |
+
x = torch.cat([x, a3_t], dim=1)
|
227 |
+
x = F.relu(self.fc1(x))
|
228 |
+
|
229 |
+
x = torch.cat([x, a4_t], dim=1)
|
230 |
+
x = F.relu(self.fc2(x))
|
231 |
+
|
232 |
+
logits = self.fc3(x)
|
233 |
+
|
234 |
+
if self.mode == "MOL":
|
235 |
+
sample = sample_from_discretized_mix_logistic(
|
236 |
+
logits.unsqueeze(0).transpose(1, 2)
|
237 |
+
)
|
238 |
+
output.append(sample.view(-1))
|
239 |
+
if torch.cuda.is_available():
|
240 |
+
# x = torch.FloatTensor([[sample]]).cuda()
|
241 |
+
x = sample.transpose(0, 1).cuda()
|
242 |
+
else:
|
243 |
+
x = sample.transpose(0, 1)
|
244 |
+
|
245 |
+
elif self.mode == "RAW":
|
246 |
+
posterior = F.softmax(logits, dim=1)
|
247 |
+
distrib = torch.distributions.Categorical(posterior)
|
248 |
+
|
249 |
+
sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0
|
250 |
+
output.append(sample)
|
251 |
+
x = sample.unsqueeze(-1)
|
252 |
+
else:
|
253 |
+
raise RuntimeError("Unknown model mode value - ", self.mode)
|
254 |
+
|
255 |
+
output = torch.stack(output).transpose(0, 1)
|
256 |
+
output = output.cpu().numpy()
|
257 |
+
output = output.astype(np.float64)
|
258 |
+
|
259 |
+
if batched:
|
260 |
+
output = self.xfade_and_unfold(output, target, overlap)
|
261 |
+
else:
|
262 |
+
output = output[0]
|
263 |
+
|
264 |
+
if mu_law:
|
265 |
+
output = decode_mu_law(output, self.n_classes, False)
|
266 |
+
if hp.apply_preemphasis:
|
267 |
+
output = de_emphasis(output)
|
268 |
+
|
269 |
+
# Fade-out at the end to avoid signal cutting out suddenly
|
270 |
+
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
271 |
+
output = output[:wave_len]
|
272 |
+
output[-20 * self.hop_length :] *= fade_out
|
273 |
+
|
274 |
+
self.train()
|
275 |
+
|
276 |
+
return output
|
277 |
+
|
278 |
+
def get_gru_cell(self, gru):
|
279 |
+
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
280 |
+
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
281 |
+
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
282 |
+
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
283 |
+
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
284 |
+
return gru_cell
|
285 |
+
|
286 |
+
def pad_tensor(self, x, pad, side="both"):
|
287 |
+
# NB - this is just a quick method i need right now
|
288 |
+
# i.e., it won't generalise to other shapes/dims
|
289 |
+
b, t, c = x.size()
|
290 |
+
total = t + 2 * pad if side == "both" else t + pad
|
291 |
+
if torch.cuda.is_available():
|
292 |
+
padded = torch.zeros(b, total, c).cuda()
|
293 |
+
else:
|
294 |
+
padded = torch.zeros(b, total, c).cpu()
|
295 |
+
if side == "before" or side == "both":
|
296 |
+
padded[:, pad : pad + t, :] = x
|
297 |
+
elif side == "after":
|
298 |
+
padded[:, :t, :] = x
|
299 |
+
return padded
|
300 |
+
|
301 |
+
def fold_with_overlap(self, x, target, overlap):
|
302 |
+
|
303 |
+
"""Fold the tensor with overlap for quick batched inference.
|
304 |
+
Overlap will be used for crossfading in xfade_and_unfold()
|
305 |
+
|
306 |
+
Args:
|
307 |
+
x (tensor) : Upsampled conditioning features.
|
308 |
+
shape=(1, timesteps, features)
|
309 |
+
target (int) : Target timesteps for each index of batch
|
310 |
+
overlap (int) : Timesteps for both xfade and rnn warmup
|
311 |
+
|
312 |
+
Return:
|
313 |
+
(tensor) : shape=(num_folds, target + 2 * overlap, features)
|
314 |
+
|
315 |
+
Details:
|
316 |
+
x = [[h1, h2, ... hn]]
|
317 |
+
|
318 |
+
Where each h is a vector of conditioning features
|
319 |
+
|
320 |
+
Eg: target=2, overlap=1 with x.size(1)=10
|
321 |
+
|
322 |
+
folded = [[h1, h2, h3, h4],
|
323 |
+
[h4, h5, h6, h7],
|
324 |
+
[h7, h8, h9, h10]]
|
325 |
+
"""
|
326 |
+
|
327 |
+
_, total_len, features = x.size()
|
328 |
+
|
329 |
+
# Calculate variables needed
|
330 |
+
num_folds = (total_len - overlap) // (target + overlap)
|
331 |
+
extended_len = num_folds * (overlap + target) + overlap
|
332 |
+
remaining = total_len - extended_len
|
333 |
+
|
334 |
+
# Pad if some time steps poking out
|
335 |
+
if remaining != 0:
|
336 |
+
num_folds += 1
|
337 |
+
padding = target + 2 * overlap - remaining
|
338 |
+
x = self.pad_tensor(x, padding, side="after")
|
339 |
+
|
340 |
+
if torch.cuda.is_available():
|
341 |
+
folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
|
342 |
+
else:
|
343 |
+
folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu()
|
344 |
+
|
345 |
+
# Get the values for the folded tensor
|
346 |
+
for i in range(num_folds):
|
347 |
+
start = i * (target + overlap)
|
348 |
+
end = start + target + 2 * overlap
|
349 |
+
folded[i] = x[:, start:end, :]
|
350 |
+
|
351 |
+
return folded
|
352 |
+
|
353 |
+
def xfade_and_unfold(self, y, target, overlap):
|
354 |
+
|
355 |
+
"""Applies a crossfade and unfolds into a 1d array.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
y (ndarry) : Batched sequences of audio samples
|
359 |
+
shape=(num_folds, target + 2 * overlap)
|
360 |
+
dtype=np.float64
|
361 |
+
overlap (int) : Timesteps for both xfade and rnn warmup
|
362 |
+
|
363 |
+
Return:
|
364 |
+
(ndarry) : audio samples in a 1d array
|
365 |
+
shape=(total_len)
|
366 |
+
dtype=np.float64
|
367 |
+
|
368 |
+
Details:
|
369 |
+
y = [[seq1],
|
370 |
+
[seq2],
|
371 |
+
[seq3]]
|
372 |
+
|
373 |
+
Apply a gain envelope at both ends of the sequences
|
374 |
+
|
375 |
+
y = [[seq1_in, seq1_target, seq1_out],
|
376 |
+
[seq2_in, seq2_target, seq2_out],
|
377 |
+
[seq3_in, seq3_target, seq3_out]]
|
378 |
+
|
379 |
+
Stagger and add up the groups of samples:
|
380 |
+
|
381 |
+
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
382 |
+
|
383 |
+
"""
|
384 |
+
|
385 |
+
num_folds, length = y.shape
|
386 |
+
target = length - 2 * overlap
|
387 |
+
total_len = num_folds * (target + overlap) + overlap
|
388 |
+
|
389 |
+
# Need some silence for the rnn warmup
|
390 |
+
silence_len = overlap // 2
|
391 |
+
fade_len = overlap - silence_len
|
392 |
+
silence = np.zeros((silence_len), dtype=np.float64)
|
393 |
+
|
394 |
+
# Equal power crossfade
|
395 |
+
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
|
396 |
+
fade_in = np.sqrt(0.5 * (1 + t))
|
397 |
+
fade_out = np.sqrt(0.5 * (1 - t))
|
398 |
+
|
399 |
+
# Concat the silence to the fades
|
400 |
+
fade_in = np.concatenate([silence, fade_in])
|
401 |
+
fade_out = np.concatenate([fade_out, silence])
|
402 |
+
|
403 |
+
# Apply the gain to the overlap samples
|
404 |
+
y[:, :overlap] *= fade_in
|
405 |
+
y[:, -overlap:] *= fade_out
|
406 |
+
|
407 |
+
unfolded = np.zeros((total_len), dtype=np.float64)
|
408 |
+
|
409 |
+
# Loop to add up all the samples
|
410 |
+
for i in range(num_folds):
|
411 |
+
start = i * (target + overlap)
|
412 |
+
end = start + target + 2 * overlap
|
413 |
+
unfolded[start:end] += y[i]
|
414 |
+
|
415 |
+
return unfolded
|
416 |
+
|
417 |
+
def get_step(self):
|
418 |
+
return self.step.data.item()
|
419 |
+
|
420 |
+
def checkpoint(self, model_dir, optimizer):
|
421 |
+
k_steps = self.get_step() // 1000
|
422 |
+
self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer)
|
423 |
+
|
424 |
+
def load(self, path, optimizer):
|
425 |
+
checkpoint = torch.load(path)
|
426 |
+
if "optimizer_state" in checkpoint:
|
427 |
+
self.load_state_dict(checkpoint["model_state"])
|
428 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
429 |
+
else:
|
430 |
+
# Backwards compatibility
|
431 |
+
self.load_state_dict(checkpoint)
|
432 |
+
|
433 |
+
def save(self, path, optimizer):
|
434 |
+
torch.save(
|
435 |
+
{
|
436 |
+
"model_state": self.state_dict(),
|
437 |
+
"optimizer_state": optimizer.state_dict(),
|
438 |
+
},
|
439 |
+
path,
|
440 |
+
)
|
441 |
+
|
442 |
+
def num_params(self):
|
443 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
444 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
445 |
+
logger.debug("Trainable Parameters: %.3fM" % parameters)
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsm6
|
3 |
+
libxext6
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
numba
|
5 |
+
opencv-python-headless
|
6 |
+
scipy
|
7 |
+
pypinyin
|
8 |
+
librosa
|
9 |
+
webrtcvad
|
10 |
+
Unidecode
|
11 |
+
inflect
|
12 |
+
loguru
|
13 |
+
gradio
|