Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +8 -0
- .gitignore +200 -0
- Colab-Inference.ipynb +184 -0
- Dockerfile +42 -0
- GPT_SoVITS/AR/__init__.py +0 -0
- GPT_SoVITS/AR/data/__init__.py +0 -0
- GPT_SoVITS/AR/data/bucket_sampler.py +149 -0
- GPT_SoVITS/AR/data/data_module.py +81 -0
- GPT_SoVITS/AR/data/dataset.py +320 -0
- GPT_SoVITS/AR/models/__init__.py +0 -0
- GPT_SoVITS/AR/models/t2s_lightning_module.py +145 -0
- GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py +110 -0
- GPT_SoVITS/AR/models/t2s_model.py +935 -0
- GPT_SoVITS/AR/models/t2s_model_onnx.py +394 -0
- GPT_SoVITS/AR/models/utils.py +282 -0
- GPT_SoVITS/AR/modules/__init__.py +0 -0
- GPT_SoVITS/AR/modules/activation.py +413 -0
- GPT_SoVITS/AR/modules/activation_onnx.py +188 -0
- GPT_SoVITS/AR/modules/embedding.py +78 -0
- GPT_SoVITS/AR/modules/embedding_onnx.py +63 -0
- GPT_SoVITS/AR/modules/lr_schedulers.py +85 -0
- GPT_SoVITS/AR/modules/optim.py +593 -0
- GPT_SoVITS/AR/modules/patched_mha_with_cache.py +428 -0
- GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +85 -0
- GPT_SoVITS/AR/modules/scaling.py +320 -0
- GPT_SoVITS/AR/modules/transformer.py +362 -0
- GPT_SoVITS/AR/modules/transformer_onnx.py +281 -0
- GPT_SoVITS/AR/text_processing/__init__.py +0 -0
- GPT_SoVITS/AR/text_processing/phonemizer.py +72 -0
- GPT_SoVITS/AR/text_processing/symbols.py +12 -0
- GPT_SoVITS/AR/utils/__init__.py +36 -0
- GPT_SoVITS/AR/utils/initialize.py +39 -0
- GPT_SoVITS/AR/utils/io.py +30 -0
- GPT_SoVITS/download.py +13 -0
- GPT_SoVITS/export_torch_script.py +861 -0
- GPT_SoVITS/export_torch_script_v3.py +1035 -0
- GPT_SoVITS/inference_cli.py +86 -0
- GPT_SoVITS/inference_gui.py +316 -0
- GPT_SoVITS/inference_webui.py +1281 -0
- GPT_SoVITS/inference_webui_fast.py +546 -0
- GPT_SoVITS/onnx_export.py +398 -0
- GPT_SoVITS/prepare_datasets/1-get-text.py +143 -0
- GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +134 -0
- GPT_SoVITS/prepare_datasets/3-get-semantic.py +118 -0
- GPT_SoVITS/pretrained_models/.gitignore +2 -0
- GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +3 -0
- GPT_SoVITS/pretrained_models/s1v3.ckpt +3 -0
- GPT_SoVITS/process_ckpt.py +124 -0
- GPT_SoVITS/s1_train.py +171 -0
- GPT_SoVITS/s2_train.py +680 -0
.dockerignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
docs
|
2 |
+
logs
|
3 |
+
output
|
4 |
+
reference
|
5 |
+
SoVITS_weights
|
6 |
+
GPT_weights
|
7 |
+
TEMP
|
8 |
+
.git
|
.gitignore
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
.vscode
|
3 |
+
__pycache__
|
4 |
+
*.pyc
|
5 |
+
env
|
6 |
+
runtime
|
7 |
+
.idea
|
8 |
+
output
|
9 |
+
logs
|
10 |
+
reference
|
11 |
+
GPT_weights
|
12 |
+
SoVITS_weights
|
13 |
+
GPT_weights_v2
|
14 |
+
SoVITS_weights_v2
|
15 |
+
GPT_weights_v3
|
16 |
+
SoVITS_weights_v3
|
17 |
+
TEMP
|
18 |
+
weight.json
|
19 |
+
ffmpeg*
|
20 |
+
ffprobe*
|
21 |
+
cfg.json
|
22 |
+
speakers.json
|
23 |
+
ref_audios
|
24 |
+
tools/AP_BWE_main/24kto48k/*
|
25 |
+
!tools/AP_BWE_main/24kto48k/readme.txt
|
26 |
+
|
27 |
+
# Byte-compiled / optimized / DLL files
|
28 |
+
__pycache__/
|
29 |
+
*.py[cod]
|
30 |
+
*$py.class
|
31 |
+
|
32 |
+
# C extensions
|
33 |
+
*.so
|
34 |
+
|
35 |
+
# Distribution / packaging
|
36 |
+
.Python
|
37 |
+
build/
|
38 |
+
develop-eggs/
|
39 |
+
dist/
|
40 |
+
downloads/
|
41 |
+
eggs/
|
42 |
+
.eggs/
|
43 |
+
lib/
|
44 |
+
lib64/
|
45 |
+
parts/
|
46 |
+
sdist/
|
47 |
+
var/
|
48 |
+
wheels/
|
49 |
+
share/python-wheels/
|
50 |
+
*.egg-info/
|
51 |
+
.installed.cfg
|
52 |
+
*.egg
|
53 |
+
MANIFEST
|
54 |
+
|
55 |
+
# PyInstaller
|
56 |
+
# Usually these files are written by a python script from a template
|
57 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
58 |
+
*.manifest
|
59 |
+
*.spec
|
60 |
+
|
61 |
+
# Installer logs
|
62 |
+
pip-log.txt
|
63 |
+
pip-delete-this-directory.txt
|
64 |
+
|
65 |
+
# Unit test / coverage reports
|
66 |
+
htmlcov/
|
67 |
+
.tox/
|
68 |
+
.nox/
|
69 |
+
.coverage
|
70 |
+
.coverage.*
|
71 |
+
.cache
|
72 |
+
nosetests.xml
|
73 |
+
coverage.xml
|
74 |
+
*.cover
|
75 |
+
*.py,cover
|
76 |
+
.hypothesis/
|
77 |
+
.pytest_cache/
|
78 |
+
cover/
|
79 |
+
|
80 |
+
# Translations
|
81 |
+
*.mo
|
82 |
+
*.pot
|
83 |
+
|
84 |
+
# Django stuff:
|
85 |
+
*.log
|
86 |
+
local_settings.py
|
87 |
+
db.sqlite3
|
88 |
+
db.sqlite3-journal
|
89 |
+
|
90 |
+
# Flask stuff:
|
91 |
+
instance/
|
92 |
+
.webassets-cache
|
93 |
+
|
94 |
+
# Scrapy stuff:
|
95 |
+
.scrapy
|
96 |
+
|
97 |
+
# Sphinx documentation
|
98 |
+
docs/_build/
|
99 |
+
|
100 |
+
# PyBuilder
|
101 |
+
.pybuilder/
|
102 |
+
target/
|
103 |
+
|
104 |
+
# Jupyter Notebook
|
105 |
+
.ipynb_checkpoints
|
106 |
+
|
107 |
+
# IPython
|
108 |
+
profile_default/
|
109 |
+
ipython_config.py
|
110 |
+
|
111 |
+
# pyenv
|
112 |
+
# For a library or package, you might want to ignore these files since the code is
|
113 |
+
# intended to run in multiple environments; otherwise, check them in:
|
114 |
+
# .python-version
|
115 |
+
|
116 |
+
# pipenv
|
117 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
118 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
119 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
120 |
+
# install all needed dependencies.
|
121 |
+
#Pipfile.lock
|
122 |
+
|
123 |
+
# UV
|
124 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
125 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
126 |
+
# commonly ignored for libraries.
|
127 |
+
#uv.lock
|
128 |
+
|
129 |
+
# poetry
|
130 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
131 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
132 |
+
# commonly ignored for libraries.
|
133 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
134 |
+
#poetry.lock
|
135 |
+
|
136 |
+
# pdm
|
137 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
138 |
+
#pdm.lock
|
139 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
140 |
+
# in version control.
|
141 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
142 |
+
.pdm.toml
|
143 |
+
.pdm-python
|
144 |
+
.pdm-build/
|
145 |
+
|
146 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
147 |
+
__pypackages__/
|
148 |
+
|
149 |
+
# Celery stuff
|
150 |
+
celerybeat-schedule
|
151 |
+
celerybeat.pid
|
152 |
+
|
153 |
+
# SageMath parsed files
|
154 |
+
*.sage.py
|
155 |
+
|
156 |
+
# Environments
|
157 |
+
.env
|
158 |
+
.venv
|
159 |
+
env/
|
160 |
+
venv/
|
161 |
+
ENV/
|
162 |
+
env.bak/
|
163 |
+
venv.bak/
|
164 |
+
|
165 |
+
# Spyder project settings
|
166 |
+
.spyderproject
|
167 |
+
.spyproject
|
168 |
+
|
169 |
+
# Rope project settings
|
170 |
+
.ropeproject
|
171 |
+
|
172 |
+
# mkdocs documentation
|
173 |
+
/site
|
174 |
+
|
175 |
+
# mypy
|
176 |
+
.mypy_cache/
|
177 |
+
.dmypy.json
|
178 |
+
dmypy.json
|
179 |
+
|
180 |
+
# Pyre type checker
|
181 |
+
.pyre/
|
182 |
+
|
183 |
+
# pytype static type analyzer
|
184 |
+
.pytype/
|
185 |
+
|
186 |
+
# Cython debug symbols
|
187 |
+
cython_debug/
|
188 |
+
|
189 |
+
# PyCharm
|
190 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
191 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
192 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
193 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
194 |
+
#.idea/
|
195 |
+
|
196 |
+
# Ruff stuff:
|
197 |
+
.ruff_cache/
|
198 |
+
|
199 |
+
# PyPI configuration file
|
200 |
+
.pypirc
|
Colab-Inference.ipynb
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# GPT-SoVITS Infer"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"## Env Setup (Run Once Only)\n",
|
15 |
+
"## 环境配置, 只需运行一次"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "markdown",
|
20 |
+
"metadata": {},
|
21 |
+
"source": [
|
22 |
+
"### 1."
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {
|
29 |
+
"id": "e9b7iFV3dm1f"
|
30 |
+
},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"%%writefile /content/setup.sh\n",
|
34 |
+
"set -e\n",
|
35 |
+
"\n",
|
36 |
+
"cd /content\n",
|
37 |
+
"\n",
|
38 |
+
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
|
39 |
+
"\n",
|
40 |
+
"cd GPT-SoVITS\n",
|
41 |
+
"\n",
|
42 |
+
"mkdir GPT_weights\n",
|
43 |
+
"\n",
|
44 |
+
"mkdir SoVITS_weights\n",
|
45 |
+
"\n",
|
46 |
+
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
|
47 |
+
" :\n",
|
48 |
+
"else\n",
|
49 |
+
" conda create -n GPTSoVITS python=3.10 -y\n",
|
50 |
+
"fi\n",
|
51 |
+
"\n",
|
52 |
+
"source activate GPTSoVITS\n",
|
53 |
+
"\n",
|
54 |
+
"pip install ipykernel\n",
|
55 |
+
"\n",
|
56 |
+
"bash install.sh --source HF"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"metadata": {},
|
62 |
+
"source": [
|
63 |
+
"### 2."
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": null,
|
69 |
+
"metadata": {
|
70 |
+
"cellView": "form",
|
71 |
+
"id": "0NgxXg5sjv7z"
|
72 |
+
},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"%pip install -q condacolab\n",
|
76 |
+
"import condacolab\n",
|
77 |
+
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
|
78 |
+
"!cd /content && bash setup.sh"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {},
|
84 |
+
"source": [
|
85 |
+
"# Download Model"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "markdown",
|
90 |
+
"metadata": {},
|
91 |
+
"source": [
|
92 |
+
"### Download From HuggingFace"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"metadata": {
|
99 |
+
"cellView": "form",
|
100 |
+
"id": "vbZY-LnM0tzq"
|
101 |
+
},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"# Modify These\n",
|
105 |
+
"USER_ID = \"AkitoP\"\n",
|
106 |
+
"REPO_NAME = \"GPT-SoVITS-v2-aegi\"\n",
|
107 |
+
"BRANCH = \"main\"\n",
|
108 |
+
"GPT_PATH = \"new_aegigoe-e100.ckpt\"\n",
|
109 |
+
"SOVITS_PATH = \"new_aegigoe_e60_s32220.pth\"\n",
|
110 |
+
"\n",
|
111 |
+
"# Do Not Modify\n",
|
112 |
+
"HF_BASE = \"https://huggingface.co\"\n",
|
113 |
+
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
|
114 |
+
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{GPT_PATH}\"\n",
|
115 |
+
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{SOVITS_PATH}\"\n",
|
116 |
+
"\n",
|
117 |
+
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
|
118 |
+
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\"\n"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "markdown",
|
123 |
+
"metadata": {},
|
124 |
+
"source": [
|
125 |
+
"### Download From ModelScope"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": null,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"# Modify These\n",
|
135 |
+
"USER_ID = \"aihobbyist\"\n",
|
136 |
+
"REPO_NAME = \"GPT-SoVits-V2-models\"\n",
|
137 |
+
"BRANCH = \"master\"\n",
|
138 |
+
"GPT_PATH = \"Genshin_Impact/EN/GPT_GenshinImpact_EN_5.1.ckpt\"\n",
|
139 |
+
"SOVITS_PATH = \"Wuthering_Waves/CN/SV_WutheringWaves_CN_1.3.pth\"\n",
|
140 |
+
"\n",
|
141 |
+
"# Do Not Modify\n",
|
142 |
+
"HF_BASE = \"https://www.modelscope.cn/models\"\n",
|
143 |
+
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
|
144 |
+
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{GPT_PATH}\"\n",
|
145 |
+
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{SOVITS_PATH}\"\n",
|
146 |
+
"\n",
|
147 |
+
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
|
148 |
+
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\""
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "markdown",
|
153 |
+
"metadata": {},
|
154 |
+
"source": [
|
155 |
+
"# Launch WebUI\n",
|
156 |
+
"# 启动 WebUI"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {
|
163 |
+
"cellView": "form",
|
164 |
+
"id": "4oRGUzkrk8C7"
|
165 |
+
},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
|
169 |
+
]
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"metadata": {
|
173 |
+
"accelerator": "GPU",
|
174 |
+
"colab": {
|
175 |
+
"provenance": []
|
176 |
+
},
|
177 |
+
"kernelspec": {
|
178 |
+
"display_name": "Python 3",
|
179 |
+
"name": "python3"
|
180 |
+
}
|
181 |
+
},
|
182 |
+
"nbformat": 4,
|
183 |
+
"nbformat_minor": 0
|
184 |
+
}
|
Dockerfile
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base CUDA image
|
2 |
+
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
|
3 |
+
|
4 |
+
LABEL maintainer="[email protected]"
|
5 |
+
LABEL version="dev-20240209"
|
6 |
+
LABEL description="Docker image for GPT-SoVITS"
|
7 |
+
|
8 |
+
|
9 |
+
# Install 3rd party apps
|
10 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
11 |
+
ENV TZ=Etc/UTC
|
12 |
+
RUN apt-get update && \
|
13 |
+
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
|
14 |
+
git lfs install && \
|
15 |
+
rm -rf /var/lib/apt/lists/*
|
16 |
+
|
17 |
+
# Copy only requirements.txt initially to leverage Docker cache
|
18 |
+
WORKDIR /workspace
|
19 |
+
COPY requirements.txt /workspace/
|
20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
+
|
22 |
+
# Define a build-time argument for image type
|
23 |
+
ARG IMAGE_TYPE=full
|
24 |
+
|
25 |
+
# Conditional logic based on the IMAGE_TYPE argument
|
26 |
+
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
|
27 |
+
COPY ./Docker /workspace/Docker
|
28 |
+
# elite 类型的镜像里面不包含额外的模型
|
29 |
+
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
|
30 |
+
chmod +x /workspace/Docker/download.sh && \
|
31 |
+
/workspace/Docker/download.sh && \
|
32 |
+
python /workspace/Docker/download.py && \
|
33 |
+
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
|
34 |
+
fi
|
35 |
+
|
36 |
+
|
37 |
+
# Copy the rest of the application
|
38 |
+
COPY . /workspace
|
39 |
+
|
40 |
+
EXPOSE 9871 9872 9873 9874 9880
|
41 |
+
|
42 |
+
CMD ["python", "webui.py"]
|
GPT_SoVITS/AR/__init__.py
ADDED
File without changes
|
GPT_SoVITS/AR/data/__init__.py
ADDED
File without changes
|
GPT_SoVITS/AR/data/bucket_sampler.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import itertools
|
4 |
+
import math
|
5 |
+
import random
|
6 |
+
from random import shuffle
|
7 |
+
from typing import Iterator, Optional, TypeVar
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from torch.utils.data import Dataset, Sampler
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"DistributedBucketSampler",
|
15 |
+
]
|
16 |
+
|
17 |
+
T_co = TypeVar("T_co", covariant=True)
|
18 |
+
|
19 |
+
|
20 |
+
class DistributedBucketSampler(Sampler[T_co]):
|
21 |
+
r"""
|
22 |
+
sort the dataset wrt. input length
|
23 |
+
divide samples into buckets
|
24 |
+
sort within buckets
|
25 |
+
divide buckets into batches
|
26 |
+
sort batches
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dataset: Dataset,
|
32 |
+
num_replicas: Optional[int] = None,
|
33 |
+
rank: Optional[int] = None,
|
34 |
+
shuffle: bool = True,
|
35 |
+
seed: int = 0,
|
36 |
+
drop_last: bool = False,
|
37 |
+
batch_size: int = 32,
|
38 |
+
) -> None:
|
39 |
+
if num_replicas is None:
|
40 |
+
if not dist.is_available():
|
41 |
+
raise RuntimeError("Requires distributed package to be available")
|
42 |
+
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
|
43 |
+
if rank is None:
|
44 |
+
if not dist.is_available():
|
45 |
+
raise RuntimeError("Requires distributed package to be available")
|
46 |
+
rank = dist.get_rank() if torch.cuda.is_available() else 0
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
torch.cuda.set_device(rank)
|
49 |
+
if rank >= num_replicas or rank < 0:
|
50 |
+
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
51 |
+
self.dataset = dataset
|
52 |
+
self.num_replicas = num_replicas
|
53 |
+
self.rank = rank
|
54 |
+
self.epoch = 0
|
55 |
+
self.drop_last = drop_last
|
56 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
57 |
+
# is no need to drop any data, since the dataset will be split equally.
|
58 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
59 |
+
# Split to nearest available length that is evenly divisible.
|
60 |
+
# This is to ensure each rank receives the same amount of data when
|
61 |
+
# using this Sampler.
|
62 |
+
self.num_samples = math.ceil(
|
63 |
+
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
self.num_samples = math.ceil(
|
67 |
+
len(self.dataset) / self.num_replicas,
|
68 |
+
) # type: ignore[arg-type]
|
69 |
+
self.total_size = self.num_samples * self.num_replicas
|
70 |
+
self.shuffle = shuffle
|
71 |
+
self.seed = seed
|
72 |
+
self.batch_size = batch_size
|
73 |
+
self.id_with_length = self._get_sample_lengths()
|
74 |
+
self.id_buckets = self.make_buckets(bucket_width=2.0)
|
75 |
+
|
76 |
+
def _get_sample_lengths(self):
|
77 |
+
id_with_lengths = []
|
78 |
+
for i in range(len(self.dataset)):
|
79 |
+
id_with_lengths.append((i, self.dataset.get_sample_length(i)))
|
80 |
+
id_with_lengths.sort(key=lambda x: x[1])
|
81 |
+
return id_with_lengths
|
82 |
+
|
83 |
+
def make_buckets(self, bucket_width: float = 2.0):
|
84 |
+
buckets = []
|
85 |
+
cur = []
|
86 |
+
max_sec = bucket_width
|
87 |
+
for id, sec in self.id_with_length:
|
88 |
+
if sec < max_sec:
|
89 |
+
cur.append(id)
|
90 |
+
else:
|
91 |
+
buckets.append(cur)
|
92 |
+
cur = [id]
|
93 |
+
max_sec += bucket_width
|
94 |
+
if len(cur) > 0:
|
95 |
+
buckets.append(cur)
|
96 |
+
return buckets
|
97 |
+
|
98 |
+
def __iter__(self) -> Iterator[T_co]:
|
99 |
+
if self.shuffle:
|
100 |
+
# deterministically shuffle based on epoch and seed
|
101 |
+
g = torch.Generator()
|
102 |
+
g.manual_seed(self.seed + self.epoch)
|
103 |
+
random.seed(self.epoch + self.seed)
|
104 |
+
shuffled_bucket = []
|
105 |
+
for buc in self.id_buckets:
|
106 |
+
buc_copy = buc.copy()
|
107 |
+
shuffle(buc_copy)
|
108 |
+
shuffled_bucket.append(buc_copy)
|
109 |
+
grouped_batch_size = self.batch_size * self.num_replicas
|
110 |
+
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
111 |
+
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
112 |
+
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
113 |
+
shuffle(batches)
|
114 |
+
indices = list(itertools.chain(*batches))
|
115 |
+
else:
|
116 |
+
# type: ignore[arg-type]
|
117 |
+
indices = list(range(len(self.dataset)))
|
118 |
+
|
119 |
+
if not self.drop_last:
|
120 |
+
# add extra samples to make it evenly divisible
|
121 |
+
padding_size = self.total_size - len(indices)
|
122 |
+
if padding_size <= len(indices):
|
123 |
+
indices += indices[:padding_size]
|
124 |
+
else:
|
125 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
126 |
+
else:
|
127 |
+
# remove tail of data to make it evenly divisible.
|
128 |
+
indices = indices[: self.total_size]
|
129 |
+
assert len(indices) == self.total_size
|
130 |
+
|
131 |
+
# subsample
|
132 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
133 |
+
assert len(indices) == self.num_samples
|
134 |
+
|
135 |
+
return iter(indices)
|
136 |
+
|
137 |
+
def __len__(self) -> int:
|
138 |
+
return self.num_samples
|
139 |
+
|
140 |
+
def set_epoch(self, epoch: int) -> None:
|
141 |
+
r"""
|
142 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
143 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
144 |
+
sampler will yield the same ordering.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
epoch (int): Epoch number.
|
148 |
+
"""
|
149 |
+
self.epoch = epoch
|
GPT_SoVITS/AR/data/data_module.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
from pytorch_lightning import LightningDataModule
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
from AR.data.bucket_sampler import DistributedBucketSampler
|
7 |
+
from AR.data.dataset import Text2SemanticDataset
|
8 |
+
|
9 |
+
|
10 |
+
class Text2SemanticDataModule(LightningDataModule):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
config,
|
14 |
+
train_semantic_path,
|
15 |
+
train_phoneme_path,
|
16 |
+
dev_semantic_path=None,
|
17 |
+
dev_phoneme_path=None,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.config = config
|
21 |
+
self.train_semantic_path = train_semantic_path
|
22 |
+
self.train_phoneme_path = train_phoneme_path
|
23 |
+
self.dev_semantic_path = dev_semantic_path
|
24 |
+
self.dev_phoneme_path = dev_phoneme_path
|
25 |
+
self.num_workers = self.config["data"]["num_workers"]
|
26 |
+
|
27 |
+
def prepare_data(self):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def setup(self, stage=None, output_logs=False):
|
31 |
+
self._train_dataset = Text2SemanticDataset(
|
32 |
+
phoneme_path=self.train_phoneme_path,
|
33 |
+
semantic_path=self.train_semantic_path,
|
34 |
+
max_sec=self.config["data"]["max_sec"],
|
35 |
+
pad_val=self.config["data"]["pad_val"],
|
36 |
+
)
|
37 |
+
self._dev_dataset = self._train_dataset
|
38 |
+
# self._dev_dataset = Text2SemanticDataset(
|
39 |
+
# phoneme_path=self.dev_phoneme_path,
|
40 |
+
# semantic_path=self.dev_semantic_path,
|
41 |
+
# max_sample=self.config['data']['max_eval_sample'],
|
42 |
+
# max_sec=self.config['data']['max_sec'],
|
43 |
+
# pad_val=self.config['data']['pad_val'])
|
44 |
+
|
45 |
+
def train_dataloader(self):
|
46 |
+
batch_size = (
|
47 |
+
self.config["train"]["batch_size"] // 2
|
48 |
+
if self.config["train"].get("if_dpo", False) is True
|
49 |
+
else self.config["train"]["batch_size"]
|
50 |
+
)
|
51 |
+
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
52 |
+
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
53 |
+
return DataLoader(
|
54 |
+
self._train_dataset,
|
55 |
+
batch_size=batch_size,
|
56 |
+
sampler=sampler,
|
57 |
+
collate_fn=self._train_dataset.collate,
|
58 |
+
num_workers=self.num_workers,
|
59 |
+
persistent_workers=True,
|
60 |
+
prefetch_factor=16,
|
61 |
+
)
|
62 |
+
|
63 |
+
def val_dataloader(self):
|
64 |
+
return DataLoader(
|
65 |
+
self._dev_dataset,
|
66 |
+
batch_size=1,
|
67 |
+
shuffle=False,
|
68 |
+
collate_fn=self._train_dataset.collate,
|
69 |
+
num_workers=max(self.num_workers, 12),
|
70 |
+
persistent_workers=True,
|
71 |
+
prefetch_factor=16,
|
72 |
+
)
|
73 |
+
|
74 |
+
# 这个会使用到嘛?
|
75 |
+
def test_dataloader(self):
|
76 |
+
return DataLoader(
|
77 |
+
self._dev_dataset,
|
78 |
+
batch_size=1,
|
79 |
+
shuffle=False,
|
80 |
+
collate_fn=self._train_dataset.collate,
|
81 |
+
)
|
GPT_SoVITS/AR/data/dataset.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
|
4 |
+
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
5 |
+
import os
|
6 |
+
import traceback
|
7 |
+
from typing import Dict, List
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import DataLoader, Dataset
|
13 |
+
|
14 |
+
version = os.environ.get("version", None)
|
15 |
+
|
16 |
+
from text import cleaned_text_to_sequence
|
17 |
+
|
18 |
+
# from config import exp_dir
|
19 |
+
|
20 |
+
|
21 |
+
def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
|
22 |
+
seq = sequences[0]
|
23 |
+
ndim = seq.ndim
|
24 |
+
if axis < 0:
|
25 |
+
axis += ndim
|
26 |
+
dtype = seq.dtype
|
27 |
+
pad_value = dtype.type(pad_value)
|
28 |
+
seq_lengths = [seq.shape[axis] for seq in sequences]
|
29 |
+
max_length = np.max(seq_lengths)
|
30 |
+
|
31 |
+
padded_sequences = []
|
32 |
+
for seq, length in zip(sequences, seq_lengths):
|
33 |
+
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
34 |
+
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
35 |
+
padded_sequences.append(padded_seq)
|
36 |
+
batch = np.stack(padded_sequences)
|
37 |
+
return batch
|
38 |
+
|
39 |
+
|
40 |
+
class Text2SemanticDataset(Dataset):
|
41 |
+
"""dataset class for text tokens to semantic model training."""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
phoneme_path: str,
|
46 |
+
semantic_path: str,
|
47 |
+
max_sample: int = None,
|
48 |
+
max_sec: int = 100,
|
49 |
+
pad_val: int = 1024,
|
50 |
+
# min value of phoneme/sec
|
51 |
+
min_ps_ratio: int = 3,
|
52 |
+
# max value of phoneme/sec
|
53 |
+
max_ps_ratio: int = 25,
|
54 |
+
) -> None:
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
self.semantic_data = pd.read_csv(
|
58 |
+
semantic_path,
|
59 |
+
delimiter="\t",
|
60 |
+
encoding="utf-8",
|
61 |
+
)
|
62 |
+
# get dict
|
63 |
+
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
64 |
+
self.path3 = "%s/3-bert" % (
|
65 |
+
os.path.dirname(
|
66 |
+
phoneme_path,
|
67 |
+
)
|
68 |
+
) # "%s/3-bert"%exp_dir#bert_dir
|
69 |
+
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
70 |
+
assert os.path.exists(self.path2)
|
71 |
+
assert os.path.exists(self.path6)
|
72 |
+
self.phoneme_data = {}
|
73 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
74 |
+
lines = f.read().strip("\n").split("\n")
|
75 |
+
|
76 |
+
for line in lines:
|
77 |
+
tmp = line.split("\t")
|
78 |
+
if len(tmp) != 4:
|
79 |
+
continue
|
80 |
+
self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
|
81 |
+
|
82 |
+
# self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
|
83 |
+
# pad for semantic tokens
|
84 |
+
self.PAD: int = pad_val
|
85 |
+
# self.hz = 25
|
86 |
+
# with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
|
87 |
+
# data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
|
88 |
+
# self.hz=int(data[:-2])#
|
89 |
+
self.hz = int(os.environ.get("hz", "25hz")[:-2])
|
90 |
+
|
91 |
+
# max seconds of semantic token
|
92 |
+
self.max_sec = max_sec
|
93 |
+
self.min_ps_ratio = min_ps_ratio
|
94 |
+
self.max_ps_ratio = max_ps_ratio
|
95 |
+
|
96 |
+
if max_sample is not None:
|
97 |
+
self.semantic_data = self.semantic_data[:max_sample]
|
98 |
+
|
99 |
+
# {idx: (semantic, phoneme)}
|
100 |
+
# semantic list, phoneme list
|
101 |
+
self.semantic_phoneme = []
|
102 |
+
self.item_names = []
|
103 |
+
|
104 |
+
self.inited = False
|
105 |
+
|
106 |
+
if not self.inited:
|
107 |
+
# 调用初始化函数
|
108 |
+
self.init_batch()
|
109 |
+
self.inited = True
|
110 |
+
del self.semantic_data
|
111 |
+
del self.phoneme_data
|
112 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
|
113 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
|
114 |
+
|
115 |
+
def init_batch(self):
|
116 |
+
semantic_data_len = len(self.semantic_data)
|
117 |
+
phoneme_data_len = len(self.phoneme_data.keys())
|
118 |
+
print("semantic_data_len:", semantic_data_len)
|
119 |
+
print("phoneme_data_len:", phoneme_data_len)
|
120 |
+
print(self.semantic_data)
|
121 |
+
idx = 0
|
122 |
+
num_not_in = 0
|
123 |
+
num_deleted_bigger = 0
|
124 |
+
num_deleted_ps = 0
|
125 |
+
for i in range(semantic_data_len):
|
126 |
+
# 先依次遍历
|
127 |
+
# get str
|
128 |
+
item_name = self.semantic_data.iloc[i, 0]
|
129 |
+
# print(self.phoneme_data)
|
130 |
+
try:
|
131 |
+
phoneme, word2ph, text = self.phoneme_data[item_name]
|
132 |
+
except Exception:
|
133 |
+
traceback.print_exc()
|
134 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
135 |
+
num_not_in += 1
|
136 |
+
continue
|
137 |
+
|
138 |
+
semantic_str = self.semantic_data.iloc[i, 1]
|
139 |
+
# get token list
|
140 |
+
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
141 |
+
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
142 |
+
# 过滤掉太长的样��
|
143 |
+
if (
|
144 |
+
len(semantic_ids) > self.max_sec * self.hz
|
145 |
+
): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
|
146 |
+
num_deleted_bigger += 1
|
147 |
+
continue
|
148 |
+
# (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
|
149 |
+
phoneme = phoneme.split(" ")
|
150 |
+
|
151 |
+
try:
|
152 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
153 |
+
except:
|
154 |
+
traceback.print_exc()
|
155 |
+
# print(f"{item_name} not in self.phoneme_data !")
|
156 |
+
num_not_in += 1
|
157 |
+
continue
|
158 |
+
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
159 |
+
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
160 |
+
num_deleted_ps += 1
|
161 |
+
continue
|
162 |
+
# if len(semantic_ids) > 1000:###########3
|
163 |
+
# num_deleted_bigger += 1
|
164 |
+
# continue
|
165 |
+
|
166 |
+
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
167 |
+
|
168 |
+
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
169 |
+
num_deleted_ps += 1
|
170 |
+
# print(item_name)
|
171 |
+
continue
|
172 |
+
|
173 |
+
self.semantic_phoneme.append((semantic_ids, phoneme_ids))
|
174 |
+
idx += 1
|
175 |
+
self.item_names.append(item_name)
|
176 |
+
|
177 |
+
min_num = 100 # 20直接不补#30补了也不存ckpt
|
178 |
+
leng = len(self.semantic_phoneme)
|
179 |
+
if leng < min_num:
|
180 |
+
tmp1 = self.semantic_phoneme
|
181 |
+
tmp2 = self.item_names
|
182 |
+
self.semantic_phoneme = []
|
183 |
+
self.item_names = []
|
184 |
+
for _ in range(max(2, int(min_num / leng))):
|
185 |
+
self.semantic_phoneme += tmp1
|
186 |
+
self.item_names += tmp2
|
187 |
+
if num_not_in > 0:
|
188 |
+
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
189 |
+
if num_deleted_bigger > 0:
|
190 |
+
print(
|
191 |
+
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
192 |
+
)
|
193 |
+
if num_deleted_ps > 0:
|
194 |
+
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
195 |
+
print(
|
196 |
+
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
197 |
+
)
|
198 |
+
"""
|
199 |
+
there are 31 semantic datas not in phoneme datas
|
200 |
+
deleted 34 audios who's duration are bigger than 54 seconds
|
201 |
+
deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
|
202 |
+
dataset.__len__(): 366463
|
203 |
+
|
204 |
+
"""
|
205 |
+
# 345410 for LibriTTS
|
206 |
+
print("dataset.__len__():", self.__len__())
|
207 |
+
|
208 |
+
def __get_item_names__(self) -> List[str]:
|
209 |
+
return self.item_names
|
210 |
+
|
211 |
+
def __len__(self) -> int:
|
212 |
+
return len(self.semantic_phoneme)
|
213 |
+
|
214 |
+
def __getitem__(self, idx: int) -> Dict:
|
215 |
+
semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
|
216 |
+
item_name = self.item_names[idx]
|
217 |
+
phoneme_ids_len = len(phoneme_ids)
|
218 |
+
# semantic tokens target
|
219 |
+
semantic_ids_len = len(semantic_ids)
|
220 |
+
|
221 |
+
flag = 0
|
222 |
+
path_bert = "%s/%s.pt" % (self.path3, item_name)
|
223 |
+
if os.path.exists(path_bert) == True:
|
224 |
+
bert_feature = torch.load(path_bert, map_location="cpu")
|
225 |
+
else:
|
226 |
+
flag = 1
|
227 |
+
if flag == 1:
|
228 |
+
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
|
229 |
+
bert_feature = None
|
230 |
+
else:
|
231 |
+
assert bert_feature.shape[-1] == len(phoneme_ids)
|
232 |
+
return {
|
233 |
+
"idx": idx,
|
234 |
+
"phoneme_ids": phoneme_ids,
|
235 |
+
"phoneme_ids_len": phoneme_ids_len,
|
236 |
+
"semantic_ids": semantic_ids,
|
237 |
+
"semantic_ids_len": semantic_ids_len,
|
238 |
+
"bert_feature": bert_feature,
|
239 |
+
}
|
240 |
+
|
241 |
+
def get_sample_length(self, idx: int):
|
242 |
+
semantic_ids = self.semantic_phoneme[idx][0]
|
243 |
+
sec = 1.0 * len(semantic_ids) / self.hz
|
244 |
+
return sec
|
245 |
+
|
246 |
+
def collate(self, examples: List[Dict]) -> Dict:
|
247 |
+
sample_index: List[int] = []
|
248 |
+
phoneme_ids: List[torch.Tensor] = []
|
249 |
+
phoneme_ids_lens: List[int] = []
|
250 |
+
semantic_ids: List[torch.Tensor] = []
|
251 |
+
semantic_ids_lens: List[int] = []
|
252 |
+
# return
|
253 |
+
|
254 |
+
for item in examples:
|
255 |
+
sample_index.append(item["idx"])
|
256 |
+
phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
|
257 |
+
semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
|
258 |
+
phoneme_ids_lens.append(item["phoneme_ids_len"])
|
259 |
+
semantic_ids_lens.append(item["semantic_ids_len"])
|
260 |
+
|
261 |
+
# pad 0
|
262 |
+
phoneme_ids = batch_sequences(phoneme_ids)
|
263 |
+
semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
|
264 |
+
|
265 |
+
# # convert each batch to torch.tensor
|
266 |
+
phoneme_ids = torch.tensor(phoneme_ids)
|
267 |
+
semantic_ids = torch.tensor(semantic_ids)
|
268 |
+
phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
|
269 |
+
semantic_ids_lens = torch.tensor(semantic_ids_lens)
|
270 |
+
bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
|
271 |
+
bert_padded.zero_()
|
272 |
+
|
273 |
+
for idx, item in enumerate(examples):
|
274 |
+
bert = item["bert_feature"]
|
275 |
+
if bert != None:
|
276 |
+
bert_padded[idx, :, : bert.shape[-1]] = bert
|
277 |
+
|
278 |
+
return {
|
279 |
+
# List[int]
|
280 |
+
"ids": sample_index,
|
281 |
+
# torch.Tensor (B, max_phoneme_length)
|
282 |
+
"phoneme_ids": phoneme_ids,
|
283 |
+
# torch.Tensor (B)
|
284 |
+
"phoneme_ids_len": phoneme_ids_lens,
|
285 |
+
# torch.Tensor (B, max_semantic_ids_length)
|
286 |
+
"semantic_ids": semantic_ids,
|
287 |
+
# torch.Tensor (B)
|
288 |
+
"semantic_ids_len": semantic_ids_lens,
|
289 |
+
# torch.Tensor (B, 1024, max_phoneme_length)
|
290 |
+
"bert_feature": bert_padded,
|
291 |
+
}
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
|
296 |
+
dataset = Text2SemanticDataset(
|
297 |
+
phoneme_path=root_dir + "phoneme_train.npy",
|
298 |
+
semantic_path=root_dir + "semantic_train.tsv",
|
299 |
+
)
|
300 |
+
|
301 |
+
batch_size = 12
|
302 |
+
dataloader = DataLoader(
|
303 |
+
dataset,
|
304 |
+
batch_size=batch_size,
|
305 |
+
collate_fn=dataset.collate,
|
306 |
+
shuffle=False,
|
307 |
+
)
|
308 |
+
for i, batch in enumerate(dataloader):
|
309 |
+
if i % 1000 == 0:
|
310 |
+
print(i)
|
311 |
+
# if i == 0:
|
312 |
+
# print('batch["ids"]:', batch["ids"])
|
313 |
+
# print('batch["phoneme_ids"]:', batch["phoneme_ids"],
|
314 |
+
# batch["phoneme_ids"].shape)
|
315 |
+
# print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
|
316 |
+
# batch["phoneme_ids_len"].shape)
|
317 |
+
# print('batch["semantic_ids"]:', batch["semantic_ids"],
|
318 |
+
# batch["semantic_ids"].shape)
|
319 |
+
# print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
|
320 |
+
# batch["semantic_ids_len"].shape)
|
GPT_SoVITS/AR/models/__init__.py
ADDED
File without changes
|
GPT_SoVITS/AR/models/t2s_lightning_module.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
now_dir = os.getcwd()
|
7 |
+
sys.path.append(now_dir)
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from pytorch_lightning import LightningModule
|
12 |
+
|
13 |
+
from AR.models.t2s_model import Text2SemanticDecoder
|
14 |
+
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
15 |
+
from AR.modules.optim import ScaledAdam
|
16 |
+
|
17 |
+
|
18 |
+
class Text2SemanticLightningModule(LightningModule):
|
19 |
+
def __init__(self, config, output_dir, is_train=True):
|
20 |
+
super().__init__()
|
21 |
+
self.config = config
|
22 |
+
self.top_k = 3
|
23 |
+
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
24 |
+
pretrained_s1 = config.get("pretrained_s1")
|
25 |
+
if pretrained_s1 and is_train:
|
26 |
+
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
27 |
+
print(
|
28 |
+
self.load_state_dict(
|
29 |
+
torch.load(
|
30 |
+
pretrained_s1,
|
31 |
+
map_location="cpu",
|
32 |
+
)["weight"],
|
33 |
+
)
|
34 |
+
)
|
35 |
+
if is_train:
|
36 |
+
self.automatic_optimization = False
|
37 |
+
self.save_hyperparameters()
|
38 |
+
self.eval_dir = output_dir / "eval"
|
39 |
+
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
40 |
+
|
41 |
+
def training_step(self, batch: Dict, batch_idx: int):
|
42 |
+
opt = self.optimizers()
|
43 |
+
scheduler = self.lr_schedulers()
|
44 |
+
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
45 |
+
loss, acc = forward(
|
46 |
+
batch["phoneme_ids"],
|
47 |
+
batch["phoneme_ids_len"],
|
48 |
+
batch["semantic_ids"],
|
49 |
+
batch["semantic_ids_len"],
|
50 |
+
batch["bert_feature"],
|
51 |
+
)
|
52 |
+
self.manual_backward(loss)
|
53 |
+
if batch_idx > 0 and batch_idx % 4 == 0:
|
54 |
+
opt.step()
|
55 |
+
opt.zero_grad()
|
56 |
+
scheduler.step()
|
57 |
+
|
58 |
+
self.log(
|
59 |
+
"total_loss",
|
60 |
+
loss,
|
61 |
+
on_step=True,
|
62 |
+
on_epoch=True,
|
63 |
+
prog_bar=True,
|
64 |
+
sync_dist=True,
|
65 |
+
)
|
66 |
+
self.log(
|
67 |
+
"lr",
|
68 |
+
scheduler.get_last_lr()[0],
|
69 |
+
on_epoch=True,
|
70 |
+
prog_bar=True,
|
71 |
+
sync_dist=True,
|
72 |
+
)
|
73 |
+
self.log(
|
74 |
+
f"top_{self.top_k}_acc",
|
75 |
+
acc,
|
76 |
+
on_step=True,
|
77 |
+
on_epoch=True,
|
78 |
+
prog_bar=True,
|
79 |
+
sync_dist=True,
|
80 |
+
)
|
81 |
+
|
82 |
+
def validation_step(self, batch: Dict, batch_idx: int):
|
83 |
+
return
|
84 |
+
|
85 |
+
# # get loss
|
86 |
+
# loss, acc = self.model.forward(
|
87 |
+
# batch['phoneme_ids'], batch['phoneme_ids_len'],
|
88 |
+
# batch['semantic_ids'], batch['semantic_ids_len'],
|
89 |
+
# batch['bert_feature']
|
90 |
+
# )
|
91 |
+
#
|
92 |
+
# self.log(
|
93 |
+
# "val_total_loss",
|
94 |
+
# loss,
|
95 |
+
# on_step=True,
|
96 |
+
# on_epoch=True,
|
97 |
+
# prog_bar=True,
|
98 |
+
# sync_dist=True)
|
99 |
+
# self.log(
|
100 |
+
# f"val_top_{self.top_k}_acc",
|
101 |
+
# acc,
|
102 |
+
# on_step=True,
|
103 |
+
# on_epoch=True,
|
104 |
+
# prog_bar=True,
|
105 |
+
# sync_dist=True)
|
106 |
+
#
|
107 |
+
# # get infer output
|
108 |
+
# semantic_len = batch['semantic_ids'].size(1)
|
109 |
+
# prompt_len = min(int(semantic_len * 0.5), 150)
|
110 |
+
# prompt = batch['semantic_ids'][:, :prompt_len]
|
111 |
+
# pred_semantic = self.model.infer(batch['phoneme_ids'],
|
112 |
+
# batch['phoneme_ids_len'], prompt,
|
113 |
+
# batch['bert_feature']
|
114 |
+
# )
|
115 |
+
# save_name = f'semantic_toks_{batch_idx}.pt'
|
116 |
+
# save_path = os.path.join(self.eval_dir, save_name)
|
117 |
+
# torch.save(pred_semantic.detach().cpu(), save_path)
|
118 |
+
|
119 |
+
def configure_optimizers(self):
|
120 |
+
model_parameters = self.model.parameters()
|
121 |
+
parameters_names = []
|
122 |
+
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
123 |
+
lm_opt = ScaledAdam(
|
124 |
+
model_parameters,
|
125 |
+
lr=0.01,
|
126 |
+
betas=(0.9, 0.95),
|
127 |
+
clipping_scale=2.0,
|
128 |
+
parameters_names=parameters_names,
|
129 |
+
show_dominant_parameters=False,
|
130 |
+
clipping_update_period=1000,
|
131 |
+
)
|
132 |
+
|
133 |
+
return {
|
134 |
+
"optimizer": lm_opt,
|
135 |
+
"lr_scheduler": {
|
136 |
+
"scheduler": WarmupCosineLRSchedule(
|
137 |
+
lm_opt,
|
138 |
+
init_lr=self.config["optimizer"]["lr_init"],
|
139 |
+
peak_lr=self.config["optimizer"]["lr"],
|
140 |
+
end_lr=self.config["optimizer"]["lr_end"],
|
141 |
+
warmup_steps=self.config["optimizer"]["warmup_steps"],
|
142 |
+
total_steps=self.config["optimizer"]["decay_steps"],
|
143 |
+
)
|
144 |
+
},
|
145 |
+
}
|
GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
now_dir = os.getcwd()
|
7 |
+
sys.path.append(now_dir)
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from pytorch_lightning import LightningModule
|
12 |
+
|
13 |
+
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
14 |
+
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
15 |
+
from AR.modules.optim import ScaledAdam
|
16 |
+
|
17 |
+
|
18 |
+
class Text2SemanticLightningModule(LightningModule):
|
19 |
+
def __init__(self, config, output_dir, is_train=True):
|
20 |
+
super().__init__()
|
21 |
+
self.config = config
|
22 |
+
self.top_k = 3
|
23 |
+
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
24 |
+
pretrained_s1 = config.get("pretrained_s1")
|
25 |
+
if pretrained_s1 and is_train:
|
26 |
+
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
27 |
+
print(
|
28 |
+
self.load_state_dict(
|
29 |
+
torch.load(
|
30 |
+
pretrained_s1,
|
31 |
+
map_location="cpu",
|
32 |
+
)["weight"],
|
33 |
+
),
|
34 |
+
)
|
35 |
+
if is_train:
|
36 |
+
self.automatic_optimization = False
|
37 |
+
self.save_hyperparameters()
|
38 |
+
self.eval_dir = output_dir / "eval"
|
39 |
+
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
40 |
+
|
41 |
+
def training_step(self, batch: Dict, batch_idx: int):
|
42 |
+
opt = self.optimizers()
|
43 |
+
scheduler = self.lr_schedulers()
|
44 |
+
loss, acc = self.model.forward(
|
45 |
+
batch["phoneme_ids"],
|
46 |
+
batch["phoneme_ids_len"],
|
47 |
+
batch["semantic_ids"],
|
48 |
+
batch["semantic_ids_len"],
|
49 |
+
batch["bert_feature"],
|
50 |
+
)
|
51 |
+
self.manual_backward(loss)
|
52 |
+
if batch_idx > 0 and batch_idx % 4 == 0:
|
53 |
+
opt.step()
|
54 |
+
opt.zero_grad()
|
55 |
+
scheduler.step()
|
56 |
+
|
57 |
+
self.log(
|
58 |
+
"total_loss",
|
59 |
+
loss,
|
60 |
+
on_step=True,
|
61 |
+
on_epoch=True,
|
62 |
+
prog_bar=True,
|
63 |
+
sync_dist=True,
|
64 |
+
)
|
65 |
+
self.log(
|
66 |
+
"lr",
|
67 |
+
scheduler.get_last_lr()[0],
|
68 |
+
on_epoch=True,
|
69 |
+
prog_bar=True,
|
70 |
+
sync_dist=True,
|
71 |
+
)
|
72 |
+
self.log(
|
73 |
+
f"top_{self.top_k}_acc",
|
74 |
+
acc,
|
75 |
+
on_step=True,
|
76 |
+
on_epoch=True,
|
77 |
+
prog_bar=True,
|
78 |
+
sync_dist=True,
|
79 |
+
)
|
80 |
+
|
81 |
+
def validation_step(self, batch: Dict, batch_idx: int):
|
82 |
+
return
|
83 |
+
|
84 |
+
def configure_optimizers(self):
|
85 |
+
model_parameters = self.model.parameters()
|
86 |
+
parameters_names = []
|
87 |
+
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
88 |
+
lm_opt = ScaledAdam(
|
89 |
+
model_parameters,
|
90 |
+
lr=0.01,
|
91 |
+
betas=(0.9, 0.95),
|
92 |
+
clipping_scale=2.0,
|
93 |
+
parameters_names=parameters_names,
|
94 |
+
show_dominant_parameters=False,
|
95 |
+
clipping_update_period=1000,
|
96 |
+
)
|
97 |
+
|
98 |
+
return {
|
99 |
+
"optimizer": lm_opt,
|
100 |
+
"lr_scheduler": {
|
101 |
+
"scheduler": WarmupCosineLRSchedule(
|
102 |
+
lm_opt,
|
103 |
+
init_lr=self.config["optimizer"]["lr_init"],
|
104 |
+
peak_lr=self.config["optimizer"]["lr"],
|
105 |
+
end_lr=self.config["optimizer"]["lr_end"],
|
106 |
+
warmup_steps=self.config["optimizer"]["warmup_steps"],
|
107 |
+
total_steps=self.config["optimizer"]["decay_steps"],
|
108 |
+
)
|
109 |
+
},
|
110 |
+
}
|
GPT_SoVITS/AR/models/t2s_model.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import math
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torchmetrics.classification import MulticlassAccuracy
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from AR.models.utils import (
|
13 |
+
dpo_loss,
|
14 |
+
get_batch_logps,
|
15 |
+
make_pad_mask,
|
16 |
+
make_pad_mask_left,
|
17 |
+
make_reject_y,
|
18 |
+
sample,
|
19 |
+
topk_sampling,
|
20 |
+
)
|
21 |
+
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
22 |
+
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
23 |
+
|
24 |
+
default_config = {
|
25 |
+
"embedding_dim": 512,
|
26 |
+
"hidden_dim": 512,
|
27 |
+
"num_head": 8,
|
28 |
+
"num_layers": 12,
|
29 |
+
"num_codebook": 8,
|
30 |
+
"p_dropout": 0.0,
|
31 |
+
"vocab_size": 1024 + 1,
|
32 |
+
"phoneme_vocab_size": 512,
|
33 |
+
"EOS": 1024,
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
38 |
+
# Efficient implementation equivalent to the following:
|
39 |
+
def scaled_dot_product_attention(
|
40 |
+
query: torch.Tensor,
|
41 |
+
key: torch.Tensor,
|
42 |
+
value: torch.Tensor,
|
43 |
+
attn_mask: Optional[torch.Tensor] = None,
|
44 |
+
scale: Optional[torch.Tensor] = None,
|
45 |
+
) -> torch.Tensor:
|
46 |
+
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
47 |
+
if scale is None:
|
48 |
+
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
49 |
+
else:
|
50 |
+
scale_factor = scale
|
51 |
+
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
|
52 |
+
|
53 |
+
if attn_mask is not None:
|
54 |
+
if attn_mask.dtype == torch.bool:
|
55 |
+
attn_bias.masked_fill_(attn_mask, float("-inf"))
|
56 |
+
else:
|
57 |
+
attn_bias += attn_mask
|
58 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
59 |
+
attn_weight += attn_bias
|
60 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
61 |
+
|
62 |
+
if attn_mask is not None:
|
63 |
+
if attn_mask.dtype == torch.bool:
|
64 |
+
attn_weight.masked_fill_(attn_mask, 0)
|
65 |
+
else:
|
66 |
+
attn_mask[attn_mask != float("-inf")] = 0
|
67 |
+
attn_mask[attn_mask == float("-inf")] = 1
|
68 |
+
attn_weight.masked_fill_(attn_mask, 0)
|
69 |
+
|
70 |
+
return attn_weight @ value
|
71 |
+
|
72 |
+
|
73 |
+
@torch.jit.script
|
74 |
+
class T2SMLP:
|
75 |
+
def __init__(self, w1, b1, w2, b2):
|
76 |
+
self.w1 = w1
|
77 |
+
self.b1 = b1
|
78 |
+
self.w2 = w2
|
79 |
+
self.b2 = b2
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = F.relu(F.linear(x, self.w1, self.b1))
|
83 |
+
x = F.linear(x, self.w2, self.b2)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
@torch.jit.script
|
88 |
+
class T2SBlock:
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
num_heads,
|
92 |
+
hidden_dim: int,
|
93 |
+
mlp: T2SMLP,
|
94 |
+
qkv_w,
|
95 |
+
qkv_b,
|
96 |
+
out_w,
|
97 |
+
out_b,
|
98 |
+
norm_w1,
|
99 |
+
norm_b1,
|
100 |
+
norm_eps1,
|
101 |
+
norm_w2,
|
102 |
+
norm_b2,
|
103 |
+
norm_eps2,
|
104 |
+
):
|
105 |
+
self.num_heads = num_heads
|
106 |
+
self.mlp = mlp
|
107 |
+
self.hidden_dim: int = hidden_dim
|
108 |
+
self.qkv_w = qkv_w
|
109 |
+
self.qkv_b = qkv_b
|
110 |
+
self.out_w = out_w
|
111 |
+
self.out_b = out_b
|
112 |
+
self.norm_w1 = norm_w1
|
113 |
+
self.norm_b1 = norm_b1
|
114 |
+
self.norm_eps1 = norm_eps1
|
115 |
+
self.norm_w2 = norm_w2
|
116 |
+
self.norm_b2 = norm_b2
|
117 |
+
self.norm_eps2 = norm_eps2
|
118 |
+
|
119 |
+
self.false = torch.tensor(False, dtype=torch.bool)
|
120 |
+
|
121 |
+
@torch.jit.ignore
|
122 |
+
def to_mask(
|
123 |
+
self,
|
124 |
+
x: torch.Tensor,
|
125 |
+
padding_mask: Optional[torch.Tensor],
|
126 |
+
):
|
127 |
+
if padding_mask is None:
|
128 |
+
return x
|
129 |
+
|
130 |
+
if padding_mask.dtype == torch.bool:
|
131 |
+
return x.masked_fill(padding_mask, 0)
|
132 |
+
else:
|
133 |
+
return x * padding_mask
|
134 |
+
|
135 |
+
def process_prompt(
|
136 |
+
self,
|
137 |
+
x: torch.Tensor,
|
138 |
+
attn_mask: torch.Tensor,
|
139 |
+
padding_mask: Optional[torch.Tensor] = None,
|
140 |
+
torch_sdpa: bool = True,
|
141 |
+
):
|
142 |
+
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
143 |
+
|
144 |
+
batch_size = q.shape[0]
|
145 |
+
q_len = q.shape[1]
|
146 |
+
kv_len = k.shape[1]
|
147 |
+
|
148 |
+
q = self.to_mask(q, padding_mask)
|
149 |
+
k_cache = self.to_mask(k, padding_mask)
|
150 |
+
v_cache = self.to_mask(v, padding_mask)
|
151 |
+
|
152 |
+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
153 |
+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
154 |
+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
155 |
+
|
156 |
+
if torch_sdpa:
|
157 |
+
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
158 |
+
else:
|
159 |
+
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
160 |
+
|
161 |
+
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
162 |
+
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
163 |
+
|
164 |
+
x = x + attn
|
165 |
+
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
166 |
+
x = x + self.mlp.forward(x)
|
167 |
+
x = F.layer_norm(
|
168 |
+
x,
|
169 |
+
[self.hidden_dim],
|
170 |
+
self.norm_w2,
|
171 |
+
self.norm_b2,
|
172 |
+
self.norm_eps2,
|
173 |
+
)
|
174 |
+
return x, k_cache, v_cache
|
175 |
+
|
176 |
+
def decode_next_token(
|
177 |
+
self,
|
178 |
+
x: torch.Tensor,
|
179 |
+
k_cache: torch.Tensor,
|
180 |
+
v_cache: torch.Tensor,
|
181 |
+
attn_mask: torch.Tensor = None,
|
182 |
+
torch_sdpa: bool = True,
|
183 |
+
):
|
184 |
+
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
185 |
+
|
186 |
+
k_cache = torch.cat([k_cache, k], dim=1)
|
187 |
+
v_cache = torch.cat([v_cache, v], dim=1)
|
188 |
+
|
189 |
+
batch_size = q.shape[0]
|
190 |
+
q_len = q.shape[1]
|
191 |
+
kv_len = k_cache.shape[1]
|
192 |
+
|
193 |
+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
194 |
+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
195 |
+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
196 |
+
|
197 |
+
if torch_sdpa:
|
198 |
+
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
199 |
+
else:
|
200 |
+
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
201 |
+
|
202 |
+
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
203 |
+
attn = F.linear(attn, self.out_w, self.out_b)
|
204 |
+
|
205 |
+
x = x + attn
|
206 |
+
x = F.layer_norm(
|
207 |
+
x,
|
208 |
+
[self.hidden_dim],
|
209 |
+
self.norm_w1,
|
210 |
+
self.norm_b1,
|
211 |
+
self.norm_eps1,
|
212 |
+
)
|
213 |
+
x = x + self.mlp.forward(x)
|
214 |
+
x = F.layer_norm(
|
215 |
+
x,
|
216 |
+
[self.hidden_dim],
|
217 |
+
self.norm_w2,
|
218 |
+
self.norm_b2,
|
219 |
+
self.norm_eps2,
|
220 |
+
)
|
221 |
+
return x, k_cache, v_cache
|
222 |
+
|
223 |
+
|
224 |
+
@torch.jit.script
|
225 |
+
class T2STransformer:
|
226 |
+
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
227 |
+
self.num_blocks: int = num_blocks
|
228 |
+
self.blocks = blocks
|
229 |
+
|
230 |
+
def process_prompt(
|
231 |
+
self,
|
232 |
+
x: torch.Tensor,
|
233 |
+
attn_mask: torch.Tensor,
|
234 |
+
padding_mask: Optional[torch.Tensor] = None,
|
235 |
+
torch_sdpa: bool = True,
|
236 |
+
):
|
237 |
+
k_cache: List[torch.Tensor] = []
|
238 |
+
v_cache: List[torch.Tensor] = []
|
239 |
+
for i in range(self.num_blocks):
|
240 |
+
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
241 |
+
k_cache.append(k_cache_)
|
242 |
+
v_cache.append(v_cache_)
|
243 |
+
return x, k_cache, v_cache
|
244 |
+
|
245 |
+
def decode_next_token(
|
246 |
+
self,
|
247 |
+
x: torch.Tensor,
|
248 |
+
k_cache: List[torch.Tensor],
|
249 |
+
v_cache: List[torch.Tensor],
|
250 |
+
attn_mask: torch.Tensor = None,
|
251 |
+
torch_sdpa: bool = True,
|
252 |
+
):
|
253 |
+
for i in range(self.num_blocks):
|
254 |
+
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
255 |
+
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
256 |
+
)
|
257 |
+
return x, k_cache, v_cache
|
258 |
+
|
259 |
+
|
260 |
+
class Text2SemanticDecoder(nn.Module):
|
261 |
+
def __init__(self, config, norm_first=False, top_k=3):
|
262 |
+
super(Text2SemanticDecoder, self).__init__()
|
263 |
+
self.model_dim = config["model"]["hidden_dim"]
|
264 |
+
self.embedding_dim = config["model"]["embedding_dim"]
|
265 |
+
self.num_head = config["model"]["head"]
|
266 |
+
self.num_layers = config["model"]["n_layer"]
|
267 |
+
self.norm_first = norm_first
|
268 |
+
self.vocab_size = config["model"]["vocab_size"]
|
269 |
+
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
270 |
+
self.p_dropout = config["model"]["dropout"]
|
271 |
+
self.EOS = config["model"]["EOS"]
|
272 |
+
self.norm_first = norm_first
|
273 |
+
assert self.EOS == self.vocab_size - 1
|
274 |
+
# should be same as num of kmeans bin
|
275 |
+
# assert self.EOS == 1024
|
276 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
277 |
+
self.ar_text_embedding = TokenEmbedding(
|
278 |
+
self.embedding_dim,
|
279 |
+
self.phoneme_vocab_size,
|
280 |
+
self.p_dropout,
|
281 |
+
)
|
282 |
+
self.ar_text_position = SinePositionalEmbedding(
|
283 |
+
self.embedding_dim,
|
284 |
+
dropout=0.1,
|
285 |
+
scale=False,
|
286 |
+
alpha=True,
|
287 |
+
)
|
288 |
+
self.ar_audio_embedding = TokenEmbedding(
|
289 |
+
self.embedding_dim,
|
290 |
+
self.vocab_size,
|
291 |
+
self.p_dropout,
|
292 |
+
)
|
293 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
294 |
+
self.embedding_dim,
|
295 |
+
dropout=0.1,
|
296 |
+
scale=False,
|
297 |
+
alpha=True,
|
298 |
+
)
|
299 |
+
|
300 |
+
self.h = TransformerEncoder(
|
301 |
+
TransformerEncoderLayer(
|
302 |
+
d_model=self.model_dim,
|
303 |
+
nhead=self.num_head,
|
304 |
+
dim_feedforward=self.model_dim * 4,
|
305 |
+
dropout=0.1,
|
306 |
+
batch_first=True,
|
307 |
+
norm_first=norm_first,
|
308 |
+
),
|
309 |
+
num_layers=self.num_layers,
|
310 |
+
norm=LayerNorm(self.model_dim) if norm_first else None,
|
311 |
+
)
|
312 |
+
|
313 |
+
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
314 |
+
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
315 |
+
|
316 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
317 |
+
self.vocab_size,
|
318 |
+
top_k=top_k,
|
319 |
+
average="micro",
|
320 |
+
multidim_average="global",
|
321 |
+
ignore_index=self.EOS,
|
322 |
+
)
|
323 |
+
|
324 |
+
blocks = []
|
325 |
+
|
326 |
+
for i in range(self.num_layers):
|
327 |
+
layer = self.h.layers[i]
|
328 |
+
t2smlp = T2SMLP(
|
329 |
+
layer.linear1.weight,
|
330 |
+
layer.linear1.bias,
|
331 |
+
layer.linear2.weight,
|
332 |
+
layer.linear2.bias,
|
333 |
+
)
|
334 |
+
|
335 |
+
block = T2SBlock(
|
336 |
+
self.num_head,
|
337 |
+
self.model_dim,
|
338 |
+
t2smlp,
|
339 |
+
layer.self_attn.in_proj_weight,
|
340 |
+
layer.self_attn.in_proj_bias,
|
341 |
+
layer.self_attn.out_proj.weight,
|
342 |
+
layer.self_attn.out_proj.bias,
|
343 |
+
layer.norm1.weight,
|
344 |
+
layer.norm1.bias,
|
345 |
+
layer.norm1.eps,
|
346 |
+
layer.norm2.weight,
|
347 |
+
layer.norm2.bias,
|
348 |
+
layer.norm2.eps,
|
349 |
+
)
|
350 |
+
|
351 |
+
blocks.append(block)
|
352 |
+
|
353 |
+
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
354 |
+
|
355 |
+
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
356 |
+
x = self.ar_text_embedding(x)
|
357 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
358 |
+
x = self.ar_text_position(x)
|
359 |
+
x_mask = make_pad_mask(x_lens)
|
360 |
+
|
361 |
+
y_mask = make_pad_mask(y_lens)
|
362 |
+
y_mask_int = y_mask.type(torch.int64)
|
363 |
+
codes = y.type(torch.int64) * (1 - y_mask_int)
|
364 |
+
|
365 |
+
# Training
|
366 |
+
# AR Decoder
|
367 |
+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
368 |
+
x_len = x_lens.max()
|
369 |
+
y_len = y_lens.max()
|
370 |
+
y_emb = self.ar_audio_embedding(y)
|
371 |
+
y_pos = self.ar_audio_position(y_emb)
|
372 |
+
|
373 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
374 |
+
|
375 |
+
ar_xy_padding_mask = xy_padding_mask
|
376 |
+
|
377 |
+
x_attn_mask = F.pad(
|
378 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
379 |
+
(0, y_len),
|
380 |
+
value=True,
|
381 |
+
)
|
382 |
+
# x_attn_mask[:, x_len]=False
|
383 |
+
y_attn_mask = F.pad(
|
384 |
+
torch.triu(
|
385 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
386 |
+
diagonal=1,
|
387 |
+
),
|
388 |
+
(x_len, 0),
|
389 |
+
value=False,
|
390 |
+
)
|
391 |
+
|
392 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
393 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
394 |
+
_xy_padding_mask = (
|
395 |
+
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
396 |
+
.expand(-1, self.num_head, -1, -1)
|
397 |
+
.reshape(bsz * self.num_head, 1, src_len)
|
398 |
+
)
|
399 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
400 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
401 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
402 |
+
xy_attn_mask = new_attn_mask
|
403 |
+
# x 和完整的 y 一次性输入模型
|
404 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
405 |
+
|
406 |
+
return xy_pos, xy_attn_mask, targets
|
407 |
+
|
408 |
+
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
409 |
+
"""
|
410 |
+
x: phoneme_ids
|
411 |
+
y: semantic_ids
|
412 |
+
"""
|
413 |
+
|
414 |
+
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
415 |
+
|
416 |
+
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
|
417 |
+
|
418 |
+
xy_dec, _ = self.h(
|
419 |
+
(xy_pos, None),
|
420 |
+
mask=xy_attn_mask,
|
421 |
+
)
|
422 |
+
x_len = x_lens.max()
|
423 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
424 |
+
|
425 |
+
###### DPO #############
|
426 |
+
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
427 |
+
x, x_lens, reject_y, reject_y_lens, bert_feature
|
428 |
+
)
|
429 |
+
|
430 |
+
reject_xy_dec, _ = self.h(
|
431 |
+
(reject_xy_pos, None),
|
432 |
+
mask=reject_xy_attn_mask,
|
433 |
+
)
|
434 |
+
x_len = x_lens.max()
|
435 |
+
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
|
436 |
+
|
437 |
+
# loss
|
438 |
+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
439 |
+
|
440 |
+
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
|
441 |
+
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
|
442 |
+
|
443 |
+
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
444 |
+
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
445 |
+
|
446 |
+
loss = loss_1 + loss_2
|
447 |
+
|
448 |
+
return loss, acc
|
449 |
+
|
450 |
+
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
|
451 |
+
"""
|
452 |
+
x: phoneme_ids
|
453 |
+
y: semantic_ids
|
454 |
+
"""
|
455 |
+
x = self.ar_text_embedding(x)
|
456 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
457 |
+
x = self.ar_text_position(x)
|
458 |
+
x_mask = make_pad_mask(x_lens)
|
459 |
+
|
460 |
+
y_mask = make_pad_mask(y_lens)
|
461 |
+
y_mask_int = y_mask.type(torch.int64)
|
462 |
+
codes = y.type(torch.int64) * (1 - y_mask_int)
|
463 |
+
|
464 |
+
# Training
|
465 |
+
# AR Decoder
|
466 |
+
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
467 |
+
x_len = x_lens.max()
|
468 |
+
y_len = y_lens.max()
|
469 |
+
y_emb = self.ar_audio_embedding(y)
|
470 |
+
y_pos = self.ar_audio_position(y_emb)
|
471 |
+
|
472 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
473 |
+
ar_xy_padding_mask = xy_padding_mask
|
474 |
+
|
475 |
+
x_attn_mask = F.pad(
|
476 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
477 |
+
(0, y_len),
|
478 |
+
value=True,
|
479 |
+
)
|
480 |
+
y_attn_mask = F.pad(
|
481 |
+
torch.triu(
|
482 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
483 |
+
diagonal=1,
|
484 |
+
),
|
485 |
+
(x_len, 0),
|
486 |
+
value=False,
|
487 |
+
)
|
488 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
489 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
490 |
+
_xy_padding_mask = (
|
491 |
+
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
492 |
+
.expand(-1, self.num_head, -1, -1)
|
493 |
+
.reshape(bsz * self.num_head, 1, src_len)
|
494 |
+
)
|
495 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
496 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
497 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
498 |
+
xy_attn_mask = new_attn_mask
|
499 |
+
# x 和完整的 y 一次性输入模型
|
500 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
501 |
+
xy_dec, _ = self.h(
|
502 |
+
(xy_pos, None),
|
503 |
+
mask=xy_attn_mask,
|
504 |
+
)
|
505 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
506 |
+
# loss
|
507 |
+
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
508 |
+
loss = F.cross_entropy(logits, targets, reduction="sum")
|
509 |
+
acc = self.ar_accuracy_metric(logits.detach(), targets).item()
|
510 |
+
return loss, acc
|
511 |
+
|
512 |
+
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
513 |
+
def infer(
|
514 |
+
self,
|
515 |
+
x,
|
516 |
+
x_lens,
|
517 |
+
prompts,
|
518 |
+
bert_feature,
|
519 |
+
top_k: int = -100,
|
520 |
+
early_stop_num: int = -1,
|
521 |
+
temperature: float = 1.0,
|
522 |
+
):
|
523 |
+
x = self.ar_text_embedding(x)
|
524 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
525 |
+
x = self.ar_text_position(x)
|
526 |
+
|
527 |
+
# AR Decoder
|
528 |
+
y = prompts
|
529 |
+
prefix_len = y.shape[1]
|
530 |
+
x_len = x.shape[1]
|
531 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
532 |
+
stop = False
|
533 |
+
for _ in tqdm(range(1500)):
|
534 |
+
y_emb = self.ar_audio_embedding(y)
|
535 |
+
y_pos = self.ar_audio_position(y_emb)
|
536 |
+
# x 和逐渐增长的 y 一起输入给模型
|
537 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
538 |
+
y_len = y.shape[1]
|
539 |
+
x_attn_mask_pad = F.pad(
|
540 |
+
x_attn_mask,
|
541 |
+
(0, y_len),
|
542 |
+
value=True,
|
543 |
+
)
|
544 |
+
y_attn_mask = F.pad(
|
545 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
546 |
+
(x_len, 0),
|
547 |
+
value=False,
|
548 |
+
)
|
549 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
550 |
+
|
551 |
+
xy_dec, _ = self.h(
|
552 |
+
(xy_pos, None),
|
553 |
+
mask=xy_attn_mask,
|
554 |
+
)
|
555 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
556 |
+
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
557 |
+
|
558 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
559 |
+
print("use early stop num:", early_stop_num)
|
560 |
+
stop = True
|
561 |
+
|
562 |
+
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
563 |
+
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
|
564 |
+
stop = True
|
565 |
+
if stop:
|
566 |
+
if prompts.shape[1] == y.shape[1]:
|
567 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
568 |
+
print("bad zero prediction")
|
569 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
570 |
+
break
|
571 |
+
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
|
572 |
+
# print(samples.shape)#[1,1]#第一个1是bs
|
573 |
+
# import os
|
574 |
+
# os._exit(2333)
|
575 |
+
y = torch.concat([y, samples], dim=1)
|
576 |
+
return y
|
577 |
+
|
578 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
579 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
580 |
+
# 错位
|
581 |
+
return targets[:, :-1], targets[:, 1:]
|
582 |
+
|
583 |
+
def infer_panel_batch_infer(
|
584 |
+
self,
|
585 |
+
x: List[torch.LongTensor], #####全部文本token
|
586 |
+
x_lens: torch.LongTensor,
|
587 |
+
prompts: torch.LongTensor, ####参考音频token
|
588 |
+
bert_feature: List[torch.LongTensor],
|
589 |
+
top_k: int = -100,
|
590 |
+
top_p: int = 100,
|
591 |
+
early_stop_num: int = -1,
|
592 |
+
temperature: float = 1.0,
|
593 |
+
repetition_penalty: float = 1.35,
|
594 |
+
**kwargs,
|
595 |
+
):
|
596 |
+
if prompts is None:
|
597 |
+
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
598 |
+
return self.infer_panel_naive_batched(
|
599 |
+
x,
|
600 |
+
x_lens,
|
601 |
+
prompts,
|
602 |
+
bert_feature,
|
603 |
+
top_k=top_k,
|
604 |
+
top_p=top_p,
|
605 |
+
early_stop_num=early_stop_num,
|
606 |
+
temperature=temperature,
|
607 |
+
**kwargs,
|
608 |
+
)
|
609 |
+
|
610 |
+
max_len = kwargs.get("max_len", x_lens.max())
|
611 |
+
x_list = []
|
612 |
+
for x_item, bert_item in zip(x, bert_feature):
|
613 |
+
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
614 |
+
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
615 |
+
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
616 |
+
x_item = self.ar_text_position(x_item).squeeze(0)
|
617 |
+
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
618 |
+
x_item = (
|
619 |
+
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
620 |
+
) ### padding left
|
621 |
+
x_list.append(x_item)
|
622 |
+
x: torch.Tensor = torch.stack(x_list, dim=0)
|
623 |
+
|
624 |
+
# AR Decoder
|
625 |
+
y = prompts
|
626 |
+
|
627 |
+
x_len = x.shape[1]
|
628 |
+
stop = False
|
629 |
+
|
630 |
+
k_cache = None
|
631 |
+
v_cache = None
|
632 |
+
################### first step ##########################
|
633 |
+
assert y is not None, "Error: Prompt free is not supported batch_infer!"
|
634 |
+
ref_free = False
|
635 |
+
|
636 |
+
y_emb = self.ar_audio_embedding(y)
|
637 |
+
y_len = y_emb.shape[1]
|
638 |
+
prefix_len = y.shape[1]
|
639 |
+
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
640 |
+
y_pos = self.ar_audio_position(y_emb)
|
641 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
642 |
+
|
643 |
+
##### create mask #####
|
644 |
+
bsz = x.shape[0]
|
645 |
+
src_len = x_len + y_len
|
646 |
+
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
647 |
+
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
648 |
+
|
649 |
+
# (bsz, x_len + y_len)
|
650 |
+
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
651 |
+
|
652 |
+
x_mask = F.pad(
|
653 |
+
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
654 |
+
(0, y_len),
|
655 |
+
value=True,
|
656 |
+
)
|
657 |
+
|
658 |
+
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
659 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
660 |
+
(x_len, 0),
|
661 |
+
value=False,
|
662 |
+
)
|
663 |
+
|
664 |
+
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
665 |
+
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
666 |
+
### 上面是错误的,会导致padding的token被"看见"
|
667 |
+
|
668 |
+
# 正确的padding_mask应该是:
|
669 |
+
# | pad_len | x_len | y_len |
|
670 |
+
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
671 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
672 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
673 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
674 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
675 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
676 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
677 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
678 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
679 |
+
|
680 |
+
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
681 |
+
|
682 |
+
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
683 |
+
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
684 |
+
|
685 |
+
# 正确的attn_mask应该是这样的:
|
686 |
+
# | pad_len | x_len | y_len |
|
687 |
+
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
688 |
+
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
689 |
+
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
690 |
+
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
691 |
+
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
692 |
+
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
693 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
|
694 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
695 |
+
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
696 |
+
|
697 |
+
###### decode #####
|
698 |
+
y_list = [None] * y.shape[0]
|
699 |
+
batch_idx_map = list(range(y.shape[0]))
|
700 |
+
idx_list = [None] * y.shape[0]
|
701 |
+
for idx in tqdm(range(1500)):
|
702 |
+
if idx == 0:
|
703 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
704 |
+
else:
|
705 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
706 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
707 |
+
|
708 |
+
if idx == 0:
|
709 |
+
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
710 |
+
logits = logits[:, :-1]
|
711 |
+
else:
|
712 |
+
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
713 |
+
|
714 |
+
samples = sample(
|
715 |
+
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
716 |
+
)[0]
|
717 |
+
|
718 |
+
y = torch.concat([y, samples], dim=1)
|
719 |
+
|
720 |
+
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
721 |
+
tokens = torch.argmax(logits, dim=-1)
|
722 |
+
reserved_idx_of_batch_for_y = None
|
723 |
+
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
724 |
+
l1 = samples[:, 0] == self.EOS
|
725 |
+
l2 = tokens == self.EOS
|
726 |
+
l = l1.logical_or(l2)
|
727 |
+
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
728 |
+
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
729 |
+
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
730 |
+
for i in removed_idx_of_batch_for_y:
|
731 |
+
batch_index = batch_idx_map[i]
|
732 |
+
idx_list[batch_index] = idx
|
733 |
+
y_list[batch_index] = y[i, :-1]
|
734 |
+
|
735 |
+
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
736 |
+
|
737 |
+
# 只保留batch中未生成完毕的序列
|
738 |
+
if reserved_idx_of_batch_for_y is not None:
|
739 |
+
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
740 |
+
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
741 |
+
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
742 |
+
if k_cache is not None:
|
743 |
+
for i in range(len(k_cache)):
|
744 |
+
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
745 |
+
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
746 |
+
|
747 |
+
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
748 |
+
print("use early stop num:", early_stop_num)
|
749 |
+
stop = True
|
750 |
+
for i, batch_index in enumerate(batch_idx_map):
|
751 |
+
batch_index = batch_idx_map[i]
|
752 |
+
idx_list[batch_index] = idx
|
753 |
+
y_list[batch_index] = y[i, :-1]
|
754 |
+
|
755 |
+
if None not in idx_list:
|
756 |
+
stop = True
|
757 |
+
|
758 |
+
if stop:
|
759 |
+
if y.shape[1] == 0:
|
760 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
761 |
+
print("bad zero prediction")
|
762 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
763 |
+
break
|
764 |
+
|
765 |
+
####################### update next step ###################################
|
766 |
+
y_emb = self.ar_audio_embedding(y[:, -1:])
|
767 |
+
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
768 |
+
:, y_len + idx
|
769 |
+
].to(dtype=y_emb.dtype, device=y_emb.device)
|
770 |
+
|
771 |
+
if None in idx_list:
|
772 |
+
for i in range(x.shape[0]):
|
773 |
+
if idx_list[i] is None:
|
774 |
+
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
775 |
+
|
776 |
+
if ref_free:
|
777 |
+
return y_list, [0] * x.shape[0]
|
778 |
+
# print(idx_list)
|
779 |
+
return y_list, idx_list
|
780 |
+
|
781 |
+
def infer_panel_naive_batched(
|
782 |
+
self,
|
783 |
+
x: List[torch.LongTensor], #####全部文本token
|
784 |
+
x_lens: torch.LongTensor,
|
785 |
+
prompts: torch.LongTensor, ####参考音频token
|
786 |
+
bert_feature: List[torch.LongTensor],
|
787 |
+
top_k: int = -100,
|
788 |
+
top_p: int = 100,
|
789 |
+
early_stop_num: int = -1,
|
790 |
+
temperature: float = 1.0,
|
791 |
+
repetition_penalty: float = 1.35,
|
792 |
+
**kwargs,
|
793 |
+
):
|
794 |
+
y_list = []
|
795 |
+
idx_list = []
|
796 |
+
for i in range(len(x)):
|
797 |
+
y, idx = self.infer_panel_naive(
|
798 |
+
x[i].unsqueeze(0),
|
799 |
+
x_lens[i],
|
800 |
+
prompts[i].unsqueeze(0) if prompts is not None else None,
|
801 |
+
bert_feature[i].unsqueeze(0),
|
802 |
+
top_k,
|
803 |
+
top_p,
|
804 |
+
early_stop_num,
|
805 |
+
temperature,
|
806 |
+
repetition_penalty,
|
807 |
+
**kwargs,
|
808 |
+
)
|
809 |
+
y_list.append(y[0])
|
810 |
+
idx_list.append(idx)
|
811 |
+
|
812 |
+
return y_list, idx_list
|
813 |
+
|
814 |
+
def infer_panel_naive(
|
815 |
+
self,
|
816 |
+
x: torch.LongTensor, #####全部文本token
|
817 |
+
x_lens: torch.LongTensor,
|
818 |
+
prompts: torch.LongTensor, ####参考音频token
|
819 |
+
bert_feature: torch.LongTensor,
|
820 |
+
top_k: int = -100,
|
821 |
+
top_p: int = 100,
|
822 |
+
early_stop_num: int = -1,
|
823 |
+
temperature: float = 1.0,
|
824 |
+
repetition_penalty: float = 1.35,
|
825 |
+
**kwargs,
|
826 |
+
):
|
827 |
+
x = self.ar_text_embedding(x)
|
828 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
829 |
+
x = self.ar_text_position(x)
|
830 |
+
|
831 |
+
# AR Decoder
|
832 |
+
y = prompts
|
833 |
+
|
834 |
+
x_len = x.shape[1]
|
835 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
836 |
+
stop = False
|
837 |
+
# print(1111111,self.num_layers)
|
838 |
+
|
839 |
+
k_cache = None
|
840 |
+
v_cache = None
|
841 |
+
################### first step ##########################
|
842 |
+
if y is not None:
|
843 |
+
y_emb = self.ar_audio_embedding(y)
|
844 |
+
y_len = y_emb.shape[1]
|
845 |
+
prefix_len = y.shape[1]
|
846 |
+
y_pos = self.ar_audio_position(y_emb)
|
847 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
848 |
+
ref_free = False
|
849 |
+
else:
|
850 |
+
y_emb = None
|
851 |
+
y_len = 0
|
852 |
+
prefix_len = 0
|
853 |
+
y_pos = None
|
854 |
+
xy_pos = x
|
855 |
+
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
856 |
+
ref_free = True
|
857 |
+
|
858 |
+
bsz = x.shape[0]
|
859 |
+
src_len = x_len + y_len
|
860 |
+
x_attn_mask_pad = F.pad(
|
861 |
+
x_attn_mask,
|
862 |
+
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
863 |
+
value=True,
|
864 |
+
)
|
865 |
+
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
866 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
867 |
+
(x_len, 0),
|
868 |
+
value=False,
|
869 |
+
)
|
870 |
+
xy_attn_mask = (
|
871 |
+
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
872 |
+
.unsqueeze(0)
|
873 |
+
.expand(bsz * self.num_head, -1, -1)
|
874 |
+
.view(bsz, self.num_head, src_len, src_len)
|
875 |
+
.to(device=x.device, dtype=torch.bool)
|
876 |
+
)
|
877 |
+
|
878 |
+
for idx in tqdm(range(1500)):
|
879 |
+
if xy_attn_mask is not None:
|
880 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
881 |
+
else:
|
882 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
883 |
+
|
884 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
885 |
+
|
886 |
+
if idx == 0:
|
887 |
+
xy_attn_mask = None
|
888 |
+
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
889 |
+
logits = logits[:, :-1]
|
890 |
+
|
891 |
+
samples = sample(
|
892 |
+
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
893 |
+
)[0]
|
894 |
+
|
895 |
+
y = torch.concat([y, samples], dim=1)
|
896 |
+
|
897 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
898 |
+
print("use early stop num:", early_stop_num)
|
899 |
+
stop = True
|
900 |
+
|
901 |
+
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
902 |
+
stop = True
|
903 |
+
if stop:
|
904 |
+
if y.shape[1] == 0:
|
905 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
906 |
+
print("bad zero prediction")
|
907 |
+
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
908 |
+
break
|
909 |
+
|
910 |
+
####################### update next step ###################################
|
911 |
+
y_emb = self.ar_audio_embedding(y[:, -1:])
|
912 |
+
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
913 |
+
:, y_len + idx
|
914 |
+
].to(dtype=y_emb.dtype, device=y_emb.device)
|
915 |
+
|
916 |
+
if ref_free:
|
917 |
+
return y[:, :-1], 0
|
918 |
+
return y[:, :-1], idx
|
919 |
+
|
920 |
+
def infer_panel(
|
921 |
+
self,
|
922 |
+
x: torch.LongTensor, #####全部文本token
|
923 |
+
x_lens: torch.LongTensor,
|
924 |
+
prompts: torch.LongTensor, ####参考音频token
|
925 |
+
bert_feature: torch.LongTensor,
|
926 |
+
top_k: int = -100,
|
927 |
+
top_p: int = 100,
|
928 |
+
early_stop_num: int = -1,
|
929 |
+
temperature: float = 1.0,
|
930 |
+
repetition_penalty: float = 1.35,
|
931 |
+
**kwargs,
|
932 |
+
):
|
933 |
+
return self.infer_panel_naive(
|
934 |
+
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
935 |
+
)
|
GPT_SoVITS/AR/models/t2s_model_onnx.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torchmetrics.classification import MulticlassAccuracy
|
7 |
+
|
8 |
+
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
9 |
+
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
10 |
+
|
11 |
+
default_config = {
|
12 |
+
"embedding_dim": 512,
|
13 |
+
"hidden_dim": 512,
|
14 |
+
"num_head": 8,
|
15 |
+
"num_layers": 12,
|
16 |
+
"num_codebook": 8,
|
17 |
+
"p_dropout": 0.0,
|
18 |
+
"vocab_size": 1024 + 1,
|
19 |
+
"phoneme_vocab_size": 512,
|
20 |
+
"EOS": 1024,
|
21 |
+
}
|
22 |
+
|
23 |
+
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
24 |
+
|
25 |
+
|
26 |
+
def logits_to_probs(
|
27 |
+
logits,
|
28 |
+
previous_tokens=None,
|
29 |
+
temperature: float = 1.0,
|
30 |
+
top_k=None,
|
31 |
+
top_p=None,
|
32 |
+
repetition_penalty: float = 1.0,
|
33 |
+
):
|
34 |
+
previous_tokens = previous_tokens.squeeze()
|
35 |
+
if previous_tokens is not None and repetition_penalty != 1.0:
|
36 |
+
previous_tokens = previous_tokens.long()
|
37 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
38 |
+
score = torch.where(
|
39 |
+
score < 0,
|
40 |
+
score * repetition_penalty,
|
41 |
+
score / repetition_penalty,
|
42 |
+
)
|
43 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
44 |
+
|
45 |
+
if top_p is not None and top_p < 1.0:
|
46 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
47 |
+
cum_probs = torch.cumsum(
|
48 |
+
torch.nn.functional.softmax(
|
49 |
+
sorted_logits,
|
50 |
+
dim=-1,
|
51 |
+
),
|
52 |
+
dim=-1,
|
53 |
+
)
|
54 |
+
sorted_indices_to_remove = cum_probs > top_p
|
55 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
56 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
57 |
+
dim=0,
|
58 |
+
index=sorted_indices,
|
59 |
+
src=sorted_indices_to_remove,
|
60 |
+
)
|
61 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
62 |
+
|
63 |
+
logits = logits / max(temperature, 1e-5)
|
64 |
+
|
65 |
+
if top_k is not None:
|
66 |
+
v, _ = torch.topk(logits, top_k)
|
67 |
+
pivot = v.select(-1, -1).unsqueeze(-1)
|
68 |
+
logits = torch.where(logits < pivot, inf_tensor_value, logits)
|
69 |
+
|
70 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
71 |
+
return probs
|
72 |
+
|
73 |
+
|
74 |
+
def multinomial_sample_one_no_sync(
|
75 |
+
probs_sort,
|
76 |
+
): # Does multinomial sampling without a cuda synchronization
|
77 |
+
q = torch.randn_like(probs_sort)
|
78 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
79 |
+
|
80 |
+
|
81 |
+
def sample(
|
82 |
+
logits,
|
83 |
+
previous_tokens,
|
84 |
+
**sampling_kwargs,
|
85 |
+
):
|
86 |
+
probs = logits_to_probs(
|
87 |
+
logits=logits,
|
88 |
+
previous_tokens=previous_tokens,
|
89 |
+
**sampling_kwargs,
|
90 |
+
)
|
91 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
92 |
+
return idx_next, probs
|
93 |
+
|
94 |
+
|
95 |
+
class OnnxEncoder(nn.Module):
|
96 |
+
def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
|
97 |
+
super().__init__()
|
98 |
+
self.ar_text_embedding = ar_text_embedding
|
99 |
+
self.bert_proj = bert_proj
|
100 |
+
self.ar_text_position = ar_text_position
|
101 |
+
|
102 |
+
def forward(self, x, bert_feature):
|
103 |
+
x = self.ar_text_embedding(x)
|
104 |
+
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
105 |
+
return self.ar_text_position(x)
|
106 |
+
|
107 |
+
|
108 |
+
class T2SFirstStageDecoder(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
ar_audio_embedding,
|
112 |
+
ar_audio_position,
|
113 |
+
h,
|
114 |
+
ar_predict_layer,
|
115 |
+
loss_fct,
|
116 |
+
ar_accuracy_metric,
|
117 |
+
top_k,
|
118 |
+
early_stop_num,
|
119 |
+
num_layers,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.ar_audio_embedding = ar_audio_embedding
|
123 |
+
self.ar_audio_position = ar_audio_position
|
124 |
+
self.h = h
|
125 |
+
self.ar_predict_layer = ar_predict_layer
|
126 |
+
self.loss_fct = loss_fct
|
127 |
+
self.ar_accuracy_metric = ar_accuracy_metric
|
128 |
+
self.top_k = top_k
|
129 |
+
self.early_stop_num = early_stop_num
|
130 |
+
self.num_layers = num_layers
|
131 |
+
|
132 |
+
def forward(self, x, prompt):
|
133 |
+
y = prompt
|
134 |
+
x_example = x[:, :, 0] * 0.0
|
135 |
+
# N, 1, 512
|
136 |
+
cache = {
|
137 |
+
"all_stage": self.num_layers,
|
138 |
+
"k": None,
|
139 |
+
"v": None,
|
140 |
+
"y_emb": None,
|
141 |
+
"first_infer": 1,
|
142 |
+
"stage": 0,
|
143 |
+
}
|
144 |
+
|
145 |
+
y_emb = self.ar_audio_embedding(y)
|
146 |
+
|
147 |
+
cache["y_emb"] = y_emb
|
148 |
+
y_pos = self.ar_audio_position(y_emb)
|
149 |
+
|
150 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
151 |
+
|
152 |
+
y_example = y_pos[:, :, 0] * 0.0
|
153 |
+
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
154 |
+
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
155 |
+
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
156 |
+
torch.ones_like(
|
157 |
+
y_example.transpose(0, 1),
|
158 |
+
dtype=torch.int64,
|
159 |
+
),
|
160 |
+
dim=0,
|
161 |
+
)
|
162 |
+
y_attn_mask = y_attn_mask > 0
|
163 |
+
|
164 |
+
x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
|
165 |
+
y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
|
166 |
+
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
167 |
+
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
168 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
169 |
+
cache["k"] = (
|
170 |
+
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
171 |
+
.unsqueeze(1)
|
172 |
+
.repeat(self.num_layers, 1, 1, 1)
|
173 |
+
)
|
174 |
+
cache["v"] = (
|
175 |
+
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
176 |
+
.unsqueeze(1)
|
177 |
+
.repeat(self.num_layers, 1, 1, 1)
|
178 |
+
)
|
179 |
+
|
180 |
+
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
181 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
182 |
+
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
183 |
+
|
184 |
+
y = torch.concat([y, samples], dim=1)
|
185 |
+
|
186 |
+
return y, cache["k"], cache["v"], cache["y_emb"], x_example
|
187 |
+
|
188 |
+
|
189 |
+
class T2SStageDecoder(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
ar_audio_embedding,
|
193 |
+
ar_audio_position,
|
194 |
+
h,
|
195 |
+
ar_predict_layer,
|
196 |
+
loss_fct,
|
197 |
+
ar_accuracy_metric,
|
198 |
+
top_k,
|
199 |
+
early_stop_num,
|
200 |
+
num_layers,
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
self.ar_audio_embedding = ar_audio_embedding
|
204 |
+
self.ar_audio_position = ar_audio_position
|
205 |
+
self.h = h
|
206 |
+
self.ar_predict_layer = ar_predict_layer
|
207 |
+
self.loss_fct = loss_fct
|
208 |
+
self.ar_accuracy_metric = ar_accuracy_metric
|
209 |
+
self.top_k = top_k
|
210 |
+
self.early_stop_num = early_stop_num
|
211 |
+
self.num_layers = num_layers
|
212 |
+
|
213 |
+
def forward(self, y, k, v, y_emb, x_example):
|
214 |
+
cache = {
|
215 |
+
"all_stage": self.num_layers,
|
216 |
+
"k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
|
217 |
+
"v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
|
218 |
+
"y_emb": y_emb,
|
219 |
+
"first_infer": 0,
|
220 |
+
"stage": 0,
|
221 |
+
}
|
222 |
+
|
223 |
+
y_emb = torch.cat(
|
224 |
+
[
|
225 |
+
cache["y_emb"],
|
226 |
+
self.ar_audio_embedding(y[:, -1:]),
|
227 |
+
],
|
228 |
+
1,
|
229 |
+
)
|
230 |
+
cache["y_emb"] = y_emb
|
231 |
+
y_pos = self.ar_audio_position(y_emb)
|
232 |
+
|
233 |
+
xy_pos = y_pos[:, -1:]
|
234 |
+
|
235 |
+
y_example = y_pos[:, :, 0] * 0.0
|
236 |
+
|
237 |
+
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
238 |
+
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
239 |
+
|
240 |
+
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
241 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
242 |
+
samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
243 |
+
|
244 |
+
y = torch.concat([y, samples], dim=1)
|
245 |
+
|
246 |
+
return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
|
247 |
+
|
248 |
+
|
249 |
+
class Text2SemanticDecoder(nn.Module):
|
250 |
+
def __init__(self, config, norm_first=False, top_k=3):
|
251 |
+
super(Text2SemanticDecoder, self).__init__()
|
252 |
+
self.model_dim = config["model"]["hidden_dim"]
|
253 |
+
self.embedding_dim = config["model"]["embedding_dim"]
|
254 |
+
self.num_head = config["model"]["head"]
|
255 |
+
self.num_layers = config["model"]["n_layer"]
|
256 |
+
self.norm_first = norm_first
|
257 |
+
self.vocab_size = config["model"]["vocab_size"]
|
258 |
+
self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
|
259 |
+
self.p_dropout = float(config["model"]["dropout"])
|
260 |
+
self.EOS = config["model"]["EOS"]
|
261 |
+
self.norm_first = norm_first
|
262 |
+
assert self.EOS == self.vocab_size - 1
|
263 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
264 |
+
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
|
265 |
+
self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
266 |
+
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
|
267 |
+
self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
|
268 |
+
self.h = TransformerEncoder(
|
269 |
+
TransformerEncoderLayer(
|
270 |
+
d_model=self.model_dim,
|
271 |
+
nhead=self.num_head,
|
272 |
+
dim_feedforward=self.model_dim * 4,
|
273 |
+
dropout=0.1,
|
274 |
+
batch_first=True,
|
275 |
+
norm_first=norm_first,
|
276 |
+
),
|
277 |
+
num_layers=self.num_layers,
|
278 |
+
norm=LayerNorm(self.model_dim) if norm_first else None,
|
279 |
+
)
|
280 |
+
self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
281 |
+
self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
282 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
283 |
+
self.vocab_size,
|
284 |
+
top_k=top_k,
|
285 |
+
average="micro",
|
286 |
+
multidim_average="global",
|
287 |
+
ignore_index=self.EOS,
|
288 |
+
)
|
289 |
+
self.top_k = torch.LongTensor([1])
|
290 |
+
self.early_stop_num = torch.LongTensor([-1])
|
291 |
+
|
292 |
+
def init_onnx(self):
|
293 |
+
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
294 |
+
self.first_stage_decoder = T2SFirstStageDecoder(
|
295 |
+
self.ar_audio_embedding,
|
296 |
+
self.ar_audio_position,
|
297 |
+
self.h,
|
298 |
+
self.ar_predict_layer,
|
299 |
+
self.loss_fct,
|
300 |
+
self.ar_accuracy_metric,
|
301 |
+
self.top_k,
|
302 |
+
self.early_stop_num,
|
303 |
+
self.num_layers,
|
304 |
+
)
|
305 |
+
self.stage_decoder = T2SStageDecoder(
|
306 |
+
self.ar_audio_embedding,
|
307 |
+
self.ar_audio_position,
|
308 |
+
self.h,
|
309 |
+
self.ar_predict_layer,
|
310 |
+
self.loss_fct,
|
311 |
+
self.ar_accuracy_metric,
|
312 |
+
self.top_k,
|
313 |
+
self.early_stop_num,
|
314 |
+
self.num_layers,
|
315 |
+
)
|
316 |
+
|
317 |
+
def forward(self, x, prompts, bert_feature):
|
318 |
+
early_stop_num = self.early_stop_num
|
319 |
+
prefix_len = prompts.shape[1]
|
320 |
+
|
321 |
+
x = self.onnx_encoder(x, bert_feature)
|
322 |
+
y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
|
323 |
+
|
324 |
+
stop = False
|
325 |
+
for idx in range(1, 1500):
|
326 |
+
enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
|
327 |
+
y, k, v, y_emb, stage, logits, samples = enco
|
328 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
329 |
+
stop = True
|
330 |
+
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
331 |
+
stop = True
|
332 |
+
if stop:
|
333 |
+
break
|
334 |
+
y[0, -1] = 0
|
335 |
+
return y, idx
|
336 |
+
|
337 |
+
def infer(self, x, prompts, bert_feature):
|
338 |
+
top_k = self.top_k
|
339 |
+
early_stop_num = self.early_stop_num
|
340 |
+
|
341 |
+
x = self.onnx_encoder(x, bert_feature)
|
342 |
+
|
343 |
+
y = prompts
|
344 |
+
prefix_len = y.shape[1]
|
345 |
+
x_len = x.shape[1]
|
346 |
+
x_example = x[:, :, 0] * 0.0
|
347 |
+
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
348 |
+
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
349 |
+
|
350 |
+
stop = False
|
351 |
+
cache = {
|
352 |
+
"all_stage": self.num_layers,
|
353 |
+
"k": [None] * self.num_layers,
|
354 |
+
"v": [None] * self.num_layers,
|
355 |
+
"y_emb": None,
|
356 |
+
"first_infer": 1,
|
357 |
+
"stage": 0,
|
358 |
+
}
|
359 |
+
for idx in range(1500):
|
360 |
+
if cache["first_infer"] == 1:
|
361 |
+
y_emb = self.ar_audio_embedding(y)
|
362 |
+
else:
|
363 |
+
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
364 |
+
cache["y_emb"] = y_emb
|
365 |
+
y_pos = self.ar_audio_position(y_emb)
|
366 |
+
if cache["first_infer"] == 1:
|
367 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
368 |
+
else:
|
369 |
+
xy_pos = y_pos[:, -1:]
|
370 |
+
y_len = y_pos.shape[1]
|
371 |
+
if cache["first_infer"] == 1:
|
372 |
+
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
373 |
+
y_attn_mask = F.pad(
|
374 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
375 |
+
(x_len, 0),
|
376 |
+
value=False,
|
377 |
+
)
|
378 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
379 |
+
else:
|
380 |
+
xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
|
381 |
+
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
382 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
383 |
+
samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
|
384 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
385 |
+
stop = True
|
386 |
+
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
387 |
+
stop = True
|
388 |
+
if stop:
|
389 |
+
if prompts.shape[1] == y.shape[1]:
|
390 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
391 |
+
break
|
392 |
+
y = torch.concat([y, samples], dim=1)
|
393 |
+
cache["first_infer"] = 0
|
394 |
+
return y, idx
|
GPT_SoVITS/AR/models/utils.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def sequence_mask(length, max_length=None):
|
10 |
+
if max_length is None:
|
11 |
+
max_length = length.max()
|
12 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
13 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
14 |
+
|
15 |
+
|
16 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
lengths:
|
20 |
+
A 1-D tensor containing sentence lengths.
|
21 |
+
max_len:
|
22 |
+
The length of masks.
|
23 |
+
Returns:
|
24 |
+
Return a 2-D bool tensor, where masked positions
|
25 |
+
are filled with `True` and non-masked positions are
|
26 |
+
filled with `False`.
|
27 |
+
|
28 |
+
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
29 |
+
#>>> make_pad_mask(lengths)
|
30 |
+
tensor([[False, True, True, True, True],
|
31 |
+
[False, False, False, True, True],
|
32 |
+
[False, False, True, True, True],
|
33 |
+
[False, False, False, False, False]])
|
34 |
+
"""
|
35 |
+
assert lengths.ndim == 1, lengths.ndim
|
36 |
+
max_len = max(max_len, lengths.max())
|
37 |
+
n = lengths.size(0)
|
38 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
39 |
+
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
40 |
+
|
41 |
+
return expaned_lengths >= lengths.unsqueeze(-1)
|
42 |
+
|
43 |
+
|
44 |
+
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
45 |
+
"""
|
46 |
+
Args:
|
47 |
+
lengths:
|
48 |
+
A 1-D tensor containing sentence lengths.
|
49 |
+
max_len:
|
50 |
+
The length of masks.
|
51 |
+
Returns:
|
52 |
+
Return a 2-D bool tensor, where masked positions
|
53 |
+
are filled with `True` and non-masked positions are
|
54 |
+
filled with `False`.
|
55 |
+
|
56 |
+
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
57 |
+
#>>> make_pad_mask(lengths)
|
58 |
+
tensor(
|
59 |
+
[
|
60 |
+
[True, True, False],
|
61 |
+
[True, False, False],
|
62 |
+
[True, True, False],
|
63 |
+
...
|
64 |
+
]
|
65 |
+
)
|
66 |
+
"""
|
67 |
+
assert lengths.ndim == 1, lengths.ndim
|
68 |
+
max_len = max(max_len, lengths.max())
|
69 |
+
n = lengths.size(0)
|
70 |
+
seq_range = torch.arange(0, max_len, device=lengths.device)
|
71 |
+
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
72 |
+
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
|
73 |
+
|
74 |
+
return expaned_lengths < 0
|
75 |
+
|
76 |
+
|
77 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
78 |
+
def top_k_top_p_filtering(
|
79 |
+
logits,
|
80 |
+
top_k=0,
|
81 |
+
top_p=1.0,
|
82 |
+
filter_value=-float("Inf"),
|
83 |
+
min_tokens_to_keep=1,
|
84 |
+
):
|
85 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
86 |
+
Args:
|
87 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
88 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
89 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
90 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
91 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
92 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
93 |
+
"""
|
94 |
+
if top_k > 0:
|
95 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
96 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
97 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
98 |
+
logits[indices_to_remove] = filter_value
|
99 |
+
|
100 |
+
if top_p < 1.0:
|
101 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
102 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
103 |
+
|
104 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
105 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
106 |
+
if min_tokens_to_keep > 1:
|
107 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
108 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
109 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
110 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
111 |
+
sorted_indices_to_remove[..., 0] = 0
|
112 |
+
|
113 |
+
# scatter sorted tensors to original indexing
|
114 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
115 |
+
logits[indices_to_remove] = filter_value
|
116 |
+
return logits
|
117 |
+
|
118 |
+
|
119 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
120 |
+
# temperature: (`optional`) float
|
121 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
122 |
+
# top_k: (`optional`) int
|
123 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
124 |
+
# top_p: (`optional`) float
|
125 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
126 |
+
|
127 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
128 |
+
if temperature != 1.0:
|
129 |
+
logits = logits / temperature
|
130 |
+
# Top-p/top-k filtering
|
131 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
132 |
+
# Sample
|
133 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
134 |
+
return token
|
135 |
+
|
136 |
+
|
137 |
+
from typing import Optional
|
138 |
+
|
139 |
+
|
140 |
+
def multinomial_sample_one_no_sync(
|
141 |
+
probs_sort,
|
142 |
+
): # Does multinomial sampling without a cuda synchronization
|
143 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
144 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
145 |
+
|
146 |
+
|
147 |
+
def logits_to_probs(
|
148 |
+
logits,
|
149 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
150 |
+
temperature: float = 1.0,
|
151 |
+
top_k: Optional[int] = None,
|
152 |
+
top_p: Optional[int] = None,
|
153 |
+
repetition_penalty: float = 1.0,
|
154 |
+
):
|
155 |
+
# if previous_tokens is not None:
|
156 |
+
# previous_tokens = previous_tokens.squeeze()
|
157 |
+
# print(logits.shape,previous_tokens.shape)
|
158 |
+
# pdb.set_trace()
|
159 |
+
if previous_tokens is not None and repetition_penalty != 1.0:
|
160 |
+
previous_tokens = previous_tokens.long()
|
161 |
+
score = torch.gather(logits, dim=1, index=previous_tokens)
|
162 |
+
score = torch.where(
|
163 |
+
score < 0,
|
164 |
+
score * repetition_penalty,
|
165 |
+
score / repetition_penalty,
|
166 |
+
)
|
167 |
+
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
168 |
+
|
169 |
+
if top_p is not None and top_p < 1.0:
|
170 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
171 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
172 |
+
sorted_indices_to_remove = cum_probs > top_p
|
173 |
+
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
174 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
175 |
+
dim=1,
|
176 |
+
index=sorted_indices,
|
177 |
+
src=sorted_indices_to_remove,
|
178 |
+
)
|
179 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
180 |
+
|
181 |
+
logits = logits / max(temperature, 1e-5)
|
182 |
+
|
183 |
+
if top_k is not None:
|
184 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
185 |
+
pivot = v[:, -1].unsqueeze(-1)
|
186 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
187 |
+
|
188 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
189 |
+
return probs
|
190 |
+
|
191 |
+
|
192 |
+
def sample(
|
193 |
+
logits,
|
194 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
195 |
+
**sampling_kwargs,
|
196 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
197 |
+
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
198 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
199 |
+
return idx_next, probs
|
200 |
+
|
201 |
+
|
202 |
+
def dpo_loss(
|
203 |
+
policy_chosen_logps: torch.FloatTensor,
|
204 |
+
policy_rejected_logps: torch.FloatTensor,
|
205 |
+
reference_chosen_logps: torch.FloatTensor,
|
206 |
+
reference_rejected_logps: torch.FloatTensor,
|
207 |
+
beta: float,
|
208 |
+
reference_free: bool = False,
|
209 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
210 |
+
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
211 |
+
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
212 |
+
|
213 |
+
if reference_free:
|
214 |
+
ref_logratios = 0
|
215 |
+
|
216 |
+
logits = pi_logratios - ref_logratios
|
217 |
+
|
218 |
+
losses = -F.logsigmoid(beta * logits)
|
219 |
+
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
220 |
+
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
221 |
+
|
222 |
+
return losses.mean(), chosen_rewards, rejected_rewards
|
223 |
+
|
224 |
+
|
225 |
+
def get_batch_logps(
|
226 |
+
logits_target: torch.FloatTensor,
|
227 |
+
logits_reject: torch.FloatTensor,
|
228 |
+
labels_target: torch.LongTensor,
|
229 |
+
labels_reject: torch.LongTensor,
|
230 |
+
average_log_prob: bool = False,
|
231 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
232 |
+
# dummy token; we'll ignore the losses on these tokens later
|
233 |
+
|
234 |
+
per_token_logps_target = torch.gather(
|
235 |
+
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
236 |
+
).squeeze(2)
|
237 |
+
per_token_logps_reject = torch.gather(
|
238 |
+
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
239 |
+
).squeeze(2)
|
240 |
+
|
241 |
+
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
242 |
+
|
243 |
+
|
244 |
+
def make_reject_y(y_o, y_lens):
|
245 |
+
def repeat_P(y):
|
246 |
+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
247 |
+
pre = y[: range_idx[0]]
|
248 |
+
shf = y[range_idx[1] :]
|
249 |
+
range_text = y[range_idx[0] : range_idx[1]]
|
250 |
+
new_y = torch.cat([pre, range_text, range_text, shf])
|
251 |
+
return new_y
|
252 |
+
|
253 |
+
def lost_P(y):
|
254 |
+
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
255 |
+
pre = y[: range_idx[0]]
|
256 |
+
shf = y[range_idx[1] :]
|
257 |
+
range_text = y[range_idx[0] : range_idx[1]]
|
258 |
+
new_y = torch.cat([pre, shf])
|
259 |
+
return new_y
|
260 |
+
|
261 |
+
bs = len(y_lens)
|
262 |
+
reject_y = []
|
263 |
+
reject_y_lens = []
|
264 |
+
for b in range(bs):
|
265 |
+
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
266 |
+
if process_item_idx == 0:
|
267 |
+
new_y = repeat_P(y_o[b])
|
268 |
+
reject_y.append(new_y)
|
269 |
+
reject_y_lens.append(len(new_y))
|
270 |
+
elif process_item_idx == 1:
|
271 |
+
new_y = lost_P(y_o[b])
|
272 |
+
reject_y.append(new_y)
|
273 |
+
reject_y_lens.append(len(new_y))
|
274 |
+
max_length = max(reject_y_lens)
|
275 |
+
for b in range(bs):
|
276 |
+
pad_length = max_length - reject_y_lens[b]
|
277 |
+
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
278 |
+
|
279 |
+
reject_y = torch.stack(reject_y, dim=0)
|
280 |
+
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
281 |
+
|
282 |
+
return reject_y, reject_y_lens
|
GPT_SoVITS/AR/modules/__init__.py
ADDED
File without changes
|
GPT_SoVITS/AR/modules/activation.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
|
12 |
+
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
13 |
+
|
14 |
+
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
15 |
+
|
16 |
+
|
17 |
+
class MultiheadAttention(Module):
|
18 |
+
r"""Allows the model to jointly attend to information
|
19 |
+
from different representation subspaces as described in the paper:
|
20 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
21 |
+
|
22 |
+
Multi-Head Attention is defined as:
|
23 |
+
|
24 |
+
.. math::
|
25 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
26 |
+
|
27 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
28 |
+
|
29 |
+
``forward()`` will use a special optimized implementation if all of the following
|
30 |
+
conditions are met:
|
31 |
+
|
32 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
33 |
+
restriction will be loosened in the future.)
|
34 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
35 |
+
- training is disabled (using ``.eval()``)
|
36 |
+
- dropout is 0
|
37 |
+
- ``add_bias_kv`` is ``False``
|
38 |
+
- ``add_zero_attn`` is ``False``
|
39 |
+
- ``batch_first`` is ``True`` and the input is batched
|
40 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
41 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
42 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
43 |
+
nor ``attn_mask`` is passed
|
44 |
+
|
45 |
+
If the optimized implementation is in use, a
|
46 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
47 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
48 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
49 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
50 |
+
that is padding can be expected.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
embed_dim: Total dimension of the model.
|
54 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
55 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
56 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
57 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
58 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
59 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
60 |
+
Default: ``False``.
|
61 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
62 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
63 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
64 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
65 |
+
|
66 |
+
Examples::
|
67 |
+
|
68 |
+
>>> # xdoctest: +SKIP
|
69 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
70 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
71 |
+
|
72 |
+
"""
|
73 |
+
|
74 |
+
__constants__ = ["batch_first"]
|
75 |
+
bias_k: Optional[torch.Tensor]
|
76 |
+
bias_v: Optional[torch.Tensor]
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
embed_dim,
|
81 |
+
num_heads,
|
82 |
+
dropout=0.0,
|
83 |
+
bias=True,
|
84 |
+
add_bias_kv=False,
|
85 |
+
add_zero_attn=False,
|
86 |
+
kdim=None,
|
87 |
+
vdim=None,
|
88 |
+
batch_first=False,
|
89 |
+
linear1_cls=Linear,
|
90 |
+
linear2_cls=Linear,
|
91 |
+
device=None,
|
92 |
+
dtype=None,
|
93 |
+
) -> None:
|
94 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
95 |
+
super(MultiheadAttention, self).__init__()
|
96 |
+
self.embed_dim = embed_dim
|
97 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
98 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
99 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
100 |
+
|
101 |
+
self.num_heads = num_heads
|
102 |
+
self.dropout = dropout
|
103 |
+
self.batch_first = batch_first
|
104 |
+
self.head_dim = embed_dim // num_heads
|
105 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
106 |
+
|
107 |
+
if add_bias_kv:
|
108 |
+
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
109 |
+
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
110 |
+
else:
|
111 |
+
self.bias_k = self.bias_v = None
|
112 |
+
|
113 |
+
if linear1_cls == Linear:
|
114 |
+
if not self._qkv_same_embed_dim:
|
115 |
+
self.q_proj_weight = Parameter(
|
116 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
117 |
+
)
|
118 |
+
self.k_proj_weight = Parameter(
|
119 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
120 |
+
)
|
121 |
+
self.v_proj_weight = Parameter(
|
122 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
123 |
+
)
|
124 |
+
self.register_parameter("in_proj_weight", None)
|
125 |
+
else:
|
126 |
+
self.in_proj_weight = Parameter(
|
127 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
128 |
+
)
|
129 |
+
self.register_parameter("q_proj_weight", None)
|
130 |
+
self.register_parameter("k_proj_weight", None)
|
131 |
+
self.register_parameter("v_proj_weight", None)
|
132 |
+
|
133 |
+
if bias:
|
134 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
135 |
+
else:
|
136 |
+
self.register_parameter("in_proj_bias", None)
|
137 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
138 |
+
embed_dim,
|
139 |
+
embed_dim,
|
140 |
+
bias=bias,
|
141 |
+
**factory_kwargs,
|
142 |
+
)
|
143 |
+
|
144 |
+
self._reset_parameters()
|
145 |
+
else:
|
146 |
+
if not self._qkv_same_embed_dim:
|
147 |
+
raise NotImplementedError
|
148 |
+
else:
|
149 |
+
self.in_proj_linear = linear1_cls(
|
150 |
+
embed_dim,
|
151 |
+
3 * embed_dim,
|
152 |
+
bias=bias,
|
153 |
+
**factory_kwargs,
|
154 |
+
)
|
155 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
156 |
+
|
157 |
+
self.register_parameter("q_proj_weight", None)
|
158 |
+
self.register_parameter("k_proj_weight", None)
|
159 |
+
self.register_parameter("v_proj_weight", None)
|
160 |
+
|
161 |
+
if bias:
|
162 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
163 |
+
else:
|
164 |
+
self.register_parameter("in_proj_bias", None)
|
165 |
+
|
166 |
+
self.out_proj = linear2_cls(
|
167 |
+
embed_dim,
|
168 |
+
embed_dim,
|
169 |
+
bias=bias,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
|
173 |
+
if self.bias_k is not None:
|
174 |
+
xavier_normal_(self.bias_k)
|
175 |
+
if self.bias_v is not None:
|
176 |
+
xavier_normal_(self.bias_v)
|
177 |
+
|
178 |
+
self.add_zero_attn = add_zero_attn
|
179 |
+
|
180 |
+
def _reset_parameters(self):
|
181 |
+
if self._qkv_same_embed_dim:
|
182 |
+
xavier_uniform_(self.in_proj_weight)
|
183 |
+
else:
|
184 |
+
xavier_uniform_(self.q_proj_weight)
|
185 |
+
xavier_uniform_(self.k_proj_weight)
|
186 |
+
xavier_uniform_(self.v_proj_weight)
|
187 |
+
|
188 |
+
if self.in_proj_bias is not None:
|
189 |
+
constant_(self.in_proj_bias, 0.0)
|
190 |
+
constant_(self.out_proj.bias, 0.0)
|
191 |
+
|
192 |
+
if self.bias_k is not None:
|
193 |
+
xavier_normal_(self.bias_k)
|
194 |
+
if self.bias_v is not None:
|
195 |
+
xavier_normal_(self.bias_v)
|
196 |
+
|
197 |
+
def __setstate__(self, state):
|
198 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
199 |
+
if "_qkv_same_embed_dim" not in state:
|
200 |
+
state["_qkv_same_embed_dim"] = True
|
201 |
+
|
202 |
+
super(MultiheadAttention, self).__setstate__(state)
|
203 |
+
|
204 |
+
def forward(
|
205 |
+
self,
|
206 |
+
query: Tensor,
|
207 |
+
key: Tensor,
|
208 |
+
value: Tensor,
|
209 |
+
key_padding_mask: Optional[Tensor] = None,
|
210 |
+
need_weights: bool = True,
|
211 |
+
attn_mask: Optional[Tensor] = None,
|
212 |
+
average_attn_weights: bool = True,
|
213 |
+
cache=None,
|
214 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
215 |
+
r"""
|
216 |
+
Args:
|
217 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
218 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
219 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
220 |
+
Queries are compared against key-value pairs to produce the output.
|
221 |
+
See "Attention Is All You Need" for more details.
|
222 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
223 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
224 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
225 |
+
See "Attention Is All You Need" for more details.
|
226 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
227 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
228 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
229 |
+
See "Attention Is All You Need" for more details.
|
230 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
231 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
232 |
+
Binary and byte masks are supported.
|
233 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
234 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
235 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
236 |
+
Default: ``True``.
|
237 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
238 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
239 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
240 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
241 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
242 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
243 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
244 |
+
the attention weight.
|
245 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
246 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
247 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
248 |
+
|
249 |
+
Outputs:
|
250 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
251 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
252 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
253 |
+
embedding dimension ``embed_dim``.
|
254 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
255 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
256 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
257 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
258 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
259 |
+
|
260 |
+
.. note::
|
261 |
+
`batch_first` argument is ignored for unbatched inputs.
|
262 |
+
"""
|
263 |
+
is_batched = query.dim() == 3
|
264 |
+
if key_padding_mask is not None:
|
265 |
+
_kpm_dtype = key_padding_mask.dtype
|
266 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
267 |
+
key_padding_mask,
|
268 |
+
):
|
269 |
+
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
270 |
+
why_not_fast_path = ""
|
271 |
+
if not is_batched:
|
272 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
273 |
+
elif query is not key or key is not value:
|
274 |
+
# When lifting this restriction, don't forget to either
|
275 |
+
# enforce that the dtypes all match or test cases where
|
276 |
+
# they don't!
|
277 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
278 |
+
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
279 |
+
why_not_fast_path = (
|
280 |
+
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
281 |
+
)
|
282 |
+
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
283 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
284 |
+
why_not_fast_path = (
|
285 |
+
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
286 |
+
)
|
287 |
+
elif self.training:
|
288 |
+
why_not_fast_path = "training is enabled"
|
289 |
+
elif not self.batch_first:
|
290 |
+
why_not_fast_path = "batch_first was not True"
|
291 |
+
elif self.bias_k is not None:
|
292 |
+
why_not_fast_path = "self.bias_k was not None"
|
293 |
+
elif self.bias_v is not None:
|
294 |
+
why_not_fast_path = "self.bias_v was not None"
|
295 |
+
elif self.dropout:
|
296 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
297 |
+
elif self.add_zero_attn:
|
298 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
299 |
+
elif not self._qkv_same_embed_dim:
|
300 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
301 |
+
elif attn_mask is not None:
|
302 |
+
why_not_fast_path = "attn_mask was not None"
|
303 |
+
elif query.is_nested and key_padding_mask is not None:
|
304 |
+
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
305 |
+
elif self.num_heads % 2 == 1:
|
306 |
+
why_not_fast_path = "num_heads is odd"
|
307 |
+
elif torch.is_autocast_enabled():
|
308 |
+
why_not_fast_path = "autocast is enabled"
|
309 |
+
|
310 |
+
if not why_not_fast_path:
|
311 |
+
tensor_args = (
|
312 |
+
query,
|
313 |
+
key,
|
314 |
+
value,
|
315 |
+
self.in_proj_weight,
|
316 |
+
self.in_proj_bias,
|
317 |
+
self.out_proj.weight,
|
318 |
+
self.out_proj.bias,
|
319 |
+
)
|
320 |
+
# We have to use list comprehensions below because TorchScript does not support
|
321 |
+
# generator expressions.
|
322 |
+
if torch.overrides.has_torch_function(tensor_args):
|
323 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
324 |
+
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
325 |
+
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
326 |
+
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
327 |
+
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
328 |
+
if not why_not_fast_path:
|
329 |
+
return torch._native_multi_head_attention(
|
330 |
+
query,
|
331 |
+
key,
|
332 |
+
value,
|
333 |
+
self.embed_dim,
|
334 |
+
self.num_heads,
|
335 |
+
self.in_proj_weight,
|
336 |
+
self.in_proj_bias,
|
337 |
+
self.out_proj.weight,
|
338 |
+
self.out_proj.bias,
|
339 |
+
key_padding_mask if key_padding_mask is not None else attn_mask,
|
340 |
+
need_weights,
|
341 |
+
average_attn_weights,
|
342 |
+
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
343 |
+
)
|
344 |
+
|
345 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
346 |
+
assert not any_nested, (
|
347 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
348 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
349 |
+
)
|
350 |
+
|
351 |
+
if self.batch_first and is_batched:
|
352 |
+
# make sure that the transpose op does not affect the "is" property
|
353 |
+
if key is value:
|
354 |
+
if query is key:
|
355 |
+
query = key = value = query.transpose(1, 0)
|
356 |
+
else:
|
357 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
358 |
+
value = key
|
359 |
+
else:
|
360 |
+
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
361 |
+
|
362 |
+
if not self._qkv_same_embed_dim:
|
363 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
364 |
+
query,
|
365 |
+
key,
|
366 |
+
value,
|
367 |
+
self.embed_dim,
|
368 |
+
self.num_heads,
|
369 |
+
self.in_proj_weight,
|
370 |
+
self.in_proj_bias,
|
371 |
+
self.bias_k,
|
372 |
+
self.bias_v,
|
373 |
+
self.add_zero_attn,
|
374 |
+
self.dropout,
|
375 |
+
self.out_proj.weight,
|
376 |
+
self.out_proj.bias,
|
377 |
+
training=self.training,
|
378 |
+
key_padding_mask=key_padding_mask,
|
379 |
+
need_weights=need_weights,
|
380 |
+
attn_mask=attn_mask,
|
381 |
+
use_separate_proj_weight=True,
|
382 |
+
q_proj_weight=self.q_proj_weight,
|
383 |
+
k_proj_weight=self.k_proj_weight,
|
384 |
+
v_proj_weight=self.v_proj_weight,
|
385 |
+
average_attn_weights=average_attn_weights,
|
386 |
+
cache=cache,
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
390 |
+
query,
|
391 |
+
key,
|
392 |
+
value,
|
393 |
+
self.embed_dim,
|
394 |
+
self.num_heads,
|
395 |
+
self.in_proj_weight,
|
396 |
+
self.in_proj_bias,
|
397 |
+
self.bias_k,
|
398 |
+
self.bias_v,
|
399 |
+
self.add_zero_attn,
|
400 |
+
self.dropout,
|
401 |
+
self.out_proj.weight,
|
402 |
+
self.out_proj.bias,
|
403 |
+
training=self.training,
|
404 |
+
key_padding_mask=key_padding_mask,
|
405 |
+
need_weights=need_weights,
|
406 |
+
attn_mask=attn_mask,
|
407 |
+
average_attn_weights=average_attn_weights,
|
408 |
+
cache=cache,
|
409 |
+
)
|
410 |
+
if self.batch_first and is_batched:
|
411 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
412 |
+
else:
|
413 |
+
return attn_output, attn_output_weights
|
GPT_SoVITS/AR/modules/activation_onnx.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
8 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
9 |
+
from torch.nn.parameter import Parameter
|
10 |
+
|
11 |
+
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
12 |
+
|
13 |
+
|
14 |
+
class MultiheadAttention(Module):
|
15 |
+
__constants__ = ["batch_first"]
|
16 |
+
bias_k: Optional[torch.Tensor]
|
17 |
+
bias_v: Optional[torch.Tensor]
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
embed_dim,
|
22 |
+
num_heads,
|
23 |
+
dropout=0.0,
|
24 |
+
bias=True,
|
25 |
+
add_bias_kv=False,
|
26 |
+
add_zero_attn=False,
|
27 |
+
kdim=None,
|
28 |
+
vdim=None,
|
29 |
+
batch_first=False,
|
30 |
+
linear1_cls=Linear,
|
31 |
+
linear2_cls=Linear,
|
32 |
+
device=None,
|
33 |
+
dtype=None,
|
34 |
+
) -> None:
|
35 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
36 |
+
super(MultiheadAttention, self).__init__()
|
37 |
+
self.embed_dim = embed_dim
|
38 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
39 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
40 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
41 |
+
|
42 |
+
self.num_heads = num_heads
|
43 |
+
self.dropout = dropout
|
44 |
+
self.batch_first = batch_first
|
45 |
+
self.head_dim = embed_dim // num_heads
|
46 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
47 |
+
|
48 |
+
if add_bias_kv:
|
49 |
+
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
50 |
+
self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
51 |
+
else:
|
52 |
+
self.bias_k = self.bias_v = None
|
53 |
+
|
54 |
+
if linear1_cls == Linear:
|
55 |
+
if not self._qkv_same_embed_dim:
|
56 |
+
self.q_proj_weight = Parameter(
|
57 |
+
torch.empty(
|
58 |
+
(embed_dim, embed_dim),
|
59 |
+
**factory_kwargs,
|
60 |
+
)
|
61 |
+
)
|
62 |
+
self.k_proj_weight = Parameter(
|
63 |
+
torch.empty(
|
64 |
+
(embed_dim, self.kdim),
|
65 |
+
**factory_kwargs,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
self.v_proj_weight = Parameter(
|
69 |
+
torch.empty(
|
70 |
+
(embed_dim, self.vdim),
|
71 |
+
**factory_kwargs,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
self.register_parameter("in_proj_weight", None)
|
75 |
+
else:
|
76 |
+
self.in_proj_weight = Parameter(
|
77 |
+
torch.empty(
|
78 |
+
(3 * embed_dim, embed_dim),
|
79 |
+
**factory_kwargs,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
self.register_parameter("q_proj_weight", None)
|
83 |
+
self.register_parameter("k_proj_weight", None)
|
84 |
+
self.register_parameter("v_proj_weight", None)
|
85 |
+
|
86 |
+
if bias:
|
87 |
+
self.in_proj_bias = Parameter(
|
88 |
+
torch.empty(3 * embed_dim, **factory_kwargs),
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
self.register_parameter("in_proj_bias", None)
|
92 |
+
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
93 |
+
|
94 |
+
self._reset_parameters()
|
95 |
+
else:
|
96 |
+
if not self._qkv_same_embed_dim:
|
97 |
+
raise NotImplementedError
|
98 |
+
else:
|
99 |
+
self.in_proj_linear = linear1_cls(
|
100 |
+
embed_dim,
|
101 |
+
3 * embed_dim,
|
102 |
+
bias=bias,
|
103 |
+
**factory_kwargs,
|
104 |
+
)
|
105 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
106 |
+
|
107 |
+
self.register_parameter("q_proj_weight", None)
|
108 |
+
self.register_parameter("k_proj_weight", None)
|
109 |
+
self.register_parameter("v_proj_weight", None)
|
110 |
+
|
111 |
+
if bias:
|
112 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
113 |
+
else:
|
114 |
+
self.register_parameter("in_proj_bias", None)
|
115 |
+
|
116 |
+
self.out_proj = linear2_cls(
|
117 |
+
embed_dim,
|
118 |
+
embed_dim,
|
119 |
+
bias=bias,
|
120 |
+
**factory_kwargs,
|
121 |
+
)
|
122 |
+
|
123 |
+
if self.bias_k is not None:
|
124 |
+
xavier_normal_(self.bias_k)
|
125 |
+
if self.bias_v is not None:
|
126 |
+
xavier_normal_(self.bias_v)
|
127 |
+
|
128 |
+
self.add_zero_attn = add_zero_attn
|
129 |
+
|
130 |
+
def _reset_parameters(self):
|
131 |
+
if self._qkv_same_embed_dim:
|
132 |
+
xavier_uniform_(self.in_proj_weight)
|
133 |
+
else:
|
134 |
+
xavier_uniform_(self.q_proj_weight)
|
135 |
+
xavier_uniform_(self.k_proj_weight)
|
136 |
+
xavier_uniform_(self.v_proj_weight)
|
137 |
+
|
138 |
+
if self.in_proj_bias is not None:
|
139 |
+
constant_(self.in_proj_bias, 0.0)
|
140 |
+
constant_(self.out_proj.bias, 0.0)
|
141 |
+
|
142 |
+
if self.bias_k is not None:
|
143 |
+
xavier_normal_(self.bias_k)
|
144 |
+
if self.bias_v is not None:
|
145 |
+
xavier_normal_(self.bias_v)
|
146 |
+
|
147 |
+
def __setstate__(self, state):
|
148 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
149 |
+
if "_qkv_same_embed_dim" not in state:
|
150 |
+
state["_qkv_same_embed_dim"] = True
|
151 |
+
|
152 |
+
super(MultiheadAttention, self).__setstate__(state)
|
153 |
+
|
154 |
+
def forward(
|
155 |
+
self,
|
156 |
+
query: Tensor,
|
157 |
+
key: Tensor,
|
158 |
+
value: Tensor,
|
159 |
+
key_padding_mask: Optional[Tensor] = None,
|
160 |
+
need_weights: bool = True,
|
161 |
+
attn_mask: Optional[Tensor] = None,
|
162 |
+
average_attn_weights: bool = True,
|
163 |
+
cache=None,
|
164 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
165 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
166 |
+
query = key = value = query.transpose(1, 0)
|
167 |
+
attn_output = multi_head_attention_forward_patched(
|
168 |
+
query,
|
169 |
+
key,
|
170 |
+
value,
|
171 |
+
self.embed_dim,
|
172 |
+
self.num_heads,
|
173 |
+
self.in_proj_weight,
|
174 |
+
self.in_proj_bias,
|
175 |
+
self.bias_k,
|
176 |
+
self.bias_v,
|
177 |
+
self.add_zero_attn,
|
178 |
+
self.dropout,
|
179 |
+
self.out_proj.weight,
|
180 |
+
self.out_proj.bias,
|
181 |
+
training=self.training,
|
182 |
+
key_padding_mask=key_padding_mask,
|
183 |
+
need_weights=need_weights,
|
184 |
+
attn_mask=attn_mask,
|
185 |
+
average_attn_weights=average_attn_weights,
|
186 |
+
cache=cache,
|
187 |
+
)
|
188 |
+
return attn_output.transpose(1, 0)
|
GPT_SoVITS/AR/modules/embedding.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class TokenEmbedding(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
embedding_dim: int,
|
12 |
+
vocab_size: int,
|
13 |
+
dropout: float = 0.0,
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.vocab_size = vocab_size
|
18 |
+
self.embedding_dim = embedding_dim
|
19 |
+
|
20 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
21 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def weight(self) -> torch.Tensor:
|
25 |
+
return self.word_embeddings.weight
|
26 |
+
|
27 |
+
def embedding(self, index: int) -> torch.Tensor:
|
28 |
+
return self.word_embeddings.weight[index : index + 1]
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor):
|
31 |
+
x = self.word_embeddings(x)
|
32 |
+
x = self.dropout(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class SinePositionalEmbedding(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
embedding_dim: int,
|
40 |
+
dropout: float = 0.0,
|
41 |
+
scale: bool = False,
|
42 |
+
alpha: bool = False,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.embedding_dim = embedding_dim
|
46 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
47 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
48 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
49 |
+
|
50 |
+
self.reverse = False
|
51 |
+
self.pe = None
|
52 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
53 |
+
|
54 |
+
def extend_pe(self, x):
|
55 |
+
"""Reset the positional encodings."""
|
56 |
+
if self.pe is not None:
|
57 |
+
if self.pe.size(1) >= x.size(1):
|
58 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
59 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
60 |
+
return
|
61 |
+
pe = torch.zeros(x.size(1), self.embedding_dim)
|
62 |
+
if self.reverse:
|
63 |
+
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
64 |
+
else:
|
65 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
66 |
+
div_term = torch.exp(
|
67 |
+
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
68 |
+
)
|
69 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
70 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
71 |
+
pe = pe.unsqueeze(0)
|
72 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
75 |
+
self.extend_pe(x)
|
76 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
77 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
78 |
+
return self.dropout(output)
|
GPT_SoVITS/AR/modules/embedding_onnx.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class TokenEmbedding(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
embedding_dim: int,
|
12 |
+
vocab_size: int,
|
13 |
+
dropout: float = 0.0,
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.vocab_size = vocab_size
|
18 |
+
self.embedding_dim = embedding_dim
|
19 |
+
|
20 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
21 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def weight(self) -> torch.Tensor:
|
25 |
+
return self.word_embeddings.weight
|
26 |
+
|
27 |
+
def embedding(self, index: int) -> torch.Tensor:
|
28 |
+
return self.word_embeddings.weight[index : index + 1]
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor):
|
31 |
+
x = self.word_embeddings(x)
|
32 |
+
x = self.dropout(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class SinePositionalEmbedding(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
embedding_dim: int,
|
40 |
+
dropout: float = 0.0,
|
41 |
+
scale: bool = False,
|
42 |
+
alpha: bool = False,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.embedding_dim = embedding_dim
|
46 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
47 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
48 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
49 |
+
self.reverse = False
|
50 |
+
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
51 |
+
|
52 |
+
def extend_pe(self, x):
|
53 |
+
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
|
54 |
+
scpe = (position * self.div_term).unsqueeze(0)
|
55 |
+
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
56 |
+
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
57 |
+
return pe
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
60 |
+
pe = self.extend_pe(x)
|
61 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
62 |
+
output = output * self.x_scale + self.alpha * pe
|
63 |
+
return self.dropout(output)
|
GPT_SoVITS/AR/modules/lr_schedulers.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from torch import nn
|
8 |
+
from torch.optim import Adam
|
9 |
+
|
10 |
+
|
11 |
+
class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
12 |
+
"""
|
13 |
+
Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
optimizer,
|
19 |
+
init_lr,
|
20 |
+
peak_lr,
|
21 |
+
end_lr,
|
22 |
+
warmup_steps=10000,
|
23 |
+
total_steps=400000,
|
24 |
+
current_step=0,
|
25 |
+
):
|
26 |
+
self.init_lr = init_lr
|
27 |
+
self.peak_lr = peak_lr
|
28 |
+
self.end_lr = end_lr
|
29 |
+
self.optimizer = optimizer
|
30 |
+
self._warmup_rate = (peak_lr - init_lr) / warmup_steps
|
31 |
+
self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
|
32 |
+
self._current_step = current_step
|
33 |
+
self.lr = init_lr
|
34 |
+
self.warmup_steps = warmup_steps
|
35 |
+
self.total_steps = total_steps
|
36 |
+
self._last_lr = [self.lr]
|
37 |
+
|
38 |
+
def set_lr(self, lr):
|
39 |
+
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
|
40 |
+
for g in self.optimizer.param_groups:
|
41 |
+
# g['lr'] = lr
|
42 |
+
g["lr"] = self.end_lr ###锁定用线性
|
43 |
+
|
44 |
+
def step(self):
|
45 |
+
if self._current_step < self.warmup_steps:
|
46 |
+
lr = self.init_lr + self._warmup_rate * self._current_step
|
47 |
+
|
48 |
+
elif self._current_step > self.total_steps:
|
49 |
+
lr = self.end_lr
|
50 |
+
|
51 |
+
else:
|
52 |
+
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
53 |
+
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
54 |
+
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
55 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
56 |
+
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
57 |
+
|
58 |
+
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
|
59 |
+
self.set_lr(lr)
|
60 |
+
self.lr = lr
|
61 |
+
self._current_step += 1
|
62 |
+
return self.lr
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
m = nn.Linear(10, 10)
|
67 |
+
opt = Adam(m.parameters(), lr=1e-4)
|
68 |
+
s = WarmupCosineLRSchedule(
|
69 |
+
opt,
|
70 |
+
1e-6,
|
71 |
+
2e-4,
|
72 |
+
1e-6,
|
73 |
+
warmup_steps=2000,
|
74 |
+
total_steps=20000,
|
75 |
+
current_step=0,
|
76 |
+
)
|
77 |
+
lrs = []
|
78 |
+
for i in range(25000):
|
79 |
+
s.step()
|
80 |
+
lrs.append(s.lr)
|
81 |
+
print(s.lr)
|
82 |
+
|
83 |
+
plt.plot(lrs)
|
84 |
+
plt.plot(range(0, 25000), lrs)
|
85 |
+
plt.show()
|
GPT_SoVITS/AR/modules/optim.py
ADDED
@@ -0,0 +1,593 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import contextlib
|
17 |
+
import logging
|
18 |
+
from collections import defaultdict
|
19 |
+
from typing import List, Tuple
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from torch import Tensor
|
23 |
+
from torch.optim import Optimizer
|
24 |
+
|
25 |
+
|
26 |
+
class BatchedOptimizer(Optimizer):
|
27 |
+
"""
|
28 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
29 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
30 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
31 |
+
as it reduces the number of kernels launched in the optimizer.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
params:
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, params, defaults):
|
38 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
39 |
+
|
40 |
+
@contextlib.contextmanager
|
41 |
+
def batched_params(self, param_group, group_params_names):
|
42 |
+
"""
|
43 |
+
This function returns (technically, yields) a list of
|
44 |
+
of tuples (p, state), where
|
45 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
46 |
+
that share the same shape, and its gradient is also stacked;
|
47 |
+
`state` is the state corresponding to this batch of parameters
|
48 |
+
(it will be physically located in the "state" for one of the real
|
49 |
+
parameters, the last one that has any particular shape and dtype).
|
50 |
+
|
51 |
+
This function is decorated as a context manager so that it can
|
52 |
+
write parameters back to their "real" locations.
|
53 |
+
|
54 |
+
The idea is, instead of doing:
|
55 |
+
<code>
|
56 |
+
for p in group["params"]:
|
57 |
+
state = self.state[p]
|
58 |
+
...
|
59 |
+
</code>
|
60 |
+
you can do:
|
61 |
+
<code>
|
62 |
+
with self.batched_params(group["params"]) as batches:
|
63 |
+
for p, state, p_names in batches:
|
64 |
+
...
|
65 |
+
</code>
|
66 |
+
|
67 |
+
Args:
|
68 |
+
group: a parameter group, which is a list of parameters; should be
|
69 |
+
one of self.param_groups.
|
70 |
+
group_params_names: name for each parameter in group,
|
71 |
+
which is List[str].
|
72 |
+
"""
|
73 |
+
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
74 |
+
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
75 |
+
|
76 |
+
assert len(param_group) == len(group_params_names)
|
77 |
+
for p, named_p in zip(param_group, group_params_names):
|
78 |
+
key = (str(p.dtype), *p.shape)
|
79 |
+
batches[key].append(p)
|
80 |
+
batches_names[key].append(named_p)
|
81 |
+
|
82 |
+
batches_names_keys = list(batches_names.keys())
|
83 |
+
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
84 |
+
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
85 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
86 |
+
|
87 |
+
stacked_params_dict = dict()
|
88 |
+
|
89 |
+
# turn batches into a list, in deterministic order.
|
90 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
91 |
+
# one for each batch in `batches`.
|
92 |
+
tuples = []
|
93 |
+
|
94 |
+
for batch, batch_names in zip(batches, batches_names):
|
95 |
+
p = batch[0]
|
96 |
+
# we arbitrarily store the state in the
|
97 |
+
# state corresponding to the 1st parameter in the
|
98 |
+
# group. class Optimizer will take care of saving/loading state.
|
99 |
+
state = self.state[p]
|
100 |
+
p_stacked = torch.stack(batch)
|
101 |
+
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
102 |
+
p_stacked.grad = grad
|
103 |
+
stacked_params_dict[key] = p_stacked
|
104 |
+
tuples.append((p_stacked, state, batch_names))
|
105 |
+
|
106 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
107 |
+
|
108 |
+
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
109 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
110 |
+
p.copy_(stacked_params[i])
|
111 |
+
|
112 |
+
|
113 |
+
class ScaledAdam(BatchedOptimizer):
|
114 |
+
"""
|
115 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
116 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
117 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
118 |
+
param = underlying_param * log_scale.exp())
|
119 |
+
|
120 |
+
|
121 |
+
Args:
|
122 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
123 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
124 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
125 |
+
optimizers.
|
126 |
+
clipping_scale: (e.g. 2.0)
|
127 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
128 |
+
over the whole model will be clipped to have 2-norm equal to
|
129 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
130 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
131 |
+
we mean after multiplying by the rms parameter value for this tensor
|
132 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
133 |
+
by this quantity.
|
134 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
135 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
136 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
137 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
138 |
+
If each parameter were decomposed
|
139 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
140 |
+
would be a the scaling factor on the learning rate of p_scale.
|
141 |
+
eps: A general-purpose epsilon to prevent division by zero
|
142 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
143 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
144 |
+
parameter tensor to be >= this value)
|
145 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
146 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
147 |
+
parameter tensor to be <= this value)
|
148 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
149 |
+
model has any parameters with numel() == 1).
|
150 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
151 |
+
of the parameter tensor. This is provided to save a little time
|
152 |
+
in the update.
|
153 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
154 |
+
"""
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
params,
|
159 |
+
lr=3e-02,
|
160 |
+
clipping_scale=None,
|
161 |
+
betas=(0.9, 0.98),
|
162 |
+
scalar_lr_scale=0.1,
|
163 |
+
eps=1.0e-08,
|
164 |
+
param_min_rms=1.0e-05,
|
165 |
+
param_max_rms=3.0,
|
166 |
+
scalar_max=10.0,
|
167 |
+
size_update_period=4,
|
168 |
+
clipping_update_period=100,
|
169 |
+
parameters_names=None,
|
170 |
+
show_dominant_parameters=True,
|
171 |
+
):
|
172 |
+
assert parameters_names is not None, (
|
173 |
+
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
174 |
+
)
|
175 |
+
defaults = dict(
|
176 |
+
lr=lr,
|
177 |
+
clipping_scale=clipping_scale,
|
178 |
+
betas=betas,
|
179 |
+
scalar_lr_scale=scalar_lr_scale,
|
180 |
+
eps=eps,
|
181 |
+
param_min_rms=param_min_rms,
|
182 |
+
param_max_rms=param_max_rms,
|
183 |
+
scalar_max=scalar_max,
|
184 |
+
size_update_period=size_update_period,
|
185 |
+
clipping_update_period=clipping_update_period,
|
186 |
+
)
|
187 |
+
|
188 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
189 |
+
assert len(self.param_groups) == len(parameters_names)
|
190 |
+
self.parameters_names = parameters_names
|
191 |
+
self.show_dominant_parameters = show_dominant_parameters
|
192 |
+
|
193 |
+
def __setstate__(self, state):
|
194 |
+
super(ScaledAdam, self).__setstate__(state)
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def step(self, closure=None):
|
198 |
+
"""Performs a single optimization step.
|
199 |
+
|
200 |
+
Arguments:
|
201 |
+
closure (callable, optional): A closure that reevaluates the model
|
202 |
+
and returns the loss.
|
203 |
+
"""
|
204 |
+
loss = None
|
205 |
+
if closure is not None:
|
206 |
+
with torch.enable_grad():
|
207 |
+
loss = closure()
|
208 |
+
|
209 |
+
batch = True
|
210 |
+
|
211 |
+
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
212 |
+
with self.batched_params(group["params"], group_params_names) as batches:
|
213 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
214 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
215 |
+
# a stacking dim, it is not a real dim.
|
216 |
+
|
217 |
+
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
218 |
+
clipping_scale = 1
|
219 |
+
else:
|
220 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
221 |
+
|
222 |
+
for p, state, _ in batches:
|
223 |
+
# Perform optimization step.
|
224 |
+
# grad is not going to be None, we handled that when creating the batches.
|
225 |
+
grad = p.grad
|
226 |
+
if grad.is_sparse:
|
227 |
+
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
228 |
+
# State initialization
|
229 |
+
if len(state) == 0:
|
230 |
+
self._init_state(group, p, state)
|
231 |
+
|
232 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
233 |
+
|
234 |
+
return loss
|
235 |
+
|
236 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
237 |
+
"""
|
238 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
239 |
+
is actually the batch dimension, corresponding to batched-together
|
240 |
+
parameters of a given shape.
|
241 |
+
|
242 |
+
|
243 |
+
Args:
|
244 |
+
group: Dict to look up configuration values.
|
245 |
+
p: The parameter that we are initializing the state for
|
246 |
+
state: Dict from string to whatever state we are initializing
|
247 |
+
"""
|
248 |
+
size_update_period = group["size_update_period"]
|
249 |
+
|
250 |
+
state["step"] = 0
|
251 |
+
|
252 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
253 |
+
|
254 |
+
# 'delta' implements conventional momentum. There are
|
255 |
+
# several different kinds of update going on, so rather than
|
256 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
257 |
+
# parameter-change "delta", which combines all forms of
|
258 |
+
# update. this is equivalent to how it's done in Adam,
|
259 |
+
# except for the first few steps.
|
260 |
+
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
261 |
+
|
262 |
+
batch_size = p.shape[0]
|
263 |
+
numel = p.numel() // batch_size
|
264 |
+
numel = p.numel()
|
265 |
+
|
266 |
+
if numel > 1:
|
267 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
268 |
+
# the parameter tensor.
|
269 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
270 |
+
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
271 |
+
state["param_rms"] = param_rms
|
272 |
+
|
273 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
274 |
+
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
275 |
+
|
276 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
277 |
+
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
278 |
+
|
279 |
+
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
280 |
+
"""
|
281 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
282 |
+
by this amount before applying the rest of the update.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
group: the parameter group, an item in self.param_groups
|
286 |
+
tuples: a list of tuples of (param, state, param_names)
|
287 |
+
where param is a batched set of parameters,
|
288 |
+
with a .grad (1st dim is batch dim)
|
289 |
+
and state is the state-dict where optimization parameters are kept.
|
290 |
+
param_names is a List[str] while each str is name for a parameter
|
291 |
+
in batched set of parameters "param".
|
292 |
+
"""
|
293 |
+
assert len(tuples) >= 1
|
294 |
+
clipping_scale = group["clipping_scale"]
|
295 |
+
(first_p, first_state, _) = tuples[0]
|
296 |
+
step = first_state["step"]
|
297 |
+
if clipping_scale is None or step == 0:
|
298 |
+
# no clipping. return early on step == 0 because the other
|
299 |
+
# parameters' state won't have been initialized yet.
|
300 |
+
return 1.0
|
301 |
+
clipping_update_period = group["clipping_update_period"]
|
302 |
+
|
303 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
304 |
+
for p, state, param_names in tuples:
|
305 |
+
grad = p.grad
|
306 |
+
if grad.is_sparse:
|
307 |
+
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
308 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
309 |
+
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
310 |
+
else:
|
311 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
312 |
+
|
313 |
+
tot_norm = tot_sumsq.sqrt()
|
314 |
+
if "model_norms" not in first_state:
|
315 |
+
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
316 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
317 |
+
|
318 |
+
if step % clipping_update_period == 0:
|
319 |
+
# Print some stats.
|
320 |
+
# We don't reach here if step == 0 because we would have returned
|
321 |
+
# above.
|
322 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
323 |
+
quartiles = []
|
324 |
+
for n in range(0, 5):
|
325 |
+
index = min(
|
326 |
+
clipping_update_period - 1,
|
327 |
+
(clipping_update_period // 4) * n,
|
328 |
+
)
|
329 |
+
quartiles.append(sorted_norms[index].item())
|
330 |
+
|
331 |
+
median = quartiles[2]
|
332 |
+
threshold = clipping_scale * median
|
333 |
+
first_state["model_norm_threshold"] = threshold
|
334 |
+
percent_clipped = (
|
335 |
+
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
336 |
+
)
|
337 |
+
first_state["num_clipped"] = 0
|
338 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
339 |
+
logging.info(
|
340 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
341 |
+
)
|
342 |
+
|
343 |
+
if step < clipping_update_period:
|
344 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
345 |
+
else:
|
346 |
+
try:
|
347 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
348 |
+
except KeyError:
|
349 |
+
logging.info(
|
350 |
+
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
351 |
+
)
|
352 |
+
return 1.0
|
353 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
354 |
+
if ans < 1.0:
|
355 |
+
first_state["num_clipped"] += 1
|
356 |
+
if ans < 0.1:
|
357 |
+
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
358 |
+
if self.show_dominant_parameters:
|
359 |
+
assert p.shape[0] == len(param_names)
|
360 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
361 |
+
return ans
|
362 |
+
|
363 |
+
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
364 |
+
"""
|
365 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
tuples: a list of tuples of (param, state, param_names)
|
369 |
+
where param is a batched set of parameters,
|
370 |
+
with a .grad (1st dim is batch dim)
|
371 |
+
and state is the state-dict where optimization parameters are kept.
|
372 |
+
param_names is a List[str] while each str is name for a parameter
|
373 |
+
in batched set of parameters "param".
|
374 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
375 |
+
from tuples, we still pass it to save some time.
|
376 |
+
"""
|
377 |
+
all_sumsq_orig = {}
|
378 |
+
for p, state, batch_param_names in tuples:
|
379 |
+
# p is a stacked batch parameters.
|
380 |
+
batch_grad = p.grad
|
381 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
382 |
+
batch_sumsq_orig = batch_grad**2
|
383 |
+
# Dummpy values used by following `zip` statement.
|
384 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
385 |
+
else:
|
386 |
+
batch_rms_orig = state["param_rms"]
|
387 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
388 |
+
|
389 |
+
for name, sumsq_orig, rms, grad in zip(
|
390 |
+
batch_param_names,
|
391 |
+
batch_sumsq_orig,
|
392 |
+
batch_rms_orig,
|
393 |
+
batch_grad,
|
394 |
+
):
|
395 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
396 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
397 |
+
|
398 |
+
assert torch.isclose(
|
399 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
400 |
+
torch.tensor(1.0),
|
401 |
+
)
|
402 |
+
sorted_by_proportion = {
|
403 |
+
k: v
|
404 |
+
for k, v in sorted(
|
405 |
+
all_sumsq_orig.items(),
|
406 |
+
key=lambda item: item[1][0],
|
407 |
+
reverse=True,
|
408 |
+
)
|
409 |
+
}
|
410 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
411 |
+
(
|
412 |
+
dominant_proportion,
|
413 |
+
dominant_sumsq,
|
414 |
+
dominant_rms,
|
415 |
+
dominant_grad,
|
416 |
+
) = sorted_by_proportion[dominant_param_name]
|
417 |
+
logging.info(
|
418 |
+
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
419 |
+
f" with proportion {dominant_proportion:.2f},"
|
420 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
421 |
+
f"={dominant_sumsq:.3e},"
|
422 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
423 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
424 |
+
)
|
425 |
+
|
426 |
+
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
427 |
+
"""
|
428 |
+
Do the step for one parameter, which is actually going to be a batch of
|
429 |
+
`real` parameters, with dim 0 as the batch dim.
|
430 |
+
Args:
|
431 |
+
group: dict to look up configuration values
|
432 |
+
p: parameter to update (actually multiple parameters stacked together
|
433 |
+
as a batch)
|
434 |
+
state: state-dict for p, to look up the optimizer state
|
435 |
+
"""
|
436 |
+
lr = group["lr"]
|
437 |
+
size_update_period = group["size_update_period"]
|
438 |
+
beta1 = group["betas"][0]
|
439 |
+
|
440 |
+
grad = p.grad
|
441 |
+
if clipping_scale != 1.0:
|
442 |
+
grad = grad * clipping_scale
|
443 |
+
step = state["step"]
|
444 |
+
delta = state["delta"]
|
445 |
+
|
446 |
+
delta.mul_(beta1)
|
447 |
+
batch_size = p.shape[0]
|
448 |
+
numel = p.numel() // batch_size
|
449 |
+
if numel > 1:
|
450 |
+
# Update the size/scale of p, and set param_rms
|
451 |
+
scale_grads = state["scale_grads"]
|
452 |
+
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
453 |
+
if step % size_update_period == size_update_period - 1:
|
454 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
455 |
+
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
456 |
+
if step > 0:
|
457 |
+
# self._size_update() learns the overall scale on the
|
458 |
+
# parameter, by shrinking or expanding it.
|
459 |
+
self._size_update(group, scale_grads, p, state)
|
460 |
+
|
461 |
+
if numel == 1:
|
462 |
+
# For parameters with 1 element we just use regular Adam.
|
463 |
+
# Updates delta.
|
464 |
+
self._step_scalar(group, p, state)
|
465 |
+
else:
|
466 |
+
self._step(group, p, state)
|
467 |
+
|
468 |
+
state["step"] = step + 1
|
469 |
+
|
470 |
+
def _size_update(
|
471 |
+
self,
|
472 |
+
group: dict,
|
473 |
+
scale_grads: Tensor,
|
474 |
+
p: Tensor,
|
475 |
+
state: dict,
|
476 |
+
) -> None:
|
477 |
+
"""
|
478 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
479 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
480 |
+
gradient descent on underlying param and on scale, this function does the update
|
481 |
+
on `scale`.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
group: dict to look up configuration values
|
485 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
486 |
+
grads w.r.t. the scales.
|
487 |
+
p: The parameter to update
|
488 |
+
state: The state-dict of p
|
489 |
+
"""
|
490 |
+
|
491 |
+
param_rms = state["param_rms"]
|
492 |
+
beta1, beta2 = group["betas"]
|
493 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
494 |
+
param_min_rms = group["param_min_rms"]
|
495 |
+
param_max_rms = group["param_max_rms"]
|
496 |
+
eps = group["eps"]
|
497 |
+
step = state["step"]
|
498 |
+
batch_size = p.shape[0]
|
499 |
+
|
500 |
+
size_update_period = scale_grads.shape[0]
|
501 |
+
# correct beta2 for the size update period: we will have
|
502 |
+
# faster decay at this level.
|
503 |
+
beta2_corr = beta2**size_update_period
|
504 |
+
|
505 |
+
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
506 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
507 |
+
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
508 |
+
alpha=1 - beta2_corr,
|
509 |
+
) # shape is (batch_size, 1, 1, ...)
|
510 |
+
|
511 |
+
# The 1st time we reach here is when size_step == 1.
|
512 |
+
size_step = (step + 1) // size_update_period
|
513 |
+
bias_correction2 = 1 - beta2_corr**size_step
|
514 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
515 |
+
# at the start of training.
|
516 |
+
|
517 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
518 |
+
|
519 |
+
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
520 |
+
|
521 |
+
is_too_small = param_rms < param_min_rms
|
522 |
+
is_too_large = param_rms > param_max_rms
|
523 |
+
|
524 |
+
# when the param gets too small, just don't shrink it any further.
|
525 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
526 |
+
# when it gets too large, stop it from getting any larger.
|
527 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
528 |
+
delta = state["delta"]
|
529 |
+
# the factor of (1-beta1) relates to momentum.
|
530 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
531 |
+
|
532 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
533 |
+
"""
|
534 |
+
This function does the core update of self.step(), in the case where the members of
|
535 |
+
the batch have more than 1 element.
|
536 |
+
|
537 |
+
Args:
|
538 |
+
group: A dict which will be used to look up configuration values
|
539 |
+
p: The parameter to be updated
|
540 |
+
grad: The grad of p
|
541 |
+
state: The state-dict corresponding to parameter p
|
542 |
+
|
543 |
+
This function modifies p.
|
544 |
+
"""
|
545 |
+
grad = p.grad
|
546 |
+
lr = group["lr"]
|
547 |
+
beta1, beta2 = group["betas"]
|
548 |
+
eps = group["eps"]
|
549 |
+
param_min_rms = group["param_min_rms"]
|
550 |
+
step = state["step"]
|
551 |
+
|
552 |
+
exp_avg_sq = state["exp_avg_sq"]
|
553 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
554 |
+
|
555 |
+
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
556 |
+
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
557 |
+
if bias_correction2 < 0.99:
|
558 |
+
# note: not in-place.
|
559 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
560 |
+
|
561 |
+
denom = exp_avg_sq.sqrt()
|
562 |
+
denom += eps
|
563 |
+
grad = grad / denom
|
564 |
+
|
565 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
566 |
+
|
567 |
+
delta = state["delta"]
|
568 |
+
delta.add_(grad * alpha)
|
569 |
+
p.add_(delta)
|
570 |
+
|
571 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
572 |
+
"""
|
573 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
574 |
+
estimate of the parameter rms.
|
575 |
+
"""
|
576 |
+
beta1, beta2 = group["betas"]
|
577 |
+
scalar_max = group["scalar_max"]
|
578 |
+
eps = group["eps"]
|
579 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
580 |
+
grad = p.grad
|
581 |
+
|
582 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
583 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
584 |
+
|
585 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
586 |
+
# slower update at the start will help stability anyway.
|
587 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
588 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
589 |
+
|
590 |
+
delta = state["delta"]
|
591 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
592 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
593 |
+
p.add_(delta)
|
GPT_SoVITS/AR/modules/patched_mha_with_cache.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.functional import *
|
2 |
+
from torch.nn.functional import (
|
3 |
+
_mha_shape_check,
|
4 |
+
_canonical_mask,
|
5 |
+
_none_or_dtype,
|
6 |
+
_in_projection_packed,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
# Tensor = torch.Tensor
|
10 |
+
# from typing import Callable, List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
|
13 |
+
def multi_head_attention_forward_patched(
|
14 |
+
query,
|
15 |
+
key,
|
16 |
+
value,
|
17 |
+
embed_dim_to_check,
|
18 |
+
num_heads,
|
19 |
+
in_proj_weight,
|
20 |
+
in_proj_bias,
|
21 |
+
bias_k,
|
22 |
+
bias_v,
|
23 |
+
add_zero_attn,
|
24 |
+
dropout_p: float,
|
25 |
+
out_proj_weight,
|
26 |
+
out_proj_bias,
|
27 |
+
training=True,
|
28 |
+
key_padding_mask=None,
|
29 |
+
need_weights=True,
|
30 |
+
attn_mask=None,
|
31 |
+
use_separate_proj_weight=False,
|
32 |
+
q_proj_weight=None,
|
33 |
+
k_proj_weight=None,
|
34 |
+
v_proj_weight=None,
|
35 |
+
static_k=None,
|
36 |
+
static_v=None,
|
37 |
+
average_attn_weights=True,
|
38 |
+
is_causal=False,
|
39 |
+
cache=None,
|
40 |
+
):
|
41 |
+
r"""
|
42 |
+
Args:
|
43 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
44 |
+
See "Attention Is All You Need" for more details.
|
45 |
+
embed_dim_to_check: total dimension of the model.
|
46 |
+
num_heads: parallel attention heads.
|
47 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
48 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
49 |
+
add_zero_attn: add a new batch of zeros to the key and
|
50 |
+
value sequences at dim=1.
|
51 |
+
dropout_p: probability of an element to be zeroed.
|
52 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
53 |
+
training: apply dropout if is ``True``.
|
54 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
55 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
56 |
+
the corresponding value on the attention layer will be filled with -inf.
|
57 |
+
need_weights: output attn_output_weights.
|
58 |
+
Default: `True`
|
59 |
+
Note: `needs_weight` defaults to `True`, but should be set to `False`
|
60 |
+
For best performance when attention weights are not nedeeded.
|
61 |
+
*Setting needs_weights to `True`
|
62 |
+
leads to a significant performance degradation.*
|
63 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
64 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
65 |
+
is_causal: If specified, applies a causal mask as attention mask, and ignores
|
66 |
+
attn_mask for computing scaled dot product attention.
|
67 |
+
Default: ``False``.
|
68 |
+
.. warning::
|
69 |
+
is_causal is provides a hint that the attn_mask is the
|
70 |
+
causal mask.Providing incorrect hints can result in
|
71 |
+
incorrect execution, including forward and backward
|
72 |
+
compatibility.
|
73 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
74 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
75 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
76 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
77 |
+
static_k, static_v: static key and value used for attention operators.
|
78 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
|
79 |
+
Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
|
80 |
+
when ``need_weights=True.``. Default: True
|
81 |
+
|
82 |
+
|
83 |
+
Shape:
|
84 |
+
Inputs:
|
85 |
+
- query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
86 |
+
the embedding dimension.
|
87 |
+
- key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
88 |
+
the embedding dimension.
|
89 |
+
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
90 |
+
the embedding dimension.
|
91 |
+
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
92 |
+
If a FloatTensor is provided, it will be directly added to the value.
|
93 |
+
If a BoolTensor is provided, the positions with the
|
94 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
95 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
96 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
97 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
98 |
+
positions. If a BoolTensor is provided, positions with ``True``
|
99 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
100 |
+
is provided, it will be added to the attention weight.
|
101 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
102 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
103 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
104 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
105 |
+
|
106 |
+
Outputs:
|
107 |
+
- attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
108 |
+
E is the embedding dimension.
|
109 |
+
- attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
|
110 |
+
attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
111 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
112 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
113 |
+
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
|
114 |
+
"""
|
115 |
+
tens_ops = (
|
116 |
+
query,
|
117 |
+
key,
|
118 |
+
value,
|
119 |
+
in_proj_weight,
|
120 |
+
in_proj_bias,
|
121 |
+
bias_k,
|
122 |
+
bias_v,
|
123 |
+
out_proj_weight,
|
124 |
+
out_proj_bias,
|
125 |
+
)
|
126 |
+
if has_torch_function(tens_ops):
|
127 |
+
return handle_torch_function(
|
128 |
+
multi_head_attention_forward,
|
129 |
+
tens_ops,
|
130 |
+
query,
|
131 |
+
key,
|
132 |
+
value,
|
133 |
+
embed_dim_to_check,
|
134 |
+
num_heads,
|
135 |
+
in_proj_weight,
|
136 |
+
in_proj_bias,
|
137 |
+
bias_k,
|
138 |
+
bias_v,
|
139 |
+
add_zero_attn,
|
140 |
+
dropout_p,
|
141 |
+
out_proj_weight,
|
142 |
+
out_proj_bias,
|
143 |
+
training=training,
|
144 |
+
key_padding_mask=key_padding_mask,
|
145 |
+
need_weights=need_weights,
|
146 |
+
attn_mask=attn_mask,
|
147 |
+
is_causal=is_causal,
|
148 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
149 |
+
q_proj_weight=q_proj_weight,
|
150 |
+
k_proj_weight=k_proj_weight,
|
151 |
+
v_proj_weight=v_proj_weight,
|
152 |
+
static_k=static_k,
|
153 |
+
static_v=static_v,
|
154 |
+
average_attn_weights=average_attn_weights,
|
155 |
+
cache=cache,
|
156 |
+
)
|
157 |
+
|
158 |
+
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
159 |
+
|
160 |
+
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
161 |
+
# is batched, run the computation and before returning squeeze the
|
162 |
+
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
163 |
+
if not is_batched:
|
164 |
+
# unsqueeze if the input is unbatched
|
165 |
+
query = query.unsqueeze(1)
|
166 |
+
key = key.unsqueeze(1)
|
167 |
+
value = value.unsqueeze(1)
|
168 |
+
if key_padding_mask is not None:
|
169 |
+
key_padding_mask = key_padding_mask.unsqueeze(0)
|
170 |
+
|
171 |
+
# set up shape vars
|
172 |
+
tgt_len, bsz, embed_dim = query.shape
|
173 |
+
src_len, _, _ = key.shape
|
174 |
+
|
175 |
+
key_padding_mask = _canonical_mask(
|
176 |
+
mask=key_padding_mask,
|
177 |
+
mask_name="key_padding_mask",
|
178 |
+
other_type=_none_or_dtype(attn_mask),
|
179 |
+
other_name="attn_mask",
|
180 |
+
target_type=query.dtype,
|
181 |
+
)
|
182 |
+
|
183 |
+
if is_causal and attn_mask is None:
|
184 |
+
raise RuntimeError(
|
185 |
+
"Need attn_mask if specifying the is_causal hint. "
|
186 |
+
"You may use the Transformer module method "
|
187 |
+
"`generate_square_subsequent_mask` to create this mask."
|
188 |
+
)
|
189 |
+
|
190 |
+
if is_causal and key_padding_mask is None and not need_weights:
|
191 |
+
# when we have a kpm or need weights, we need attn_mask
|
192 |
+
# Otherwise, we use the is_causal hint go as is_causal
|
193 |
+
# indicator to SDPA.
|
194 |
+
attn_mask = None
|
195 |
+
else:
|
196 |
+
attn_mask = _canonical_mask(
|
197 |
+
mask=attn_mask,
|
198 |
+
mask_name="attn_mask",
|
199 |
+
other_type=None,
|
200 |
+
other_name="",
|
201 |
+
target_type=query.dtype,
|
202 |
+
check_other=False,
|
203 |
+
)
|
204 |
+
|
205 |
+
if key_padding_mask is not None:
|
206 |
+
# We have the attn_mask, and use that to merge kpm into it.
|
207 |
+
# Turn off use of is_causal hint, as the merged mask is no
|
208 |
+
# longer causal.
|
209 |
+
is_causal = False
|
210 |
+
|
211 |
+
assert embed_dim == embed_dim_to_check, (
|
212 |
+
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
213 |
+
)
|
214 |
+
if isinstance(embed_dim, torch.Tensor):
|
215 |
+
# embed_dim can be a tensor when JIT tracing
|
216 |
+
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
217 |
+
else:
|
218 |
+
head_dim = embed_dim // num_heads
|
219 |
+
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
220 |
+
if use_separate_proj_weight:
|
221 |
+
# allow MHA to have different embedding dimensions when separate projection weights are used
|
222 |
+
assert key.shape[:2] == value.shape[:2], (
|
223 |
+
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
227 |
+
|
228 |
+
#
|
229 |
+
# compute in-projection
|
230 |
+
#
|
231 |
+
if not use_separate_proj_weight:
|
232 |
+
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
233 |
+
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
234 |
+
else:
|
235 |
+
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
236 |
+
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
237 |
+
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
238 |
+
if in_proj_bias is None:
|
239 |
+
b_q = b_k = b_v = None
|
240 |
+
else:
|
241 |
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
242 |
+
q, k, v = _in_projection(
|
243 |
+
query,
|
244 |
+
key,
|
245 |
+
value,
|
246 |
+
q_proj_weight,
|
247 |
+
k_proj_weight,
|
248 |
+
v_proj_weight,
|
249 |
+
b_q,
|
250 |
+
b_k,
|
251 |
+
b_v,
|
252 |
+
)
|
253 |
+
if cache != None:
|
254 |
+
if cache["first_infer"] == 1:
|
255 |
+
cache["k"][cache["stage"]] = k
|
256 |
+
# print(0,cache["k"].shape)
|
257 |
+
cache["v"][cache["stage"]] = v
|
258 |
+
else: ###12个layer每个都要留自己的cache_kv
|
259 |
+
# print(1,cache["k"].shape)
|
260 |
+
cache["k"][cache["stage"]] = torch.cat(
|
261 |
+
[cache["k"][cache["stage"]], k], 0
|
262 |
+
) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
263 |
+
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
|
264 |
+
# print(2, cache["k"].shape)
|
265 |
+
src_len = cache["k"][cache["stage"]].shape[0]
|
266 |
+
k = cache["k"][cache["stage"]]
|
267 |
+
v = cache["v"][cache["stage"]]
|
268 |
+
# if attn_mask is not None:
|
269 |
+
# attn_mask=attn_mask[-1:,]
|
270 |
+
# print(attn_mask.shape,attn_mask)
|
271 |
+
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
272 |
+
# print(2333,cache)
|
273 |
+
# prep attention mask
|
274 |
+
|
275 |
+
attn_mask = _canonical_mask(
|
276 |
+
mask=attn_mask,
|
277 |
+
mask_name="attn_mask",
|
278 |
+
other_type=None,
|
279 |
+
other_name="",
|
280 |
+
target_type=q.dtype,
|
281 |
+
check_other=False,
|
282 |
+
)
|
283 |
+
|
284 |
+
if attn_mask is not None:
|
285 |
+
# ensure attn_mask's dim is 3
|
286 |
+
if attn_mask.dim() == 2:
|
287 |
+
correct_2d_size = (tgt_len, src_len)
|
288 |
+
if attn_mask.shape != correct_2d_size:
|
289 |
+
raise RuntimeError(
|
290 |
+
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
|
291 |
+
)
|
292 |
+
attn_mask = attn_mask.unsqueeze(0)
|
293 |
+
elif attn_mask.dim() == 3:
|
294 |
+
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
295 |
+
if attn_mask.shape != correct_3d_size:
|
296 |
+
raise RuntimeError(
|
297 |
+
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
301 |
+
|
302 |
+
# add bias along batch dimension (currently second)
|
303 |
+
if bias_k is not None and bias_v is not None:
|
304 |
+
assert static_k is None, "bias cannot be added to static key."
|
305 |
+
assert static_v is None, "bias cannot be added to static value."
|
306 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
307 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
308 |
+
if attn_mask is not None:
|
309 |
+
attn_mask = pad(attn_mask, (0, 1))
|
310 |
+
if key_padding_mask is not None:
|
311 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
312 |
+
else:
|
313 |
+
assert bias_k is None
|
314 |
+
assert bias_v is None
|
315 |
+
|
316 |
+
#
|
317 |
+
# reshape q, k, v for multihead attention and make em batch first
|
318 |
+
#
|
319 |
+
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
320 |
+
if static_k is None:
|
321 |
+
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
322 |
+
else:
|
323 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
324 |
+
assert static_k.size(0) == bsz * num_heads, (
|
325 |
+
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
326 |
+
)
|
327 |
+
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
328 |
+
k = static_k
|
329 |
+
if static_v is None:
|
330 |
+
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
331 |
+
else:
|
332 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
333 |
+
assert static_v.size(0) == bsz * num_heads, (
|
334 |
+
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
335 |
+
)
|
336 |
+
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
337 |
+
v = static_v
|
338 |
+
|
339 |
+
# add zero attention along batch dimension (now first)
|
340 |
+
if add_zero_attn:
|
341 |
+
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
342 |
+
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
343 |
+
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
344 |
+
if attn_mask is not None:
|
345 |
+
attn_mask = pad(attn_mask, (0, 1))
|
346 |
+
if key_padding_mask is not None:
|
347 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
348 |
+
|
349 |
+
# update source sequence length after adjustments
|
350 |
+
src_len = k.size(1)
|
351 |
+
|
352 |
+
# merge key padding and attention masks
|
353 |
+
if key_padding_mask is not None:
|
354 |
+
assert key_padding_mask.shape == (
|
355 |
+
bsz,
|
356 |
+
src_len,
|
357 |
+
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
358 |
+
key_padding_mask = (
|
359 |
+
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
360 |
+
)
|
361 |
+
if attn_mask is None:
|
362 |
+
attn_mask = key_padding_mask
|
363 |
+
else:
|
364 |
+
attn_mask = attn_mask + key_padding_mask
|
365 |
+
|
366 |
+
# adjust dropout probability
|
367 |
+
if not training:
|
368 |
+
dropout_p = 0.0
|
369 |
+
|
370 |
+
#
|
371 |
+
# (deep breath) calculate attention and out projection
|
372 |
+
#
|
373 |
+
|
374 |
+
if need_weights:
|
375 |
+
B, Nt, E = q.shape
|
376 |
+
q_scaled = q / math.sqrt(E)
|
377 |
+
|
378 |
+
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
379 |
+
|
380 |
+
if attn_mask is not None:
|
381 |
+
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
382 |
+
else:
|
383 |
+
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
384 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
385 |
+
if dropout_p > 0.0:
|
386 |
+
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
387 |
+
|
388 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
389 |
+
|
390 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
391 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
392 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
393 |
+
|
394 |
+
# optionally average attention weights over heads
|
395 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
396 |
+
if average_attn_weights:
|
397 |
+
attn_output_weights = attn_output_weights.mean(dim=1)
|
398 |
+
|
399 |
+
if not is_batched:
|
400 |
+
# squeeze the output if input was unbatched
|
401 |
+
attn_output = attn_output.squeeze(1)
|
402 |
+
attn_output_weights = attn_output_weights.squeeze(0)
|
403 |
+
return attn_output, attn_output_weights
|
404 |
+
else:
|
405 |
+
# attn_mask can be either (L,S) or (N*num_heads, L, S)
|
406 |
+
# if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
|
407 |
+
# in order to match the input for SDPA of (N, num_heads, L, S)
|
408 |
+
if attn_mask is not None:
|
409 |
+
if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
|
410 |
+
attn_mask = attn_mask.unsqueeze(0)
|
411 |
+
else:
|
412 |
+
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
413 |
+
|
414 |
+
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
415 |
+
k = k.view(bsz, num_heads, src_len, head_dim)
|
416 |
+
v = v.view(bsz, num_heads, src_len, head_dim)
|
417 |
+
|
418 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
419 |
+
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
420 |
+
|
421 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
422 |
+
|
423 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
424 |
+
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
425 |
+
if not is_batched:
|
426 |
+
# squeeze the output if input was unbatched
|
427 |
+
attn_output = attn_output.squeeze(1)
|
428 |
+
return attn_output, None
|
GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn.functional import *
|
2 |
+
from torch.nn.functional import (
|
3 |
+
_canonical_mask,
|
4 |
+
)
|
5 |
+
|
6 |
+
|
7 |
+
def multi_head_attention_forward_patched(
|
8 |
+
query,
|
9 |
+
key,
|
10 |
+
value,
|
11 |
+
embed_dim_to_check: int,
|
12 |
+
num_heads: int,
|
13 |
+
in_proj_weight,
|
14 |
+
in_proj_bias: Optional[Tensor],
|
15 |
+
bias_k: Optional[Tensor],
|
16 |
+
bias_v: Optional[Tensor],
|
17 |
+
add_zero_attn: bool,
|
18 |
+
dropout_p: float,
|
19 |
+
out_proj_weight: Tensor,
|
20 |
+
out_proj_bias: Optional[Tensor],
|
21 |
+
training: bool = True,
|
22 |
+
key_padding_mask: Optional[Tensor] = None,
|
23 |
+
need_weights: bool = True,
|
24 |
+
attn_mask: Optional[Tensor] = None,
|
25 |
+
use_separate_proj_weight: bool = False,
|
26 |
+
q_proj_weight: Optional[Tensor] = None,
|
27 |
+
k_proj_weight: Optional[Tensor] = None,
|
28 |
+
v_proj_weight: Optional[Tensor] = None,
|
29 |
+
static_k: Optional[Tensor] = None,
|
30 |
+
static_v: Optional[Tensor] = None,
|
31 |
+
average_attn_weights: bool = True,
|
32 |
+
is_causal: bool = False,
|
33 |
+
cache=None,
|
34 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
35 |
+
# set up shape vars
|
36 |
+
_, _, embed_dim = query.shape
|
37 |
+
attn_mask = _canonical_mask(
|
38 |
+
mask=attn_mask,
|
39 |
+
mask_name="attn_mask",
|
40 |
+
other_type=None,
|
41 |
+
other_name="",
|
42 |
+
target_type=query.dtype,
|
43 |
+
check_other=False,
|
44 |
+
)
|
45 |
+
head_dim = embed_dim // num_heads
|
46 |
+
|
47 |
+
proj_qkv = linear(query, in_proj_weight, in_proj_bias)
|
48 |
+
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
49 |
+
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
50 |
+
|
51 |
+
if cache["first_infer"] == 1:
|
52 |
+
cache["k"][cache["stage"]] = k
|
53 |
+
cache["v"][cache["stage"]] = v
|
54 |
+
else:
|
55 |
+
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
|
56 |
+
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
|
57 |
+
k = cache["k"][cache["stage"]]
|
58 |
+
v = cache["v"][cache["stage"]]
|
59 |
+
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
60 |
+
|
61 |
+
attn_mask = _canonical_mask(
|
62 |
+
mask=attn_mask,
|
63 |
+
mask_name="attn_mask",
|
64 |
+
other_type=None,
|
65 |
+
other_name="",
|
66 |
+
target_type=q.dtype,
|
67 |
+
check_other=False,
|
68 |
+
)
|
69 |
+
attn_mask = attn_mask.unsqueeze(0)
|
70 |
+
|
71 |
+
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
|
72 |
+
k = k.view(-1, num_heads, head_dim).transpose(0, 1)
|
73 |
+
v = v.view(-1, num_heads, head_dim).transpose(0, 1)
|
74 |
+
|
75 |
+
dropout_p = 0.0
|
76 |
+
attn_mask = attn_mask.unsqueeze(0)
|
77 |
+
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
78 |
+
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
79 |
+
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
80 |
+
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
81 |
+
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
82 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
83 |
+
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
84 |
+
|
85 |
+
return attn_output
|
GPT_SoVITS/AR/modules/scaling.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import random
|
17 |
+
from typing import Optional
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from torch import Tensor
|
23 |
+
|
24 |
+
|
25 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
26 |
+
"""
|
27 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
28 |
+
This is a definition, originally motivated by its close numerical
|
29 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
30 |
+
|
31 |
+
Memory-efficient derivative computation:
|
32 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
33 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
34 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
35 |
+
double_swish'(x) = x * s'(x) + s(x).
|
36 |
+
= x * s(x) * (1-s(x)) + s(x).
|
37 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
38 |
+
... so we just need to remember s(x) but not x itself.
|
39 |
+
"""
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
43 |
+
requires_grad = x.requires_grad
|
44 |
+
x_dtype = x.dtype
|
45 |
+
if x.dtype == torch.float16:
|
46 |
+
x = x.to(torch.float32)
|
47 |
+
|
48 |
+
s = torch.sigmoid(x - 1.0)
|
49 |
+
y = x * s
|
50 |
+
|
51 |
+
if requires_grad:
|
52 |
+
deriv = y * (1 - s) + s
|
53 |
+
# notes on derivative of x * sigmoid(x - 1):
|
54 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
55 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
56 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
57 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
58 |
+
# floors), should be expectation-preserving.
|
59 |
+
floor = -0.043637
|
60 |
+
ceil = 1.2
|
61 |
+
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
62 |
+
if __name__ == "__main__":
|
63 |
+
# for self-testing only.
|
64 |
+
assert d_scaled.min() >= 0.0
|
65 |
+
assert d_scaled.max() < 256.0
|
66 |
+
d_int = d_scaled.to(torch.uint8)
|
67 |
+
ctx.save_for_backward(d_int)
|
68 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
69 |
+
y = y.to(torch.float16)
|
70 |
+
return y
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
74 |
+
(d,) = ctx.saved_tensors
|
75 |
+
# the same constants as used in forward pass.
|
76 |
+
floor = -0.043637
|
77 |
+
ceil = 1.2
|
78 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
79 |
+
return y_grad * d
|
80 |
+
|
81 |
+
|
82 |
+
class DoubleSwish(torch.nn.Module):
|
83 |
+
def forward(self, x: Tensor) -> Tensor:
|
84 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
85 |
+
that we approximate closely with x * sigmoid(x-1).
|
86 |
+
"""
|
87 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
88 |
+
return x * torch.sigmoid(x - 1.0)
|
89 |
+
return DoubleSwishFunction.apply(x)
|
90 |
+
|
91 |
+
|
92 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
93 |
+
@staticmethod
|
94 |
+
def forward(
|
95 |
+
ctx,
|
96 |
+
x: Tensor,
|
97 |
+
scale_factor: Tensor,
|
98 |
+
sign_factor: Optional[Tensor],
|
99 |
+
channel_dim: int,
|
100 |
+
) -> Tensor:
|
101 |
+
if channel_dim < 0:
|
102 |
+
channel_dim += x.ndim
|
103 |
+
ctx.channel_dim = channel_dim
|
104 |
+
xgt0 = x > 0
|
105 |
+
if sign_factor is None:
|
106 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
107 |
+
else:
|
108 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
109 |
+
return x
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
113 |
+
if len(ctx.saved_tensors) == 3:
|
114 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
115 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
116 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
117 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
118 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
119 |
+
else:
|
120 |
+
xgt0, scale_factor = ctx.saved_tensors
|
121 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
122 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
123 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
124 |
+
neg_delta_grad = x_grad.abs() * factor
|
125 |
+
return (
|
126 |
+
x_grad - neg_delta_grad,
|
127 |
+
None,
|
128 |
+
None,
|
129 |
+
None,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
def _compute_scale_factor(
|
134 |
+
x: Tensor,
|
135 |
+
channel_dim: int,
|
136 |
+
min_abs: float,
|
137 |
+
max_abs: float,
|
138 |
+
gain_factor: float,
|
139 |
+
max_factor: float,
|
140 |
+
) -> Tensor:
|
141 |
+
if channel_dim < 0:
|
142 |
+
channel_dim += x.ndim
|
143 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
144 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
145 |
+
|
146 |
+
if min_abs == 0.0:
|
147 |
+
below_threshold = 0.0
|
148 |
+
else:
|
149 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
150 |
+
# x_abs)_mean , min_abs.
|
151 |
+
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
152 |
+
|
153 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
154 |
+
|
155 |
+
return below_threshold - above_threshold
|
156 |
+
|
157 |
+
|
158 |
+
def _compute_sign_factor(
|
159 |
+
x: Tensor,
|
160 |
+
channel_dim: int,
|
161 |
+
min_positive: float,
|
162 |
+
max_positive: float,
|
163 |
+
gain_factor: float,
|
164 |
+
max_factor: float,
|
165 |
+
) -> Tensor:
|
166 |
+
if channel_dim < 0:
|
167 |
+
channel_dim += x.ndim
|
168 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
169 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
170 |
+
if min_positive == 0.0:
|
171 |
+
factor1 = 0.0
|
172 |
+
else:
|
173 |
+
# 0 if proportion_positive >= min_positive, else can be
|
174 |
+
# as large as max_factor.
|
175 |
+
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
176 |
+
|
177 |
+
if max_positive == 1.0:
|
178 |
+
factor2 = 0.0
|
179 |
+
else:
|
180 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
181 |
+
# as large as -max_factor.
|
182 |
+
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
183 |
+
min=0, max=max_factor
|
184 |
+
)
|
185 |
+
sign_factor = factor1 - factor2
|
186 |
+
# require min_positive != 0 or max_positive != 1:
|
187 |
+
assert not isinstance(sign_factor, float)
|
188 |
+
return sign_factor
|
189 |
+
|
190 |
+
|
191 |
+
class ActivationBalancer(torch.nn.Module):
|
192 |
+
"""
|
193 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
194 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
195 |
+
time. It does this by multiplying negative derivative values by up to
|
196 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
197 |
+
interpolated from 1 at the threshold to those extremal values when none
|
198 |
+
of the inputs are positive.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
num_channels: the number of channels
|
202 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
203 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
204 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
205 |
+
that (x > 0), below which we start to modify the derivatives.
|
206 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
207 |
+
that (x > 0), above which we start to modify the derivatives.
|
208 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
209 |
+
either the sign constraint or the magnitude constraint;
|
210 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
211 |
+
values in the range [0.98..1.02].
|
212 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
213 |
+
change in gradient once the constraints on min_positive and max_positive
|
214 |
+
are violated.
|
215 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
216 |
+
change in gradient once the constraints on min_abs and max_abs
|
217 |
+
are violated.
|
218 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
219 |
+
value per channel, which we allow, before we start to modify
|
220 |
+
the derivatives to prevent this.
|
221 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
222 |
+
value per channel, which we allow, before we start to modify
|
223 |
+
the derivatives to prevent this.
|
224 |
+
min_prob: determines the minimum probability with which we modify the
|
225 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
226 |
+
on each forward(). This is done randomly to prevent all layers
|
227 |
+
from doing it at the same time. Early in training we may use
|
228 |
+
higher probabilities than this; it will decay to this value.
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
num_channels: int,
|
234 |
+
channel_dim: int,
|
235 |
+
min_positive: float = 0.05,
|
236 |
+
max_positive: float = 0.95,
|
237 |
+
max_factor: float = 0.04,
|
238 |
+
sign_gain_factor: float = 0.01,
|
239 |
+
scale_gain_factor: float = 0.02,
|
240 |
+
min_abs: float = 0.2,
|
241 |
+
max_abs: float = 100.0,
|
242 |
+
min_prob: float = 0.1,
|
243 |
+
):
|
244 |
+
super(ActivationBalancer, self).__init__()
|
245 |
+
self.num_channels = num_channels
|
246 |
+
self.channel_dim = channel_dim
|
247 |
+
self.min_positive = min_positive
|
248 |
+
self.max_positive = max_positive
|
249 |
+
self.max_factor = max_factor
|
250 |
+
self.min_abs = min_abs
|
251 |
+
self.max_abs = max_abs
|
252 |
+
self.min_prob = min_prob
|
253 |
+
self.sign_gain_factor = sign_gain_factor
|
254 |
+
self.scale_gain_factor = scale_gain_factor
|
255 |
+
|
256 |
+
# count measures how many times the forward() function has been called.
|
257 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
258 |
+
# make sure it is synced to disk when we load and save the model.
|
259 |
+
self.cpu_count = 0
|
260 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
261 |
+
|
262 |
+
def forward(self, x: Tensor) -> Tensor:
|
263 |
+
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
|
264 |
+
return _no_op(x)
|
265 |
+
|
266 |
+
count = self.cpu_count
|
267 |
+
self.cpu_count += 1
|
268 |
+
|
269 |
+
if random.random() < 0.01:
|
270 |
+
# Occasionally sync self.cpu_count with self.count.
|
271 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
272 |
+
# because syncing with the GPU is slow.
|
273 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
274 |
+
self.count.fill_(self.cpu_count)
|
275 |
+
|
276 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
277 |
+
# a floor at min_prob (==0.1, by default)
|
278 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
279 |
+
|
280 |
+
if random.random() < prob:
|
281 |
+
sign_gain_factor = 0.5
|
282 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
283 |
+
sign_factor = _compute_sign_factor(
|
284 |
+
x,
|
285 |
+
self.channel_dim,
|
286 |
+
self.min_positive,
|
287 |
+
self.max_positive,
|
288 |
+
gain_factor=self.sign_gain_factor / prob,
|
289 |
+
max_factor=self.max_factor,
|
290 |
+
)
|
291 |
+
else:
|
292 |
+
sign_factor = None
|
293 |
+
|
294 |
+
scale_factor = _compute_scale_factor(
|
295 |
+
x.detach(),
|
296 |
+
self.channel_dim,
|
297 |
+
min_abs=self.min_abs,
|
298 |
+
max_abs=self.max_abs,
|
299 |
+
gain_factor=self.scale_gain_factor / prob,
|
300 |
+
max_factor=self.max_factor,
|
301 |
+
)
|
302 |
+
return ActivationBalancerFunction.apply(
|
303 |
+
x,
|
304 |
+
scale_factor,
|
305 |
+
sign_factor,
|
306 |
+
self.channel_dim,
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
return _no_op(x)
|
310 |
+
|
311 |
+
|
312 |
+
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
313 |
+
"""
|
314 |
+
ActivationBalancer -> DoubleSwish
|
315 |
+
"""
|
316 |
+
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
317 |
+
return nn.Sequential(
|
318 |
+
balancer,
|
319 |
+
DoubleSwish(),
|
320 |
+
)
|
GPT_SoVITS/AR/modules/transformer.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
|
2 |
+
import copy
|
3 |
+
import numbers
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any
|
6 |
+
from typing import Callable
|
7 |
+
from typing import List
|
8 |
+
from typing import Optional
|
9 |
+
from typing import Tuple
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from AR.modules.activation import MultiheadAttention
|
14 |
+
from AR.modules.scaling import BalancedDoubleSwish
|
15 |
+
from torch import nn
|
16 |
+
from torch import Tensor
|
17 |
+
from torch.nn import functional as F
|
18 |
+
|
19 |
+
_shape_t = Union[int, List[int], torch.Size]
|
20 |
+
|
21 |
+
|
22 |
+
class LayerNorm(nn.Module):
|
23 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
24 |
+
normalized_shape: Tuple[int, ...]
|
25 |
+
eps: float
|
26 |
+
elementwise_affine: bool
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
normalized_shape: _shape_t,
|
31 |
+
eps: float = 1e-5,
|
32 |
+
elementwise_affine: bool = True,
|
33 |
+
device=None,
|
34 |
+
dtype=None,
|
35 |
+
) -> None:
|
36 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
37 |
+
super(LayerNorm, self).__init__()
|
38 |
+
if isinstance(normalized_shape, numbers.Integral):
|
39 |
+
# mypy error: incompatible types in assignment
|
40 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
41 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
42 |
+
self.eps = eps
|
43 |
+
self.elementwise_affine = elementwise_affine
|
44 |
+
if self.elementwise_affine:
|
45 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
46 |
+
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
47 |
+
else:
|
48 |
+
self.register_parameter("weight", None)
|
49 |
+
self.register_parameter("bias", None)
|
50 |
+
|
51 |
+
self.reset_parameters()
|
52 |
+
|
53 |
+
def reset_parameters(self) -> None:
|
54 |
+
if self.elementwise_affine:
|
55 |
+
nn.init.ones_(self.weight)
|
56 |
+
nn.init.zeros_(self.bias)
|
57 |
+
|
58 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
59 |
+
if isinstance(input, tuple):
|
60 |
+
input, embedding = input
|
61 |
+
return (
|
62 |
+
F.layer_norm(
|
63 |
+
input,
|
64 |
+
self.normalized_shape,
|
65 |
+
self.weight,
|
66 |
+
self.bias,
|
67 |
+
self.eps,
|
68 |
+
),
|
69 |
+
embedding,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert embedding is None
|
73 |
+
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
74 |
+
|
75 |
+
def extra_repr(self) -> str:
|
76 |
+
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
77 |
+
|
78 |
+
|
79 |
+
class IdentityNorm(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
d_model: int,
|
83 |
+
eps: float = 1e-5,
|
84 |
+
device=None,
|
85 |
+
dtype=None,
|
86 |
+
) -> None:
|
87 |
+
super(IdentityNorm, self).__init__()
|
88 |
+
|
89 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
90 |
+
if isinstance(input, tuple):
|
91 |
+
return input
|
92 |
+
|
93 |
+
assert embedding is None
|
94 |
+
return input
|
95 |
+
|
96 |
+
|
97 |
+
class TransformerEncoder(nn.Module):
|
98 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
99 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
103 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
104 |
+
norm: the layer normalization component (optional).
|
105 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
106 |
+
(and convert back on output). This will improve the overall performance of
|
107 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
108 |
+
|
109 |
+
Examples::
|
110 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
111 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
112 |
+
>>> src = torch.rand(10, 32, 512)
|
113 |
+
>>> out = transformer_encoder(src)
|
114 |
+
"""
|
115 |
+
|
116 |
+
__constants__ = ["norm"]
|
117 |
+
|
118 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
119 |
+
super(TransformerEncoder, self).__init__()
|
120 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
121 |
+
self.num_layers = num_layers
|
122 |
+
self.norm = norm
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
src: Tensor,
|
127 |
+
mask: Optional[Tensor] = None,
|
128 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
129 |
+
return_layer_states: bool = False,
|
130 |
+
cache=None,
|
131 |
+
) -> Tensor:
|
132 |
+
r"""Pass the input through the encoder layers in turn.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
src: the sequence to the encoder (required).
|
136 |
+
mask: the mask for the src sequence (optional).
|
137 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
138 |
+
return_layer_states: return layers' state (optional).
|
139 |
+
|
140 |
+
Shape:
|
141 |
+
see the docs in Transformer class.
|
142 |
+
"""
|
143 |
+
if return_layer_states:
|
144 |
+
layer_states = [] # layers' output
|
145 |
+
output = src
|
146 |
+
for mod in self.layers:
|
147 |
+
output = mod(
|
148 |
+
output,
|
149 |
+
src_mask=mask,
|
150 |
+
src_key_padding_mask=src_key_padding_mask,
|
151 |
+
cache=cache,
|
152 |
+
)
|
153 |
+
layer_states.append(output[0])
|
154 |
+
|
155 |
+
if self.norm is not None:
|
156 |
+
output = self.norm(output)
|
157 |
+
|
158 |
+
return layer_states, output
|
159 |
+
|
160 |
+
output = src
|
161 |
+
for mod in self.layers:
|
162 |
+
output = mod(
|
163 |
+
output,
|
164 |
+
src_mask=mask,
|
165 |
+
src_key_padding_mask=src_key_padding_mask,
|
166 |
+
cache=cache,
|
167 |
+
)
|
168 |
+
|
169 |
+
if self.norm is not None:
|
170 |
+
output = self.norm(output)
|
171 |
+
|
172 |
+
return output
|
173 |
+
|
174 |
+
|
175 |
+
class TransformerEncoderLayer(nn.Module):
|
176 |
+
__constants__ = ["batch_first", "norm_first"]
|
177 |
+
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
d_model: int,
|
181 |
+
nhead: int,
|
182 |
+
dim_feedforward: int = 2048,
|
183 |
+
dropout: float = 0.1,
|
184 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
185 |
+
batch_first: bool = False,
|
186 |
+
norm_first: bool = False,
|
187 |
+
device=None,
|
188 |
+
dtype=None,
|
189 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
190 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
191 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
192 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
193 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
194 |
+
layer_norm_eps: float = 1e-5,
|
195 |
+
adaptive_layer_norm=False,
|
196 |
+
) -> None:
|
197 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
198 |
+
super(TransformerEncoderLayer, self).__init__()
|
199 |
+
# print(233333333333,d_model,nhead)
|
200 |
+
# import os
|
201 |
+
# os._exit(2333333)
|
202 |
+
self.self_attn = MultiheadAttention(
|
203 |
+
d_model, # 512 16
|
204 |
+
nhead,
|
205 |
+
dropout=dropout,
|
206 |
+
batch_first=batch_first,
|
207 |
+
linear1_cls=linear1_self_attention_cls,
|
208 |
+
linear2_cls=linear2_self_attention_cls,
|
209 |
+
**factory_kwargs,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Implementation of Feedforward model
|
213 |
+
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
214 |
+
self.dropout = nn.Dropout(dropout)
|
215 |
+
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
216 |
+
|
217 |
+
self.norm_first = norm_first
|
218 |
+
self.dropout1 = nn.Dropout(dropout)
|
219 |
+
self.dropout2 = nn.Dropout(dropout)
|
220 |
+
|
221 |
+
# Legacy string support for activation function.
|
222 |
+
if isinstance(activation, str):
|
223 |
+
activation = _get_activation_fn(activation)
|
224 |
+
elif isinstance(activation, partial):
|
225 |
+
activation = activation(d_model)
|
226 |
+
elif activation == BalancedDoubleSwish:
|
227 |
+
activation = BalancedDoubleSwish(d_model)
|
228 |
+
|
229 |
+
# # We can't test self.activation in forward() in TorchScript,
|
230 |
+
# # so stash some information about it instead.
|
231 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
232 |
+
# self.activation_relu_or_gelu = 1
|
233 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
234 |
+
# self.activation_relu_or_gelu = 2
|
235 |
+
# else:
|
236 |
+
# self.activation_relu_or_gelu = 0
|
237 |
+
self.activation = activation
|
238 |
+
|
239 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
240 |
+
if layer_norm_cls == IdentityNorm:
|
241 |
+
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
242 |
+
else:
|
243 |
+
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
+
|
245 |
+
if adaptive_layer_norm:
|
246 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
247 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
248 |
+
else:
|
249 |
+
self.norm1 = norm1
|
250 |
+
self.norm2 = norm2
|
251 |
+
|
252 |
+
def __setstate__(self, state):
|
253 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
254 |
+
if not hasattr(self, "activation"):
|
255 |
+
self.activation = F.relu
|
256 |
+
|
257 |
+
def forward(
|
258 |
+
self,
|
259 |
+
src: Tensor,
|
260 |
+
src_mask: Optional[Tensor] = None,
|
261 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
262 |
+
cache=None,
|
263 |
+
) -> Tensor:
|
264 |
+
r"""Pass the input through the encoder layer.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
src: the sequence to the encoder layer (required).
|
268 |
+
src_mask: the mask for the src sequence (optional).
|
269 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
270 |
+
|
271 |
+
Shape:
|
272 |
+
see the docs in Transformer class.
|
273 |
+
"""
|
274 |
+
x, stage_embedding = src, None
|
275 |
+
is_src_tuple = False
|
276 |
+
if isinstance(src, tuple):
|
277 |
+
x, stage_embedding = src
|
278 |
+
is_src_tuple = True
|
279 |
+
|
280 |
+
if src_key_padding_mask is not None:
|
281 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
282 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
283 |
+
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
284 |
+
|
285 |
+
if self.norm_first:
|
286 |
+
x = x + self._sa_block(
|
287 |
+
self.norm1(x, stage_embedding),
|
288 |
+
src_mask,
|
289 |
+
src_key_padding_mask,
|
290 |
+
cache=cache,
|
291 |
+
)
|
292 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
293 |
+
else:
|
294 |
+
x = self.norm1(
|
295 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
|
296 |
+
stage_embedding,
|
297 |
+
)
|
298 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
299 |
+
|
300 |
+
if is_src_tuple:
|
301 |
+
return (x, stage_embedding)
|
302 |
+
return x
|
303 |
+
|
304 |
+
# self-attention block
|
305 |
+
def _sa_block(
|
306 |
+
self,
|
307 |
+
x: Tensor,
|
308 |
+
attn_mask: Optional[Tensor],
|
309 |
+
key_padding_mask: Optional[Tensor],
|
310 |
+
cache=None,
|
311 |
+
) -> Tensor:
|
312 |
+
# print(x.shape,attn_mask.shape,key_padding_mask)
|
313 |
+
# torch.Size([1, 188, 512]) torch.Size([188, 188]) None
|
314 |
+
# import os
|
315 |
+
# os._exit(23333)
|
316 |
+
x = self.self_attn(
|
317 |
+
x,
|
318 |
+
x,
|
319 |
+
x,
|
320 |
+
attn_mask=attn_mask,
|
321 |
+
key_padding_mask=key_padding_mask,
|
322 |
+
need_weights=False,
|
323 |
+
cache=cache,
|
324 |
+
)[0]
|
325 |
+
return self.dropout1(x)
|
326 |
+
|
327 |
+
# feed forward block
|
328 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
329 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
330 |
+
return self.dropout2(x)
|
331 |
+
|
332 |
+
|
333 |
+
class AdaptiveLayerNorm(nn.Module):
|
334 |
+
r"""Adaptive Layer Normalization"""
|
335 |
+
|
336 |
+
def __init__(self, d_model, norm) -> None:
|
337 |
+
super(AdaptiveLayerNorm, self).__init__()
|
338 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
339 |
+
self.norm = norm
|
340 |
+
self.d_model = d_model
|
341 |
+
self.eps = self.norm.eps
|
342 |
+
|
343 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
344 |
+
if isinstance(input, tuple):
|
345 |
+
input, embedding = input
|
346 |
+
weight, bias = torch.split(
|
347 |
+
self.project_layer(embedding),
|
348 |
+
split_size_or_sections=self.d_model,
|
349 |
+
dim=-1,
|
350 |
+
)
|
351 |
+
return (weight * self.norm(input) + bias, embedding)
|
352 |
+
|
353 |
+
weight, bias = torch.split(
|
354 |
+
self.project_layer(embedding),
|
355 |
+
split_size_or_sections=self.d_model,
|
356 |
+
dim=-1,
|
357 |
+
)
|
358 |
+
return weight * self.norm(input) + bias
|
359 |
+
|
360 |
+
|
361 |
+
def _get_clones(module, N):
|
362 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
GPT_SoVITS/AR/modules/transformer_onnx.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
|
2 |
+
import copy
|
3 |
+
import numbers
|
4 |
+
from functools import partial
|
5 |
+
from typing import Any
|
6 |
+
from typing import Callable
|
7 |
+
from typing import List
|
8 |
+
from typing import Optional
|
9 |
+
from typing import Tuple
|
10 |
+
from typing import Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from AR.modules.activation_onnx import MultiheadAttention
|
14 |
+
from AR.modules.scaling import BalancedDoubleSwish
|
15 |
+
from torch import nn
|
16 |
+
from torch import Tensor
|
17 |
+
from torch.nn import functional as F
|
18 |
+
|
19 |
+
_shape_t = Union[int, List[int], torch.Size]
|
20 |
+
|
21 |
+
|
22 |
+
class LayerNorm(nn.Module):
|
23 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
24 |
+
normalized_shape: Tuple[int, ...]
|
25 |
+
eps: float
|
26 |
+
elementwise_affine: bool
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
normalized_shape: _shape_t,
|
31 |
+
eps: float = 1e-5,
|
32 |
+
elementwise_affine: bool = True,
|
33 |
+
device=None,
|
34 |
+
dtype=None,
|
35 |
+
) -> None:
|
36 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
37 |
+
super(LayerNorm, self).__init__()
|
38 |
+
if isinstance(normalized_shape, numbers.Integral):
|
39 |
+
# mypy error: incompatible types in assignment
|
40 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
41 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
42 |
+
self.eps = eps
|
43 |
+
self.elementwise_affine = elementwise_affine
|
44 |
+
if self.elementwise_affine:
|
45 |
+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
46 |
+
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
47 |
+
else:
|
48 |
+
self.register_parameter("weight", None)
|
49 |
+
self.register_parameter("bias", None)
|
50 |
+
|
51 |
+
self.reset_parameters()
|
52 |
+
|
53 |
+
def reset_parameters(self) -> None:
|
54 |
+
if self.elementwise_affine:
|
55 |
+
nn.init.ones_(self.weight)
|
56 |
+
nn.init.zeros_(self.bias)
|
57 |
+
|
58 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
59 |
+
if isinstance(input, tuple):
|
60 |
+
input, embedding = input
|
61 |
+
return (
|
62 |
+
F.layer_norm(
|
63 |
+
input,
|
64 |
+
self.normalized_shape,
|
65 |
+
self.weight,
|
66 |
+
self.bias,
|
67 |
+
self.eps,
|
68 |
+
),
|
69 |
+
embedding,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert embedding is None
|
73 |
+
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
74 |
+
|
75 |
+
def extra_repr(self) -> str:
|
76 |
+
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
77 |
+
|
78 |
+
|
79 |
+
class IdentityNorm(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
d_model: int,
|
83 |
+
eps: float = 1e-5,
|
84 |
+
device=None,
|
85 |
+
dtype=None,
|
86 |
+
) -> None:
|
87 |
+
super(IdentityNorm, self).__init__()
|
88 |
+
|
89 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
90 |
+
if isinstance(input, tuple):
|
91 |
+
return input
|
92 |
+
|
93 |
+
assert embedding is None
|
94 |
+
return input
|
95 |
+
|
96 |
+
|
97 |
+
class TransformerEncoder(nn.Module):
|
98 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
99 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
103 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
104 |
+
norm: the layer normalization component (optional).
|
105 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
106 |
+
(and convert back on output). This will improve the overall performance of
|
107 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
108 |
+
|
109 |
+
Examples::
|
110 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
111 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
112 |
+
>>> src = torch.rand(10, 32, 512)
|
113 |
+
>>> out = transformer_encoder(src)
|
114 |
+
"""
|
115 |
+
|
116 |
+
__constants__ = ["norm"]
|
117 |
+
|
118 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
119 |
+
super(TransformerEncoder, self).__init__()
|
120 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
121 |
+
self.num_layers = num_layers
|
122 |
+
self.norm = norm
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
src: Tensor,
|
127 |
+
mask: Optional[Tensor] = None,
|
128 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
129 |
+
return_layer_states: bool = False,
|
130 |
+
cache=None,
|
131 |
+
) -> Tensor:
|
132 |
+
output = src
|
133 |
+
for mod in self.layers:
|
134 |
+
output = mod(
|
135 |
+
output,
|
136 |
+
src_mask=mask,
|
137 |
+
src_key_padding_mask=src_key_padding_mask,
|
138 |
+
cache=cache,
|
139 |
+
)
|
140 |
+
|
141 |
+
if self.norm is not None:
|
142 |
+
output = self.norm(output)
|
143 |
+
|
144 |
+
return output
|
145 |
+
|
146 |
+
|
147 |
+
class TransformerEncoderLayer(nn.Module):
|
148 |
+
__constants__ = ["batch_first", "norm_first"]
|
149 |
+
|
150 |
+
def __init__(
|
151 |
+
self,
|
152 |
+
d_model: int,
|
153 |
+
nhead: int,
|
154 |
+
dim_feedforward: int = 2048,
|
155 |
+
dropout: float = 0.1,
|
156 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
157 |
+
batch_first: bool = False,
|
158 |
+
norm_first: bool = False,
|
159 |
+
device=None,
|
160 |
+
dtype=None,
|
161 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
162 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
163 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
164 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
165 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
166 |
+
layer_norm_eps: float = 1e-5,
|
167 |
+
adaptive_layer_norm=False,
|
168 |
+
) -> None:
|
169 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
170 |
+
super(TransformerEncoderLayer, self).__init__()
|
171 |
+
self.self_attn = MultiheadAttention(
|
172 |
+
d_model, # 512 16
|
173 |
+
nhead,
|
174 |
+
dropout=dropout,
|
175 |
+
batch_first=batch_first,
|
176 |
+
linear1_cls=linear1_self_attention_cls,
|
177 |
+
linear2_cls=linear2_self_attention_cls,
|
178 |
+
**factory_kwargs,
|
179 |
+
)
|
180 |
+
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
181 |
+
self.dropout = nn.Dropout(dropout)
|
182 |
+
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
183 |
+
self.norm_first = norm_first
|
184 |
+
self.dropout1 = nn.Dropout(dropout)
|
185 |
+
self.dropout2 = nn.Dropout(dropout)
|
186 |
+
if isinstance(activation, str):
|
187 |
+
activation = _get_activation_fn(activation)
|
188 |
+
elif isinstance(activation, partial):
|
189 |
+
activation = activation(d_model)
|
190 |
+
elif activation == BalancedDoubleSwish:
|
191 |
+
activation = BalancedDoubleSwish(d_model)
|
192 |
+
self.activation = activation
|
193 |
+
|
194 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
195 |
+
if layer_norm_cls == IdentityNorm:
|
196 |
+
norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
197 |
+
else:
|
198 |
+
norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
199 |
+
|
200 |
+
if adaptive_layer_norm:
|
201 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
202 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
203 |
+
else:
|
204 |
+
self.norm1 = norm1
|
205 |
+
self.norm2 = norm2
|
206 |
+
|
207 |
+
def __setstate__(self, state):
|
208 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
209 |
+
if not hasattr(self, "activation"):
|
210 |
+
self.activation = F.relu
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
src: Tensor,
|
215 |
+
src_mask: Optional[Tensor] = None,
|
216 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
217 |
+
cache=None,
|
218 |
+
) -> Tensor:
|
219 |
+
x = src
|
220 |
+
stage_embedding = None
|
221 |
+
x = self.norm1(
|
222 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
|
223 |
+
stage_embedding,
|
224 |
+
)
|
225 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
226 |
+
|
227 |
+
return x
|
228 |
+
|
229 |
+
def _sa_block(
|
230 |
+
self,
|
231 |
+
x: Tensor,
|
232 |
+
attn_mask: Optional[Tensor],
|
233 |
+
key_padding_mask: Optional[Tensor],
|
234 |
+
cache=None,
|
235 |
+
) -> Tensor:
|
236 |
+
x = self.self_attn(
|
237 |
+
x,
|
238 |
+
x,
|
239 |
+
x,
|
240 |
+
attn_mask=attn_mask,
|
241 |
+
key_padding_mask=key_padding_mask,
|
242 |
+
need_weights=False,
|
243 |
+
cache=cache,
|
244 |
+
)
|
245 |
+
return self.dropout1(x)
|
246 |
+
|
247 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
248 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
249 |
+
return self.dropout2(x)
|
250 |
+
|
251 |
+
|
252 |
+
class AdaptiveLayerNorm(nn.Module):
|
253 |
+
r"""Adaptive Layer Normalization"""
|
254 |
+
|
255 |
+
def __init__(self, d_model, norm) -> None:
|
256 |
+
super(AdaptiveLayerNorm, self).__init__()
|
257 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
258 |
+
self.norm = norm
|
259 |
+
self.d_model = d_model
|
260 |
+
self.eps = self.norm.eps
|
261 |
+
|
262 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
263 |
+
if isinstance(input, tuple):
|
264 |
+
input, embedding = input
|
265 |
+
weight, bias = torch.split(
|
266 |
+
self.project_layer(embedding),
|
267 |
+
split_size_or_sections=self.d_model,
|
268 |
+
dim=-1,
|
269 |
+
)
|
270 |
+
return (weight * self.norm(input) + bias, embedding)
|
271 |
+
|
272 |
+
weight, bias = torch.split(
|
273 |
+
self.project_layer(embedding),
|
274 |
+
split_size_or_sections=self.d_model,
|
275 |
+
dim=-1,
|
276 |
+
)
|
277 |
+
return weight * self.norm(input) + bias
|
278 |
+
|
279 |
+
|
280 |
+
def _get_clones(module, N):
|
281 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
GPT_SoVITS/AR/text_processing/__init__.py
ADDED
File without changes
|
GPT_SoVITS/AR/text_processing/phonemizer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import itertools
|
4 |
+
import re
|
5 |
+
from typing import Dict
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import regex
|
9 |
+
from gruut import sentences
|
10 |
+
from gruut.const import Sentence
|
11 |
+
from gruut.const import Word
|
12 |
+
from AR.text_processing.symbols import SYMBOL_TO_ID
|
13 |
+
|
14 |
+
|
15 |
+
class GruutPhonemizer:
|
16 |
+
def __init__(self, language: str):
|
17 |
+
self._phonemizer = sentences
|
18 |
+
self.lang = language
|
19 |
+
self.symbol_to_id = SYMBOL_TO_ID
|
20 |
+
self._special_cases_dict: Dict[str] = {
|
21 |
+
r"\.\.\.": "... ",
|
22 |
+
";": "; ",
|
23 |
+
":": ": ",
|
24 |
+
",": ", ",
|
25 |
+
r"\.": ". ",
|
26 |
+
"!": "! ",
|
27 |
+
r"\?": "? ",
|
28 |
+
"—": "—",
|
29 |
+
"…": "… ",
|
30 |
+
"«": "«",
|
31 |
+
"»": "»",
|
32 |
+
}
|
33 |
+
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
34 |
+
|
35 |
+
def _normalize_punctuation(self, text: str) -> str:
|
36 |
+
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
37 |
+
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
|
38 |
+
text = regex.sub(r"\pZ+", r" ", text)
|
39 |
+
return text.strip()
|
40 |
+
|
41 |
+
def _convert_punctuation(self, word: Word) -> str:
|
42 |
+
if not word.phonemes:
|
43 |
+
return ""
|
44 |
+
if word.phonemes[0] in ["‖", "|"]:
|
45 |
+
return word.text.strip()
|
46 |
+
|
47 |
+
phonemes = "".join(word.phonemes)
|
48 |
+
# remove modifier characters ˈˌː with regex
|
49 |
+
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
|
50 |
+
return phonemes.strip()
|
51 |
+
|
52 |
+
def phonemize(self, text: str, espeak: bool = False) -> str:
|
53 |
+
text_to_phonemize: str = self._normalize_punctuation(text)
|
54 |
+
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
55 |
+
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
56 |
+
return " ".join(words)
|
57 |
+
|
58 |
+
def transform(self, phonemes):
|
59 |
+
# convert phonemes to ids
|
60 |
+
# dictionary is in symbols.py
|
61 |
+
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
phonemizer = GruutPhonemizer("en-us")
|
66 |
+
# text -> IPA
|
67 |
+
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
|
68 |
+
print("phonemes:", phonemes)
|
69 |
+
print("len(phonemes):", len(phonemes))
|
70 |
+
phoneme_ids = phonemizer.transform(phonemes)
|
71 |
+
print("phoneme_ids:", phoneme_ids)
|
72 |
+
print("len(phoneme_ids):", len(phoneme_ids))
|
GPT_SoVITS/AR/text_processing/symbols.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
PAD = "_"
|
4 |
+
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
5 |
+
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
6 |
+
IPA_LETTERS = (
|
7 |
+
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
8 |
+
)
|
9 |
+
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
10 |
+
SPACE_ID = SYMBOLS.index(" ")
|
11 |
+
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
12 |
+
ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
|
GPT_SoVITS/AR/utils/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def str2bool(str):
|
5 |
+
return True if str.lower() == "true" else False
|
6 |
+
|
7 |
+
|
8 |
+
def get_newest_ckpt(string_list):
|
9 |
+
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
10 |
+
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
11 |
+
|
12 |
+
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
13 |
+
extracted_info = []
|
14 |
+
for string in string_list:
|
15 |
+
match = re.match(pattern, string)
|
16 |
+
if match:
|
17 |
+
epoch = int(match.group(1))
|
18 |
+
step = int(match.group(2))
|
19 |
+
extracted_info.append((epoch, step, string))
|
20 |
+
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
21 |
+
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
22 |
+
# 获取最新的 ckpt 文件名
|
23 |
+
newest_ckpt = sorted_info[0][2]
|
24 |
+
return newest_ckpt
|
25 |
+
|
26 |
+
|
27 |
+
# 文本存在且不为空时 return True
|
28 |
+
def check_txt_file(file_path):
|
29 |
+
try:
|
30 |
+
with open(file_path, "r") as file:
|
31 |
+
text = file.readline().strip()
|
32 |
+
assert text.strip() != ""
|
33 |
+
return text
|
34 |
+
except Exception:
|
35 |
+
return False
|
36 |
+
return False
|
GPT_SoVITS/AR/utils/initialize.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Initialize modules for espnet2 neural networks."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from typeguard import check_argument_types
|
6 |
+
|
7 |
+
|
8 |
+
def initialize(model: torch.nn.Module, init: str):
|
9 |
+
"""Initialize weights of a neural network module.
|
10 |
+
|
11 |
+
Parameters are initialized using the given method or distribution.
|
12 |
+
|
13 |
+
Custom initialization routines can be implemented into submodules
|
14 |
+
as function `espnet_initialization_fn` within the custom module.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
model: Target.
|
18 |
+
init: Method of initialization.
|
19 |
+
"""
|
20 |
+
assert check_argument_types()
|
21 |
+
print("init with", init)
|
22 |
+
|
23 |
+
# weight init
|
24 |
+
for p in model.parameters():
|
25 |
+
if p.dim() > 1:
|
26 |
+
if init == "xavier_uniform":
|
27 |
+
torch.nn.init.xavier_uniform_(p.data)
|
28 |
+
elif init == "xavier_normal":
|
29 |
+
torch.nn.init.xavier_normal_(p.data)
|
30 |
+
elif init == "kaiming_uniform":
|
31 |
+
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
|
32 |
+
elif init == "kaiming_normal":
|
33 |
+
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
|
34 |
+
else:
|
35 |
+
raise ValueError("Unknown initialization: " + init)
|
36 |
+
# bias init
|
37 |
+
for name, p in model.named_parameters():
|
38 |
+
if ".bias" in name and p.dim() == 1:
|
39 |
+
p.data.zero_()
|
GPT_SoVITS/AR/utils/io.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
def load_yaml_config(path):
|
8 |
+
with open(path) as f:
|
9 |
+
config = yaml.full_load(f)
|
10 |
+
return config
|
11 |
+
|
12 |
+
|
13 |
+
def save_config_to_yaml(config, path):
|
14 |
+
assert path.endswith(".yaml")
|
15 |
+
with open(path, "w") as f:
|
16 |
+
f.write(yaml.dump(config))
|
17 |
+
f.close()
|
18 |
+
|
19 |
+
|
20 |
+
def write_args(args, path):
|
21 |
+
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
22 |
+
with open(path, "a") as args_file:
|
23 |
+
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
24 |
+
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
25 |
+
args_file.write("==> Cmd:\n")
|
26 |
+
args_file.write(str(sys.argv))
|
27 |
+
args_file.write("\n==> args:\n")
|
28 |
+
for k, v in sorted(args_dict.items()):
|
29 |
+
args_file.write(" %s: %s\n" % (str(k), str(v)))
|
30 |
+
args_file.close()
|
GPT_SoVITS/download.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
now_dir = os.getcwd()
|
5 |
+
sys.path.insert(0, now_dir)
|
6 |
+
from text.g2pw import G2PWPinyin
|
7 |
+
|
8 |
+
g2pw = G2PWPinyin(
|
9 |
+
model_dir="GPT_SoVITS/text/G2PWModel",
|
10 |
+
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
11 |
+
v_to_u=False,
|
12 |
+
neutral_tone_with_five=True,
|
13 |
+
)
|
GPT_SoVITS/export_torch_script.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
2 |
+
# reference: https://github.com/lifeiteng/vall-e
|
3 |
+
import argparse
|
4 |
+
from typing import Optional
|
5 |
+
from my_utils import load_audio
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
|
9 |
+
from torch import IntTensor, LongTensor, Tensor, nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
13 |
+
from feature_extractor import cnhubert
|
14 |
+
|
15 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
16 |
+
from module.models_onnx import SynthesizerTrn
|
17 |
+
|
18 |
+
from inference_webui import get_phones_and_bert
|
19 |
+
|
20 |
+
import os
|
21 |
+
import soundfile
|
22 |
+
|
23 |
+
default_config = {
|
24 |
+
"embedding_dim": 512,
|
25 |
+
"hidden_dim": 512,
|
26 |
+
"num_head": 8,
|
27 |
+
"num_layers": 12,
|
28 |
+
"num_codebook": 8,
|
29 |
+
"p_dropout": 0.0,
|
30 |
+
"vocab_size": 1024 + 1,
|
31 |
+
"phoneme_vocab_size": 512,
|
32 |
+
"EOS": 1024,
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
37 |
+
config = dict_s1["config"]
|
38 |
+
config["model"]["dropout"] = float(config["model"]["dropout"])
|
39 |
+
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
40 |
+
t2s_model.load_state_dict(dict_s1["weight"])
|
41 |
+
t2s_model = t2s_model.eval()
|
42 |
+
return t2s_model
|
43 |
+
|
44 |
+
|
45 |
+
@torch.jit.script
|
46 |
+
def logits_to_probs(
|
47 |
+
logits,
|
48 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
49 |
+
temperature: float = 1.0,
|
50 |
+
top_k: Optional[int] = None,
|
51 |
+
top_p: Optional[int] = None,
|
52 |
+
repetition_penalty: float = 1.0,
|
53 |
+
):
|
54 |
+
# if previous_tokens is not None:
|
55 |
+
# previous_tokens = previous_tokens.squeeze()
|
56 |
+
# print(logits.shape,previous_tokens.shape)
|
57 |
+
# pdb.set_trace()
|
58 |
+
if previous_tokens is not None and repetition_penalty != 1.0:
|
59 |
+
previous_tokens = previous_tokens.long()
|
60 |
+
score = torch.gather(logits, dim=1, index=previous_tokens)
|
61 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
62 |
+
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
63 |
+
|
64 |
+
if top_p is not None and top_p < 1.0:
|
65 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
66 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
67 |
+
sorted_indices_to_remove = cum_probs > top_p
|
68 |
+
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
69 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
70 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
71 |
+
|
72 |
+
logits = logits / max(temperature, 1e-5)
|
73 |
+
|
74 |
+
if top_k is not None:
|
75 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
76 |
+
pivot = v[:, -1].unsqueeze(-1)
|
77 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
78 |
+
|
79 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
80 |
+
return probs
|
81 |
+
|
82 |
+
|
83 |
+
@torch.jit.script
|
84 |
+
def multinomial_sample_one_no_sync(probs_sort):
|
85 |
+
# Does multinomial sampling without a cuda synchronization
|
86 |
+
q = torch.randn_like(probs_sort)
|
87 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
88 |
+
|
89 |
+
|
90 |
+
@torch.jit.script
|
91 |
+
def sample(
|
92 |
+
logits,
|
93 |
+
previous_tokens,
|
94 |
+
temperature: float = 1.0,
|
95 |
+
top_k: Optional[int] = None,
|
96 |
+
top_p: Optional[int] = None,
|
97 |
+
repetition_penalty: float = 1.0,
|
98 |
+
):
|
99 |
+
probs = logits_to_probs(
|
100 |
+
logits=logits,
|
101 |
+
previous_tokens=previous_tokens,
|
102 |
+
temperature=temperature,
|
103 |
+
top_k=top_k,
|
104 |
+
top_p=top_p,
|
105 |
+
repetition_penalty=repetition_penalty,
|
106 |
+
)
|
107 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
108 |
+
return idx_next, probs
|
109 |
+
|
110 |
+
|
111 |
+
@torch.jit.script
|
112 |
+
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
113 |
+
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
114 |
+
y = torch.nn.functional.pad(
|
115 |
+
y.unsqueeze(1),
|
116 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
117 |
+
mode="reflect",
|
118 |
+
)
|
119 |
+
y = y.squeeze(1)
|
120 |
+
spec = torch.stft(
|
121 |
+
y,
|
122 |
+
n_fft,
|
123 |
+
hop_length=hop_size,
|
124 |
+
win_length=win_size,
|
125 |
+
window=hann_window,
|
126 |
+
center=center,
|
127 |
+
pad_mode="reflect",
|
128 |
+
normalized=False,
|
129 |
+
onesided=True,
|
130 |
+
return_complex=False,
|
131 |
+
)
|
132 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
133 |
+
return spec
|
134 |
+
|
135 |
+
|
136 |
+
class DictToAttrRecursive(dict):
|
137 |
+
def __init__(self, input_dict):
|
138 |
+
super().__init__(input_dict)
|
139 |
+
for key, value in input_dict.items():
|
140 |
+
if isinstance(value, dict):
|
141 |
+
value = DictToAttrRecursive(value)
|
142 |
+
self[key] = value
|
143 |
+
setattr(self, key, value)
|
144 |
+
|
145 |
+
def __getattr__(self, item):
|
146 |
+
try:
|
147 |
+
return self[item]
|
148 |
+
except KeyError:
|
149 |
+
raise AttributeError(f"Attribute {item} not found")
|
150 |
+
|
151 |
+
def __setattr__(self, key, value):
|
152 |
+
if isinstance(value, dict):
|
153 |
+
value = DictToAttrRecursive(value)
|
154 |
+
super(DictToAttrRecursive, self).__setitem__(key, value)
|
155 |
+
super().__setattr__(key, value)
|
156 |
+
|
157 |
+
def __delattr__(self, item):
|
158 |
+
try:
|
159 |
+
del self[item]
|
160 |
+
except KeyError:
|
161 |
+
raise AttributeError(f"Attribute {item} not found")
|
162 |
+
|
163 |
+
|
164 |
+
@torch.jit.script
|
165 |
+
class T2SMLP:
|
166 |
+
def __init__(self, w1, b1, w2, b2):
|
167 |
+
self.w1 = w1
|
168 |
+
self.b1 = b1
|
169 |
+
self.w2 = w2
|
170 |
+
self.b2 = b2
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
x = F.relu(F.linear(x, self.w1, self.b1))
|
174 |
+
x = F.linear(x, self.w2, self.b2)
|
175 |
+
return x
|
176 |
+
|
177 |
+
|
178 |
+
@torch.jit.script
|
179 |
+
class T2SBlock:
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
num_heads: int,
|
183 |
+
hidden_dim: int,
|
184 |
+
mlp: T2SMLP,
|
185 |
+
qkv_w,
|
186 |
+
qkv_b,
|
187 |
+
out_w,
|
188 |
+
out_b,
|
189 |
+
norm_w1,
|
190 |
+
norm_b1,
|
191 |
+
norm_eps1: float,
|
192 |
+
norm_w2,
|
193 |
+
norm_b2,
|
194 |
+
norm_eps2: float,
|
195 |
+
):
|
196 |
+
self.num_heads = num_heads
|
197 |
+
self.mlp = mlp
|
198 |
+
self.hidden_dim: int = hidden_dim
|
199 |
+
self.qkv_w = qkv_w
|
200 |
+
self.qkv_b = qkv_b
|
201 |
+
self.out_w = out_w
|
202 |
+
self.out_b = out_b
|
203 |
+
self.norm_w1 = norm_w1
|
204 |
+
self.norm_b1 = norm_b1
|
205 |
+
self.norm_eps1 = norm_eps1
|
206 |
+
self.norm_w2 = norm_w2
|
207 |
+
self.norm_b2 = norm_b2
|
208 |
+
self.norm_eps2 = norm_eps2
|
209 |
+
|
210 |
+
self.false = torch.tensor(False, dtype=torch.bool)
|
211 |
+
|
212 |
+
@torch.jit.ignore
|
213 |
+
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
214 |
+
if padding_mask is None:
|
215 |
+
return x
|
216 |
+
|
217 |
+
if padding_mask.dtype == torch.bool:
|
218 |
+
return x.masked_fill(padding_mask, 0)
|
219 |
+
else:
|
220 |
+
return x * padding_mask
|
221 |
+
|
222 |
+
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
223 |
+
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
224 |
+
|
225 |
+
batch_size = q.shape[0]
|
226 |
+
q_len = q.shape[1]
|
227 |
+
kv_len = k.shape[1]
|
228 |
+
|
229 |
+
q = self.to_mask(q, padding_mask)
|
230 |
+
k_cache = self.to_mask(k, padding_mask)
|
231 |
+
v_cache = self.to_mask(v, padding_mask)
|
232 |
+
|
233 |
+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
234 |
+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
235 |
+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
236 |
+
|
237 |
+
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
238 |
+
|
239 |
+
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
240 |
+
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
241 |
+
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
242 |
+
|
243 |
+
if padding_mask is not None:
|
244 |
+
for i in range(batch_size):
|
245 |
+
# mask = padding_mask[i,:,0]
|
246 |
+
if self.false.device != padding_mask.device:
|
247 |
+
self.false = self.false.to(padding_mask.device)
|
248 |
+
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
249 |
+
x_item = x[i, idx, :].unsqueeze(0)
|
250 |
+
attn_item = attn[i, idx, :].unsqueeze(0)
|
251 |
+
x_item = x_item + attn_item
|
252 |
+
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
253 |
+
x_item = x_item + self.mlp.forward(x_item)
|
254 |
+
x_item = F.layer_norm(
|
255 |
+
x_item,
|
256 |
+
[self.hidden_dim],
|
257 |
+
self.norm_w2,
|
258 |
+
self.norm_b2,
|
259 |
+
self.norm_eps2,
|
260 |
+
)
|
261 |
+
x[i, idx, :] = x_item.squeeze(0)
|
262 |
+
x = self.to_mask(x, padding_mask)
|
263 |
+
else:
|
264 |
+
x = x + attn
|
265 |
+
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
266 |
+
x = x + self.mlp.forward(x)
|
267 |
+
x = F.layer_norm(
|
268 |
+
x,
|
269 |
+
[self.hidden_dim],
|
270 |
+
self.norm_w2,
|
271 |
+
self.norm_b2,
|
272 |
+
self.norm_eps2,
|
273 |
+
)
|
274 |
+
return x, k_cache, v_cache
|
275 |
+
|
276 |
+
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
277 |
+
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
278 |
+
|
279 |
+
k_cache = torch.cat([k_cache, k], dim=1)
|
280 |
+
v_cache = torch.cat([v_cache, v], dim=1)
|
281 |
+
|
282 |
+
batch_size = q.shape[0]
|
283 |
+
q_len = q.shape[1]
|
284 |
+
kv_len = k_cache.shape[1]
|
285 |
+
|
286 |
+
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
287 |
+
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
288 |
+
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
289 |
+
|
290 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
291 |
+
|
292 |
+
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
293 |
+
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
294 |
+
attn = F.linear(attn, self.out_w, self.out_b)
|
295 |
+
|
296 |
+
x = x + attn
|
297 |
+
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
298 |
+
x = x + self.mlp.forward(x)
|
299 |
+
x = F.layer_norm(
|
300 |
+
x,
|
301 |
+
[self.hidden_dim],
|
302 |
+
self.norm_w2,
|
303 |
+
self.norm_b2,
|
304 |
+
self.norm_eps2,
|
305 |
+
)
|
306 |
+
return x, k_cache, v_cache
|
307 |
+
|
308 |
+
|
309 |
+
@torch.jit.script
|
310 |
+
class T2STransformer:
|
311 |
+
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
312 |
+
self.num_blocks: int = num_blocks
|
313 |
+
self.blocks = blocks
|
314 |
+
|
315 |
+
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
316 |
+
k_cache: list[torch.Tensor] = []
|
317 |
+
v_cache: list[torch.Tensor] = []
|
318 |
+
for i in range(self.num_blocks):
|
319 |
+
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
320 |
+
k_cache.append(k_cache_)
|
321 |
+
v_cache.append(v_cache_)
|
322 |
+
return x, k_cache, v_cache
|
323 |
+
|
324 |
+
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
325 |
+
for i in range(self.num_blocks):
|
326 |
+
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
327 |
+
return x, k_cache, v_cache
|
328 |
+
|
329 |
+
|
330 |
+
class VitsModel(nn.Module):
|
331 |
+
def __init__(self, vits_path):
|
332 |
+
super().__init__()
|
333 |
+
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
334 |
+
dict_s2 = torch.load(vits_path)
|
335 |
+
self.hps = dict_s2["config"]
|
336 |
+
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
337 |
+
self.hps["model"]["version"] = "v1"
|
338 |
+
else:
|
339 |
+
self.hps["model"]["version"] = "v2"
|
340 |
+
|
341 |
+
self.hps = DictToAttrRecursive(self.hps)
|
342 |
+
self.hps.model.semantic_frame_rate = "25hz"
|
343 |
+
self.vq_model = SynthesizerTrn(
|
344 |
+
self.hps.data.filter_length // 2 + 1,
|
345 |
+
self.hps.train.segment_size // self.hps.data.hop_length,
|
346 |
+
n_speakers=self.hps.data.n_speakers,
|
347 |
+
**self.hps.model,
|
348 |
+
)
|
349 |
+
self.vq_model.eval()
|
350 |
+
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
351 |
+
|
352 |
+
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
353 |
+
refer = spectrogram_torch(
|
354 |
+
ref_audio,
|
355 |
+
self.hps.data.filter_length,
|
356 |
+
self.hps.data.sampling_rate,
|
357 |
+
self.hps.data.hop_length,
|
358 |
+
self.hps.data.win_length,
|
359 |
+
center=False,
|
360 |
+
)
|
361 |
+
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
362 |
+
|
363 |
+
|
364 |
+
class T2SModel(nn.Module):
|
365 |
+
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
366 |
+
super(T2SModel, self).__init__()
|
367 |
+
self.model_dim = raw_t2s.model.model_dim
|
368 |
+
self.embedding_dim = raw_t2s.model.embedding_dim
|
369 |
+
self.num_head = raw_t2s.model.num_head
|
370 |
+
self.num_layers = raw_t2s.model.num_layers
|
371 |
+
self.vocab_size = raw_t2s.model.vocab_size
|
372 |
+
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
373 |
+
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
374 |
+
self.EOS: int = int(raw_t2s.model.EOS)
|
375 |
+
self.norm_first = raw_t2s.model.norm_first
|
376 |
+
assert self.EOS == self.vocab_size - 1
|
377 |
+
self.hz = 50
|
378 |
+
|
379 |
+
self.bert_proj = raw_t2s.model.bert_proj
|
380 |
+
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
|
381 |
+
self.ar_text_position = raw_t2s.model.ar_text_position
|
382 |
+
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
|
383 |
+
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
384 |
+
|
385 |
+
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
386 |
+
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
387 |
+
|
388 |
+
blocks = []
|
389 |
+
h = raw_t2s.model.h
|
390 |
+
|
391 |
+
for i in range(self.num_layers):
|
392 |
+
layer = h.layers[i]
|
393 |
+
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
394 |
+
|
395 |
+
block = T2SBlock(
|
396 |
+
self.num_head,
|
397 |
+
self.model_dim,
|
398 |
+
t2smlp,
|
399 |
+
layer.self_attn.in_proj_weight,
|
400 |
+
layer.self_attn.in_proj_bias,
|
401 |
+
layer.self_attn.out_proj.weight,
|
402 |
+
layer.self_attn.out_proj.bias,
|
403 |
+
layer.norm1.weight,
|
404 |
+
layer.norm1.bias,
|
405 |
+
layer.norm1.eps,
|
406 |
+
layer.norm2.weight,
|
407 |
+
layer.norm2.bias,
|
408 |
+
layer.norm2.eps,
|
409 |
+
)
|
410 |
+
|
411 |
+
blocks.append(block)
|
412 |
+
|
413 |
+
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
414 |
+
|
415 |
+
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
416 |
+
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
|
417 |
+
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
418 |
+
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
419 |
+
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
420 |
+
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
421 |
+
|
422 |
+
def forward(
|
423 |
+
self,
|
424 |
+
prompts: LongTensor,
|
425 |
+
ref_seq: LongTensor,
|
426 |
+
text_seq: LongTensor,
|
427 |
+
ref_bert: torch.Tensor,
|
428 |
+
text_bert: torch.Tensor,
|
429 |
+
top_k: LongTensor,
|
430 |
+
):
|
431 |
+
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
432 |
+
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
433 |
+
bert = bert.unsqueeze(0)
|
434 |
+
|
435 |
+
x = self.ar_text_embedding(all_phoneme_ids)
|
436 |
+
x = x + self.bert_proj(bert.transpose(1, 2))
|
437 |
+
x: torch.Tensor = self.ar_text_position(x)
|
438 |
+
|
439 |
+
early_stop_num = self.early_stop_num
|
440 |
+
|
441 |
+
# [1,N,512] [1,N]
|
442 |
+
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
443 |
+
y = prompts
|
444 |
+
# x_example = x[:,:,0] * 0.0
|
445 |
+
|
446 |
+
x_len = x.shape[1]
|
447 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
448 |
+
|
449 |
+
y_emb = self.ar_audio_embedding(y)
|
450 |
+
y_len = y_emb.shape[1]
|
451 |
+
prefix_len = y.shape[1]
|
452 |
+
y_pos = self.ar_audio_position(y_emb)
|
453 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
454 |
+
|
455 |
+
bsz = x.shape[0]
|
456 |
+
src_len = x_len + y_len
|
457 |
+
x_attn_mask_pad = F.pad(
|
458 |
+
x_attn_mask,
|
459 |
+
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
460 |
+
value=True,
|
461 |
+
)
|
462 |
+
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
463 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
464 |
+
(x_len, 0),
|
465 |
+
value=False,
|
466 |
+
)
|
467 |
+
xy_attn_mask = (
|
468 |
+
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
469 |
+
.unsqueeze(0)
|
470 |
+
.expand(bsz * self.num_head, -1, -1)
|
471 |
+
.view(bsz, self.num_head, src_len, src_len)
|
472 |
+
.to(device=x.device, dtype=torch.bool)
|
473 |
+
)
|
474 |
+
|
475 |
+
idx = 0
|
476 |
+
top_k = int(top_k)
|
477 |
+
|
478 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
479 |
+
|
480 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
481 |
+
logits = logits[:, :-1]
|
482 |
+
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
483 |
+
y = torch.concat([y, samples], dim=1)
|
484 |
+
y_emb = self.ar_audio_embedding(y[:, -1:])
|
485 |
+
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
486 |
+
:, y_len + idx
|
487 |
+
].to(dtype=y_emb.dtype, device=y_emb.device)
|
488 |
+
|
489 |
+
stop = False
|
490 |
+
# for idx in range(1, 50):
|
491 |
+
for idx in range(1, 1500):
|
492 |
+
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
493 |
+
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
494 |
+
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
495 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
496 |
+
|
497 |
+
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
498 |
+
logits = logits[:, :-1]
|
499 |
+
|
500 |
+
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
501 |
+
|
502 |
+
y = torch.concat([y, samples], dim=1)
|
503 |
+
|
504 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
505 |
+
stop = True
|
506 |
+
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
507 |
+
stop = True
|
508 |
+
if stop:
|
509 |
+
if y.shape[1] == 0:
|
510 |
+
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
511 |
+
break
|
512 |
+
|
513 |
+
y_emb = self.ar_audio_embedding(y[:, -1:])
|
514 |
+
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
515 |
+
:, y_len + idx
|
516 |
+
].to(dtype=y_emb.dtype, device=y_emb.device)
|
517 |
+
|
518 |
+
y[0, -1] = 0
|
519 |
+
|
520 |
+
return y[:, -idx:].unsqueeze(0)
|
521 |
+
|
522 |
+
|
523 |
+
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
524 |
+
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
525 |
+
cnhubert.cnhubert_base_path = cnhubert_base_path
|
526 |
+
|
527 |
+
|
528 |
+
@torch.jit.script
|
529 |
+
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
530 |
+
phone_level_feature = []
|
531 |
+
for i in range(word2ph.shape[0]):
|
532 |
+
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
533 |
+
phone_level_feature.append(repeat_feature)
|
534 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
535 |
+
# [sum(word2ph), 1024]
|
536 |
+
return phone_level_feature
|
537 |
+
|
538 |
+
|
539 |
+
class MyBertModel(torch.nn.Module):
|
540 |
+
def __init__(self, bert_model):
|
541 |
+
super(MyBertModel, self).__init__()
|
542 |
+
self.bert = bert_model
|
543 |
+
|
544 |
+
def forward(
|
545 |
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
546 |
+
):
|
547 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
548 |
+
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
549 |
+
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
550 |
+
return build_phone_level_feature(res, word2ph)
|
551 |
+
|
552 |
+
|
553 |
+
class SSLModel(torch.nn.Module):
|
554 |
+
def __init__(self):
|
555 |
+
super().__init__()
|
556 |
+
self.ssl = cnhubert.get_model().model
|
557 |
+
|
558 |
+
def forward(self, ref_audio_16k) -> torch.Tensor:
|
559 |
+
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
560 |
+
return ssl_content
|
561 |
+
|
562 |
+
|
563 |
+
class ExportSSLModel(torch.nn.Module):
|
564 |
+
def __init__(self, ssl: SSLModel):
|
565 |
+
super().__init__()
|
566 |
+
self.ssl = ssl
|
567 |
+
|
568 |
+
def forward(self, ref_audio: torch.Tensor):
|
569 |
+
return self.ssl(ref_audio)
|
570 |
+
|
571 |
+
@torch.jit.export
|
572 |
+
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
573 |
+
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
574 |
+
return audio
|
575 |
+
|
576 |
+
|
577 |
+
def export_bert(output_path):
|
578 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
579 |
+
|
580 |
+
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
|
581 |
+
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
582 |
+
word2ph = []
|
583 |
+
for c in text:
|
584 |
+
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
585 |
+
word2ph.append(1)
|
586 |
+
else:
|
587 |
+
word2ph.append(2)
|
588 |
+
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
589 |
+
|
590 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
591 |
+
my_bert_model = MyBertModel(bert_model)
|
592 |
+
|
593 |
+
ref_bert_inputs = {
|
594 |
+
"input_ids": ref_bert_inputs["input_ids"],
|
595 |
+
"attention_mask": ref_bert_inputs["attention_mask"],
|
596 |
+
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
597 |
+
"word2ph": ref_bert_inputs["word2ph"],
|
598 |
+
}
|
599 |
+
|
600 |
+
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
601 |
+
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
602 |
+
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
603 |
+
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
604 |
+
|
605 |
+
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
606 |
+
output_path = os.path.join(output_path, "bert_model.pt")
|
607 |
+
my_bert_model.save(output_path)
|
608 |
+
print("#### exported bert ####")
|
609 |
+
|
610 |
+
|
611 |
+
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
612 |
+
if not os.path.exists(output_path):
|
613 |
+
os.makedirs(output_path)
|
614 |
+
print(f"目录已创建: {output_path}")
|
615 |
+
else:
|
616 |
+
print(f"目录已存在: {output_path}")
|
617 |
+
|
618 |
+
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
619 |
+
ssl = SSLModel()
|
620 |
+
if export_bert_and_ssl:
|
621 |
+
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
622 |
+
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
623 |
+
torch.jit.script(s).save(ssl_path)
|
624 |
+
print("#### exported ssl ####")
|
625 |
+
export_bert(output_path)
|
626 |
+
else:
|
627 |
+
s = ExportSSLModel(ssl)
|
628 |
+
|
629 |
+
print(f"device: {device}")
|
630 |
+
|
631 |
+
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
632 |
+
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
633 |
+
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
634 |
+
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
635 |
+
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
636 |
+
)
|
637 |
+
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
638 |
+
text_bert = text_bert_T.T.to(text_seq.device)
|
639 |
+
|
640 |
+
ssl_content = ssl(ref_audio).to(device)
|
641 |
+
|
642 |
+
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
643 |
+
vits = VitsModel(vits_path).to(device)
|
644 |
+
vits.eval()
|
645 |
+
|
646 |
+
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
647 |
+
# dict_s1 = torch.load(gpt_path, map_location=device)
|
648 |
+
dict_s1 = torch.load(gpt_path)
|
649 |
+
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
650 |
+
print("#### get_raw_t2s_model ####")
|
651 |
+
print(raw_t2s.config)
|
652 |
+
t2s_m = T2SModel(raw_t2s)
|
653 |
+
t2s_m.eval()
|
654 |
+
t2s = torch.jit.script(t2s_m).to(device)
|
655 |
+
print("#### script t2s_m ####")
|
656 |
+
|
657 |
+
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
658 |
+
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
659 |
+
gpt_sovits.eval()
|
660 |
+
|
661 |
+
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
662 |
+
|
663 |
+
torch._dynamo.mark_dynamic(ssl_content, 2)
|
664 |
+
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
665 |
+
torch._dynamo.mark_dynamic(ref_seq, 1)
|
666 |
+
torch._dynamo.mark_dynamic(text_seq, 1)
|
667 |
+
torch._dynamo.mark_dynamic(ref_bert, 0)
|
668 |
+
torch._dynamo.mark_dynamic(text_bert, 0)
|
669 |
+
|
670 |
+
top_k = torch.LongTensor([5]).to(device)
|
671 |
+
|
672 |
+
with torch.no_grad():
|
673 |
+
gpt_sovits_export = torch.jit.trace(
|
674 |
+
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
675 |
+
)
|
676 |
+
|
677 |
+
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
678 |
+
gpt_sovits_export.save(gpt_sovits_path)
|
679 |
+
print("#### exported gpt_sovits ####")
|
680 |
+
|
681 |
+
|
682 |
+
@torch.jit.script
|
683 |
+
def parse_audio(ref_audio):
|
684 |
+
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
685 |
+
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
686 |
+
return ref_audio_16k, ref_audio_sr
|
687 |
+
|
688 |
+
|
689 |
+
@torch.jit.script
|
690 |
+
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
691 |
+
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
692 |
+
|
693 |
+
|
694 |
+
class GPT_SoVITS(nn.Module):
|
695 |
+
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
696 |
+
super().__init__()
|
697 |
+
self.t2s = t2s
|
698 |
+
self.vits = vits
|
699 |
+
|
700 |
+
def forward(
|
701 |
+
self,
|
702 |
+
ssl_content: torch.Tensor,
|
703 |
+
ref_audio_sr: torch.Tensor,
|
704 |
+
ref_seq: Tensor,
|
705 |
+
text_seq: Tensor,
|
706 |
+
ref_bert: Tensor,
|
707 |
+
text_bert: Tensor,
|
708 |
+
top_k: LongTensor,
|
709 |
+
speed=1.0,
|
710 |
+
):
|
711 |
+
codes = self.vits.vq_model.extract_latent(ssl_content)
|
712 |
+
prompt_semantic = codes[0, 0]
|
713 |
+
prompts = prompt_semantic.unsqueeze(0)
|
714 |
+
|
715 |
+
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
716 |
+
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
717 |
+
return audio
|
718 |
+
|
719 |
+
|
720 |
+
def test():
|
721 |
+
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
722 |
+
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
723 |
+
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
724 |
+
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
725 |
+
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
726 |
+
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
727 |
+
|
728 |
+
args = parser.parse_args()
|
729 |
+
gpt_path = args.gpt_model
|
730 |
+
vits_path = args.sovits_model
|
731 |
+
ref_audio_path = args.ref_audio
|
732 |
+
ref_text = args.ref_text
|
733 |
+
|
734 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
735 |
+
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
736 |
+
# bert = MyBertModel(bert_model)
|
737 |
+
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
738 |
+
|
739 |
+
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
740 |
+
# raw_t2s = get_raw_t2s_model(dict_s1)
|
741 |
+
# t2s = T2SModel(raw_t2s)
|
742 |
+
# t2s.eval()
|
743 |
+
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
744 |
+
|
745 |
+
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
746 |
+
# vits = VitsModel(vits_path)
|
747 |
+
# vits.eval()
|
748 |
+
|
749 |
+
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
750 |
+
# ssl.eval()
|
751 |
+
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
752 |
+
|
753 |
+
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
754 |
+
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
755 |
+
|
756 |
+
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
757 |
+
ref_seq = torch.LongTensor([ref_seq_id])
|
758 |
+
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
759 |
+
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
760 |
+
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
761 |
+
|
762 |
+
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
763 |
+
|
764 |
+
test_bert = tokenizer(text, return_tensors="pt")
|
765 |
+
word2ph = []
|
766 |
+
for c in text:
|
767 |
+
if c in [",", "。", ":", "?", "?", ",", "."]:
|
768 |
+
word2ph.append(1)
|
769 |
+
else:
|
770 |
+
word2ph.append(2)
|
771 |
+
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
772 |
+
|
773 |
+
test_bert = my_bert(
|
774 |
+
test_bert["input_ids"].to("cuda"),
|
775 |
+
test_bert["attention_mask"].to("cuda"),
|
776 |
+
test_bert["token_type_ids"].to("cuda"),
|
777 |
+
test_bert["word2ph"].to("cuda"),
|
778 |
+
)
|
779 |
+
|
780 |
+
text_seq = torch.LongTensor([text_seq_id])
|
781 |
+
text_bert = text_bert_T.T.to(text_seq.device)
|
782 |
+
|
783 |
+
print("text_bert:", text_bert.shape, text_bert)
|
784 |
+
print("test_bert:", test_bert.shape, test_bert)
|
785 |
+
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
786 |
+
|
787 |
+
print("text_seq:", text_seq.shape)
|
788 |
+
print("text_bert:", text_bert.shape, text_bert.type())
|
789 |
+
|
790 |
+
# [1,N]
|
791 |
+
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
792 |
+
print("ref_audio:", ref_audio.shape)
|
793 |
+
|
794 |
+
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
795 |
+
print("start ssl")
|
796 |
+
ssl_content = ssl(ref_audio)
|
797 |
+
|
798 |
+
print("start gpt_sovits:")
|
799 |
+
print("ssl_content:", ssl_content.shape)
|
800 |
+
print("ref_audio_sr:", ref_audio_sr.shape)
|
801 |
+
print("ref_seq:", ref_seq.shape)
|
802 |
+
ref_seq = ref_seq.to("cuda")
|
803 |
+
print("text_seq:", text_seq.shape)
|
804 |
+
text_seq = text_seq.to("cuda")
|
805 |
+
print("ref_bert:", ref_bert.shape)
|
806 |
+
ref_bert = ref_bert.to("cuda")
|
807 |
+
print("text_bert:", text_bert.shape)
|
808 |
+
text_bert = text_bert.to("cuda")
|
809 |
+
|
810 |
+
top_k = torch.LongTensor([5]).to("cuda")
|
811 |
+
|
812 |
+
with torch.no_grad():
|
813 |
+
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
814 |
+
print("start write wav")
|
815 |
+
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
816 |
+
|
817 |
+
|
818 |
+
import text
|
819 |
+
import json
|
820 |
+
|
821 |
+
|
822 |
+
def export_symbel(version="v2"):
|
823 |
+
if version == "v1":
|
824 |
+
symbols = text._symbol_to_id_v1
|
825 |
+
with open("onnx/symbols_v1.json", "w") as file:
|
826 |
+
json.dump(symbols, file, indent=4)
|
827 |
+
else:
|
828 |
+
symbols = text._symbol_to_id_v2
|
829 |
+
with open("onnx/symbols_v2.json", "w") as file:
|
830 |
+
json.dump(symbols, file, indent=4)
|
831 |
+
|
832 |
+
|
833 |
+
def main():
|
834 |
+
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
835 |
+
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
836 |
+
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
837 |
+
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
838 |
+
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
839 |
+
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
840 |
+
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
841 |
+
parser.add_argument("--device", help="Device to use")
|
842 |
+
|
843 |
+
args = parser.parse_args()
|
844 |
+
export(
|
845 |
+
gpt_path=args.gpt_model,
|
846 |
+
vits_path=args.sovits_model,
|
847 |
+
ref_audio_path=args.ref_audio,
|
848 |
+
ref_text=args.ref_text,
|
849 |
+
output_path=args.output_path,
|
850 |
+
device=args.device,
|
851 |
+
export_bert_and_ssl=args.export_common_model,
|
852 |
+
)
|
853 |
+
|
854 |
+
|
855 |
+
import inference_webui
|
856 |
+
|
857 |
+
if __name__ == "__main__":
|
858 |
+
inference_webui.is_half = False
|
859 |
+
inference_webui.dtype = torch.float32
|
860 |
+
main()
|
861 |
+
# test()
|
GPT_SoVITS/export_torch_script_v3.py
ADDED
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from export_torch_script import (
|
3 |
+
T2SModel,
|
4 |
+
get_raw_t2s_model,
|
5 |
+
resamplex,
|
6 |
+
spectrogram_torch,
|
7 |
+
)
|
8 |
+
from f5_tts.model.backbones.dit import DiT
|
9 |
+
from inference_webui import get_phones_and_bert
|
10 |
+
import librosa
|
11 |
+
from module import commons
|
12 |
+
from module.mel_processing import mel_spectrogram_torch
|
13 |
+
from module.models_onnx import CFM, SynthesizerTrnV3
|
14 |
+
import numpy as np
|
15 |
+
import torch._dynamo.config
|
16 |
+
import torchaudio
|
17 |
+
import logging
|
18 |
+
import uvicorn
|
19 |
+
import torch
|
20 |
+
import soundfile
|
21 |
+
from librosa.filters import mel as librosa_mel_fn
|
22 |
+
|
23 |
+
|
24 |
+
from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
25 |
+
|
26 |
+
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
27 |
+
logger = logging.getLogger("uvicorn")
|
28 |
+
|
29 |
+
is_half = True
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
now_dir = os.getcwd()
|
32 |
+
|
33 |
+
|
34 |
+
class MelSpectrgram(torch.nn.Module):
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
dtype,
|
38 |
+
device,
|
39 |
+
n_fft,
|
40 |
+
num_mels,
|
41 |
+
sampling_rate,
|
42 |
+
hop_size,
|
43 |
+
win_size,
|
44 |
+
fmin,
|
45 |
+
fmax,
|
46 |
+
center=False,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
|
50 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
51 |
+
self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
|
52 |
+
self.n_fft: int = n_fft
|
53 |
+
self.hop_size: int = hop_size
|
54 |
+
self.win_size: int = win_size
|
55 |
+
self.center: bool = center
|
56 |
+
|
57 |
+
def forward(self, y):
|
58 |
+
y = torch.nn.functional.pad(
|
59 |
+
y.unsqueeze(1),
|
60 |
+
(
|
61 |
+
int((self.n_fft - self.hop_size) / 2),
|
62 |
+
int((self.n_fft - self.hop_size) / 2),
|
63 |
+
),
|
64 |
+
mode="reflect",
|
65 |
+
)
|
66 |
+
y = y.squeeze(1)
|
67 |
+
spec = torch.stft(
|
68 |
+
y,
|
69 |
+
self.n_fft,
|
70 |
+
hop_length=self.hop_size,
|
71 |
+
win_length=self.win_size,
|
72 |
+
window=self.hann_window,
|
73 |
+
center=self.center,
|
74 |
+
pad_mode="reflect",
|
75 |
+
normalized=False,
|
76 |
+
onesided=True,
|
77 |
+
return_complex=False,
|
78 |
+
)
|
79 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
|
80 |
+
spec = torch.matmul(self.mel_basis, spec)
|
81 |
+
# spec = spectral_normalize_torch(spec)
|
82 |
+
spec = torch.log(torch.clamp(spec, min=1e-5))
|
83 |
+
return spec
|
84 |
+
|
85 |
+
|
86 |
+
class ExportDitBlocks(torch.nn.Module):
|
87 |
+
def __init__(self, dit: DiT):
|
88 |
+
super().__init__()
|
89 |
+
self.transformer_blocks = dit.transformer_blocks
|
90 |
+
self.norm_out = dit.norm_out
|
91 |
+
self.proj_out = dit.proj_out
|
92 |
+
self.depth = dit.depth
|
93 |
+
|
94 |
+
def forward(self, x, t, mask, rope):
|
95 |
+
for block in self.transformer_blocks:
|
96 |
+
x = block(x, t, mask=mask, rope=(rope, 1.0))
|
97 |
+
x = self.norm_out(x, t)
|
98 |
+
output = self.proj_out(x)
|
99 |
+
return output
|
100 |
+
|
101 |
+
|
102 |
+
class ExportDitEmbed(torch.nn.Module):
|
103 |
+
def __init__(self, dit: DiT):
|
104 |
+
super().__init__()
|
105 |
+
self.time_embed = dit.time_embed
|
106 |
+
self.d_embed = dit.d_embed
|
107 |
+
self.text_embed = dit.text_embed
|
108 |
+
self.input_embed = dit.input_embed
|
109 |
+
self.rotary_embed = dit.rotary_embed
|
110 |
+
self.rotary_embed.inv_freq.to(device)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
x0: torch.Tensor, # nosied input audio # noqa: F722
|
115 |
+
cond0: torch.Tensor, # masked cond audio # noqa: F722
|
116 |
+
x_lens: torch.Tensor,
|
117 |
+
time: torch.Tensor, # time step # noqa: F821 F722
|
118 |
+
dt_base_bootstrap: torch.Tensor,
|
119 |
+
text0: torch.Tensor, # noqa: F722#####condition feature
|
120 |
+
):
|
121 |
+
x = x0.transpose(2, 1)
|
122 |
+
cond = cond0.transpose(2, 1)
|
123 |
+
text = text0.transpose(2, 1)
|
124 |
+
mask = commons.sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
125 |
+
|
126 |
+
t = self.time_embed(time) + self.d_embed(dt_base_bootstrap)
|
127 |
+
text_embed = self.text_embed(text, x.shape[1])
|
128 |
+
rope_t = torch.arange(x.shape[1], device=device)
|
129 |
+
rope, _ = self.rotary_embed(rope_t)
|
130 |
+
x = self.input_embed(x, cond, text_embed)
|
131 |
+
return x, t, mask, rope
|
132 |
+
|
133 |
+
|
134 |
+
class ExportDiT(torch.nn.Module):
|
135 |
+
def __init__(self, dit: DiT):
|
136 |
+
super().__init__()
|
137 |
+
if dit != None:
|
138 |
+
self.embed = ExportDitEmbed(dit)
|
139 |
+
self.blocks = ExportDitBlocks(dit)
|
140 |
+
else:
|
141 |
+
self.embed = None
|
142 |
+
self.blocks = None
|
143 |
+
|
144 |
+
def forward( # x, prompt_x, x_lens, t, style,cond
|
145 |
+
self, # d is channel,n is T
|
146 |
+
x0: torch.Tensor, # nosied input audio # noqa: F722
|
147 |
+
cond0: torch.Tensor, # masked cond audio # noqa: F722
|
148 |
+
x_lens: torch.Tensor,
|
149 |
+
time: torch.Tensor, # time step # noqa: F821 F722
|
150 |
+
dt_base_bootstrap: torch.Tensor,
|
151 |
+
text0: torch.Tensor, # noqa: F722#####condition feature
|
152 |
+
):
|
153 |
+
x, t, mask, rope = self.embed(x0, cond0, x_lens, time, dt_base_bootstrap, text0)
|
154 |
+
output = self.blocks(x, t, mask, rope)
|
155 |
+
return output
|
156 |
+
|
157 |
+
|
158 |
+
class ExportCFM(torch.nn.Module):
|
159 |
+
def __init__(self, cfm: CFM):
|
160 |
+
super().__init__()
|
161 |
+
self.cfm = cfm
|
162 |
+
|
163 |
+
def forward(
|
164 |
+
self,
|
165 |
+
fea_ref: torch.Tensor,
|
166 |
+
fea_todo_chunk: torch.Tensor,
|
167 |
+
mel2: torch.Tensor,
|
168 |
+
sample_steps: torch.LongTensor,
|
169 |
+
):
|
170 |
+
T_min = fea_ref.size(2)
|
171 |
+
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
172 |
+
cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
|
173 |
+
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
174 |
+
mel2 = cfm_res[:, :, -T_min:]
|
175 |
+
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
176 |
+
return cfm_res, fea_ref, mel2
|
177 |
+
|
178 |
+
|
179 |
+
mel_fn = lambda x: mel_spectrogram_torch(
|
180 |
+
x,
|
181 |
+
**{
|
182 |
+
"n_fft": 1024,
|
183 |
+
"win_size": 1024,
|
184 |
+
"hop_size": 256,
|
185 |
+
"num_mels": 100,
|
186 |
+
"sampling_rate": 24000,
|
187 |
+
"fmin": 0,
|
188 |
+
"fmax": None,
|
189 |
+
"center": False,
|
190 |
+
},
|
191 |
+
)
|
192 |
+
|
193 |
+
spec_min = -12
|
194 |
+
spec_max = 2
|
195 |
+
|
196 |
+
|
197 |
+
@torch.jit.script
|
198 |
+
def norm_spec(x):
|
199 |
+
spec_min = -12
|
200 |
+
spec_max = 2
|
201 |
+
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
202 |
+
|
203 |
+
|
204 |
+
def denorm_spec(x):
|
205 |
+
spec_min = -12
|
206 |
+
spec_max = 2
|
207 |
+
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
208 |
+
|
209 |
+
|
210 |
+
class ExportGPTSovitsHalf(torch.nn.Module):
|
211 |
+
def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
|
212 |
+
super().__init__()
|
213 |
+
self.hps = hps
|
214 |
+
self.t2s_m = t2s_m
|
215 |
+
self.vq_model = vq_model
|
216 |
+
self.mel2 = MelSpectrgram(
|
217 |
+
dtype=torch.float32,
|
218 |
+
device=device,
|
219 |
+
n_fft=1024,
|
220 |
+
num_mels=100,
|
221 |
+
sampling_rate=24000,
|
222 |
+
hop_size=256,
|
223 |
+
win_size=1024,
|
224 |
+
fmin=0,
|
225 |
+
fmax=None,
|
226 |
+
center=False,
|
227 |
+
)
|
228 |
+
# self.dtype = dtype
|
229 |
+
self.filter_length: int = hps.data.filter_length
|
230 |
+
self.sampling_rate: int = hps.data.sampling_rate
|
231 |
+
self.hop_length: int = hps.data.hop_length
|
232 |
+
self.win_length: int = hps.data.win_length
|
233 |
+
|
234 |
+
def forward(
|
235 |
+
self,
|
236 |
+
ssl_content,
|
237 |
+
ref_audio_32k: torch.FloatTensor,
|
238 |
+
phoneme_ids0,
|
239 |
+
phoneme_ids1,
|
240 |
+
bert1,
|
241 |
+
bert2,
|
242 |
+
top_k,
|
243 |
+
):
|
244 |
+
refer = spectrogram_torch(
|
245 |
+
ref_audio_32k,
|
246 |
+
self.filter_length,
|
247 |
+
self.sampling_rate,
|
248 |
+
self.hop_length,
|
249 |
+
self.win_length,
|
250 |
+
center=False,
|
251 |
+
).to(ssl_content.dtype)
|
252 |
+
|
253 |
+
codes = self.vq_model.extract_latent(ssl_content)
|
254 |
+
prompt_semantic = codes[0, 0]
|
255 |
+
prompt = prompt_semantic.unsqueeze(0)
|
256 |
+
# print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
257 |
+
|
258 |
+
pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
259 |
+
# print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
260 |
+
|
261 |
+
ge = self.vq_model.create_ge(refer)
|
262 |
+
# print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
263 |
+
|
264 |
+
prompt_ = prompt.unsqueeze(0)
|
265 |
+
fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
|
266 |
+
# print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
267 |
+
# print(prompt_.shape, phoneme_ids0.shape, ge.shape)
|
268 |
+
# print(fea_ref.shape)
|
269 |
+
|
270 |
+
ref_24k = resamplex(ref_audio_32k, 32000, 24000)
|
271 |
+
mel2 = norm_spec(self.mel2(ref_24k)).to(ssl_content.dtype)
|
272 |
+
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
273 |
+
mel2 = mel2[:, :, :T_min]
|
274 |
+
fea_ref = fea_ref[:, :, :T_min]
|
275 |
+
if T_min > 468:
|
276 |
+
mel2 = mel2[:, :, -468:]
|
277 |
+
fea_ref = fea_ref[:, :, -468:]
|
278 |
+
T_min = 468
|
279 |
+
|
280 |
+
fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
|
281 |
+
# print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
282 |
+
# print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
|
283 |
+
# print(fea_todo.shape)
|
284 |
+
|
285 |
+
return fea_ref, fea_todo, mel2
|
286 |
+
|
287 |
+
|
288 |
+
class GPTSoVITSV3(torch.nn.Module):
|
289 |
+
def __init__(self, gpt_sovits_half, cfm, bigvgan):
|
290 |
+
super().__init__()
|
291 |
+
self.gpt_sovits_half = gpt_sovits_half
|
292 |
+
self.cfm = cfm
|
293 |
+
self.bigvgan = bigvgan
|
294 |
+
|
295 |
+
def forward(
|
296 |
+
self,
|
297 |
+
ssl_content,
|
298 |
+
ref_audio_32k: torch.FloatTensor,
|
299 |
+
phoneme_ids0: torch.LongTensor,
|
300 |
+
phoneme_ids1: torch.LongTensor,
|
301 |
+
bert1,
|
302 |
+
bert2,
|
303 |
+
top_k: torch.LongTensor,
|
304 |
+
sample_steps: torch.LongTensor,
|
305 |
+
):
|
306 |
+
# current_time = datetime.now()
|
307 |
+
# print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
308 |
+
fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
|
309 |
+
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
310 |
+
)
|
311 |
+
chunk_len = 934 - fea_ref.shape[2]
|
312 |
+
wav_gen_list = []
|
313 |
+
idx = 0
|
314 |
+
wav_gen_length = fea_todo.shape[2] * 256
|
315 |
+
while 1:
|
316 |
+
# current_time = datetime.now()
|
317 |
+
# print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
318 |
+
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
319 |
+
if fea_todo_chunk.shape[-1] == 0:
|
320 |
+
break
|
321 |
+
|
322 |
+
# 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
|
323 |
+
# 所以在这里补0让他shape维持不变
|
324 |
+
# 但是这样会导致生成的音频长度不对,所以在最后截取一下。
|
325 |
+
# 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
|
326 |
+
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
327 |
+
if complete_len != 0:
|
328 |
+
fea_todo_chunk = torch.cat(
|
329 |
+
[
|
330 |
+
fea_todo_chunk,
|
331 |
+
torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
|
332 |
+
],
|
333 |
+
2,
|
334 |
+
)
|
335 |
+
|
336 |
+
cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
337 |
+
idx += chunk_len
|
338 |
+
|
339 |
+
cfm_res = denorm_spec(cfm_res)
|
340 |
+
bigvgan_res = self.bigvgan(cfm_res)
|
341 |
+
wav_gen_list.append(bigvgan_res)
|
342 |
+
|
343 |
+
wav_gen = torch.cat(wav_gen_list, 2)
|
344 |
+
return wav_gen[0][0][:wav_gen_length]
|
345 |
+
|
346 |
+
|
347 |
+
def init_bigvgan():
|
348 |
+
global bigvgan_model
|
349 |
+
from BigVGAN import bigvgan
|
350 |
+
|
351 |
+
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
352 |
+
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
353 |
+
use_cuda_kernel=False,
|
354 |
+
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
355 |
+
# remove weight norm in the model and set to eval mode
|
356 |
+
bigvgan_model.remove_weight_norm()
|
357 |
+
bigvgan_model = bigvgan_model.eval()
|
358 |
+
if is_half == True:
|
359 |
+
bigvgan_model = bigvgan_model.half().to(device)
|
360 |
+
else:
|
361 |
+
bigvgan_model = bigvgan_model.to(device)
|
362 |
+
|
363 |
+
|
364 |
+
class Sovits:
|
365 |
+
def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
|
366 |
+
self.vq_model = vq_model
|
367 |
+
self.hps = hps
|
368 |
+
cfm.estimator = ExportDiT(cfm.estimator)
|
369 |
+
self.cfm = cfm
|
370 |
+
|
371 |
+
|
372 |
+
class DictToAttrRecursive(dict):
|
373 |
+
def __init__(self, input_dict):
|
374 |
+
super().__init__(input_dict)
|
375 |
+
for key, value in input_dict.items():
|
376 |
+
if isinstance(value, dict):
|
377 |
+
value = DictToAttrRecursive(value)
|
378 |
+
self[key] = value
|
379 |
+
setattr(self, key, value)
|
380 |
+
|
381 |
+
def __getattr__(self, item):
|
382 |
+
try:
|
383 |
+
return self[item]
|
384 |
+
except KeyError:
|
385 |
+
raise AttributeError(f"Attribute {item} not found")
|
386 |
+
|
387 |
+
def __setattr__(self, key, value):
|
388 |
+
if isinstance(value, dict):
|
389 |
+
value = DictToAttrRecursive(value)
|
390 |
+
super(DictToAttrRecursive, self).__setitem__(key, value)
|
391 |
+
super().__setattr__(key, value)
|
392 |
+
|
393 |
+
def __delattr__(self, item):
|
394 |
+
try:
|
395 |
+
del self[item]
|
396 |
+
except KeyError:
|
397 |
+
raise AttributeError(f"Attribute {item} not found")
|
398 |
+
|
399 |
+
|
400 |
+
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
401 |
+
|
402 |
+
|
403 |
+
def get_sovits_weights(sovits_path):
|
404 |
+
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
405 |
+
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
406 |
+
|
407 |
+
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
408 |
+
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
409 |
+
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
410 |
+
|
411 |
+
dict_s2 = load_sovits_new(sovits_path)
|
412 |
+
hps = dict_s2["config"]
|
413 |
+
hps = DictToAttrRecursive(hps)
|
414 |
+
hps.model.semantic_frame_rate = "25hz"
|
415 |
+
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
416 |
+
hps.model.version = "v2" # v3model,v2sybomls
|
417 |
+
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
418 |
+
hps.model.version = "v1"
|
419 |
+
else:
|
420 |
+
hps.model.version = "v2"
|
421 |
+
|
422 |
+
if model_version == "v3":
|
423 |
+
hps.model.version = "v3"
|
424 |
+
|
425 |
+
logger.info(f"hps: {hps}")
|
426 |
+
|
427 |
+
vq_model = SynthesizerTrnV3(
|
428 |
+
hps.data.filter_length // 2 + 1,
|
429 |
+
hps.train.segment_size // hps.data.hop_length,
|
430 |
+
n_speakers=hps.data.n_speakers,
|
431 |
+
**hps.model,
|
432 |
+
)
|
433 |
+
# init_bigvgan()
|
434 |
+
model_version = hps.model.version
|
435 |
+
logger.info(f"模型版本: {model_version}")
|
436 |
+
|
437 |
+
if is_half == True:
|
438 |
+
vq_model = vq_model.half().to(device)
|
439 |
+
else:
|
440 |
+
vq_model = vq_model.to(device)
|
441 |
+
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
442 |
+
vq_model.eval()
|
443 |
+
|
444 |
+
cfm = vq_model.cfm
|
445 |
+
del vq_model.cfm
|
446 |
+
|
447 |
+
sovits = Sovits(vq_model, cfm, hps)
|
448 |
+
return sovits
|
449 |
+
|
450 |
+
|
451 |
+
logger.info(f"torch version {torch.__version__}")
|
452 |
+
# ssl_model = cnhubert.get_model()
|
453 |
+
# if is_half:
|
454 |
+
# ssl_model = ssl_model.half().to(device)
|
455 |
+
# else:
|
456 |
+
# ssl_model = ssl_model.to(device)
|
457 |
+
|
458 |
+
|
459 |
+
def export_cfm(
|
460 |
+
e_cfm: ExportCFM,
|
461 |
+
mu: torch.Tensor,
|
462 |
+
x_lens: torch.LongTensor,
|
463 |
+
prompt: torch.Tensor,
|
464 |
+
n_timesteps: torch.IntTensor,
|
465 |
+
temperature=1.0,
|
466 |
+
):
|
467 |
+
cfm = e_cfm.cfm
|
468 |
+
|
469 |
+
B, T = mu.size(0), mu.size(1)
|
470 |
+
x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
471 |
+
print("x:", x.shape, x.dtype)
|
472 |
+
prompt_len = prompt.size(-1)
|
473 |
+
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
474 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
475 |
+
x[..., :prompt_len] = 0.0
|
476 |
+
mu = mu.transpose(2, 1)
|
477 |
+
|
478 |
+
ntimestep = int(n_timesteps)
|
479 |
+
|
480 |
+
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
481 |
+
d = torch.tensor(1.0 / ntimestep, dtype=x.dtype, device=x.device)
|
482 |
+
|
483 |
+
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
484 |
+
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
485 |
+
|
486 |
+
print(
|
487 |
+
"cfm input shapes:",
|
488 |
+
x.shape,
|
489 |
+
prompt_x.shape,
|
490 |
+
x_lens.shape,
|
491 |
+
t_tensor.shape,
|
492 |
+
d_tensor.shape,
|
493 |
+
mu.shape,
|
494 |
+
)
|
495 |
+
|
496 |
+
print("cfm input dtypes:", x.dtype, prompt_x.dtype, x_lens.dtype, t_tensor.dtype, d_tensor.dtype, mu.dtype)
|
497 |
+
|
498 |
+
estimator: ExportDiT = torch.jit.trace(
|
499 |
+
cfm.estimator,
|
500 |
+
optimize=True,
|
501 |
+
example_inputs=(x, prompt_x, x_lens, t_tensor, d_tensor, mu),
|
502 |
+
)
|
503 |
+
estimator.save("onnx/ad/estimator.pt")
|
504 |
+
# torch.onnx.export(
|
505 |
+
# cfm.estimator,
|
506 |
+
# (x, prompt_x, x_lens, t_tensor, d_tensor, mu),
|
507 |
+
# "onnx/ad/dit.onnx",
|
508 |
+
# input_names=["x", "prompt_x", "x_lens", "t", "d", "mu"],
|
509 |
+
# output_names=["output"],
|
510 |
+
# dynamic_axes={
|
511 |
+
# "x": [2],
|
512 |
+
# "prompt_x": [2],
|
513 |
+
# "mu": [2],
|
514 |
+
# },
|
515 |
+
# )
|
516 |
+
print("save estimator ok")
|
517 |
+
cfm.estimator = estimator
|
518 |
+
export_cfm = torch.jit.script(e_cfm)
|
519 |
+
export_cfm.save("onnx/ad/cfm.pt")
|
520 |
+
# sovits.cfm = cfm
|
521 |
+
# cfm.save("onnx/ad/cfm.pt")
|
522 |
+
return export_cfm
|
523 |
+
|
524 |
+
|
525 |
+
def export():
|
526 |
+
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
527 |
+
|
528 |
+
init_bigvgan()
|
529 |
+
|
530 |
+
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
531 |
+
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
532 |
+
print("#### get_raw_t2s_model ####")
|
533 |
+
print(raw_t2s.config)
|
534 |
+
|
535 |
+
if is_half:
|
536 |
+
raw_t2s = raw_t2s.half().to(device)
|
537 |
+
|
538 |
+
t2s_m = T2SModel(raw_t2s)
|
539 |
+
t2s_m.eval()
|
540 |
+
script_t2s = torch.jit.script(t2s_m).to(device)
|
541 |
+
|
542 |
+
hps = sovits.hps
|
543 |
+
ref_wav_path = "onnx/ad/ref.wav"
|
544 |
+
speed = 1.0
|
545 |
+
sample_steps = 32
|
546 |
+
dtype = torch.float16 if is_half == True else torch.float32
|
547 |
+
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
548 |
+
zero_wav = np.zeros(
|
549 |
+
int(hps.data.sampling_rate * 0.3),
|
550 |
+
dtype=np.float16 if is_half == True else np.float32,
|
551 |
+
)
|
552 |
+
|
553 |
+
with torch.no_grad():
|
554 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
555 |
+
wav16k = torch.from_numpy(wav16k)
|
556 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
557 |
+
|
558 |
+
if is_half == True:
|
559 |
+
wav16k = wav16k.half().to(device)
|
560 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
561 |
+
else:
|
562 |
+
wav16k = wav16k.to(device)
|
563 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
564 |
+
wav16k = torch.cat([wav16k, zero_wav_torch])
|
565 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
566 |
+
codes = sovits.vq_model.extract_latent(ssl_content)
|
567 |
+
prompt_semantic = codes[0, 0]
|
568 |
+
prompt = prompt_semantic.unsqueeze(0).to(device)
|
569 |
+
|
570 |
+
phones1, bert1, norm_text1 = get_phones_and_bert(
|
571 |
+
"你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
572 |
+
)
|
573 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(
|
574 |
+
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
575 |
+
"auto",
|
576 |
+
"v3",
|
577 |
+
)
|
578 |
+
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
579 |
+
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
580 |
+
|
581 |
+
# codes = sovits.vq_model.extract_latent(ssl_content)
|
582 |
+
# prompt_semantic = codes[0, 0]
|
583 |
+
# prompts = prompt_semantic.unsqueeze(0)
|
584 |
+
|
585 |
+
top_k = torch.LongTensor([15]).to(device)
|
586 |
+
print("topk", top_k)
|
587 |
+
|
588 |
+
bert1 = bert1.T.to(device)
|
589 |
+
bert2 = bert2.T.to(device)
|
590 |
+
print(
|
591 |
+
prompt.dtype,
|
592 |
+
phoneme_ids0.dtype,
|
593 |
+
phoneme_ids1.dtype,
|
594 |
+
bert1.dtype,
|
595 |
+
bert2.dtype,
|
596 |
+
top_k.dtype,
|
597 |
+
)
|
598 |
+
print(
|
599 |
+
prompt.shape,
|
600 |
+
phoneme_ids0.shape,
|
601 |
+
phoneme_ids1.shape,
|
602 |
+
bert1.shape,
|
603 |
+
bert2.shape,
|
604 |
+
top_k.shape,
|
605 |
+
)
|
606 |
+
pred_semantic = t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
|
607 |
+
|
608 |
+
ge = sovits.vq_model.create_ge(refer)
|
609 |
+
prompt_ = prompt.unsqueeze(0)
|
610 |
+
|
611 |
+
torch._dynamo.mark_dynamic(prompt_, 2)
|
612 |
+
torch._dynamo.mark_dynamic(phoneme_ids0, 1)
|
613 |
+
|
614 |
+
fea_ref = sovits.vq_model(prompt_, phoneme_ids0, ge)
|
615 |
+
|
616 |
+
inputs = {
|
617 |
+
"forward": (prompt_, phoneme_ids0, ge),
|
618 |
+
"extract_latent": ssl_content,
|
619 |
+
"create_ge": refer,
|
620 |
+
}
|
621 |
+
|
622 |
+
trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
|
623 |
+
trace_vq_model.save("onnx/ad/vq_model.pt")
|
624 |
+
|
625 |
+
print(fea_ref.shape, fea_ref.dtype, ge.shape)
|
626 |
+
print(prompt_.shape, phoneme_ids0.shape, ge.shape)
|
627 |
+
|
628 |
+
# vq_model = torch.jit.trace(
|
629 |
+
# sovits.vq_model,
|
630 |
+
# optimize=True,
|
631 |
+
# # strict=False,
|
632 |
+
# example_inputs=(prompt_, phoneme_ids0, ge),
|
633 |
+
# )
|
634 |
+
# vq_model = sovits.vq_model
|
635 |
+
vq_model = trace_vq_model
|
636 |
+
|
637 |
+
gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
|
638 |
+
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
|
639 |
+
|
640 |
+
ref_audio, sr = torchaudio.load(ref_wav_path)
|
641 |
+
ref_audio = ref_audio.to(device).float()
|
642 |
+
if ref_audio.shape[0] == 2:
|
643 |
+
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
644 |
+
if sr != 24000:
|
645 |
+
ref_audio = resample(ref_audio, sr)
|
646 |
+
# mel2 = mel_fn(ref_audio)
|
647 |
+
mel2 = norm_spec(mel_fn(ref_audio))
|
648 |
+
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
649 |
+
fea_ref = fea_ref[:, :, :T_min]
|
650 |
+
print("fea_ref:", fea_ref.shape, T_min)
|
651 |
+
if T_min > 468:
|
652 |
+
mel2 = mel2[:, :, -468:]
|
653 |
+
fea_ref = fea_ref[:, :, -468:]
|
654 |
+
T_min = 468
|
655 |
+
chunk_len = 934 - T_min
|
656 |
+
mel2 = mel2.to(dtype)
|
657 |
+
|
658 |
+
# fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
|
659 |
+
fea_todo = vq_model(pred_semantic, phoneme_ids1, ge)
|
660 |
+
|
661 |
+
cfm_resss = []
|
662 |
+
idx = 0
|
663 |
+
sample_steps = torch.LongTensor([sample_steps]).to(device)
|
664 |
+
export_cfm_ = ExportCFM(sovits.cfm)
|
665 |
+
while 1:
|
666 |
+
print("idx:", idx)
|
667 |
+
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
668 |
+
if fea_todo_chunk.shape[-1] == 0:
|
669 |
+
break
|
670 |
+
|
671 |
+
print(
|
672 |
+
"export_cfm:",
|
673 |
+
fea_ref.shape,
|
674 |
+
fea_todo_chunk.shape,
|
675 |
+
mel2.shape,
|
676 |
+
sample_steps.shape,
|
677 |
+
)
|
678 |
+
if idx == 0:
|
679 |
+
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
680 |
+
export_cfm_ = export_cfm(
|
681 |
+
export_cfm_,
|
682 |
+
fea,
|
683 |
+
torch.LongTensor([fea.size(1)]).to(fea.device),
|
684 |
+
mel2,
|
685 |
+
sample_steps,
|
686 |
+
)
|
687 |
+
# torch.onnx.export(
|
688 |
+
# export_cfm_,
|
689 |
+
# (
|
690 |
+
# fea_ref,
|
691 |
+
# fea_todo_chunk,
|
692 |
+
# mel2,
|
693 |
+
# sample_steps,
|
694 |
+
# ),
|
695 |
+
# "onnx/ad/cfm.onnx",
|
696 |
+
# input_names=["fea_ref", "fea_todo_chunk", "mel2", "sample_steps"],
|
697 |
+
# output_names=["cfm_res", "fea_ref_", "mel2_"],
|
698 |
+
# dynamic_axes={
|
699 |
+
# "fea_ref": [2],
|
700 |
+
# "fea_todo_chunk": [2],
|
701 |
+
# "mel2": [2],
|
702 |
+
# },
|
703 |
+
# )
|
704 |
+
|
705 |
+
idx += chunk_len
|
706 |
+
|
707 |
+
cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
708 |
+
cfm_resss.append(cfm_res)
|
709 |
+
continue
|
710 |
+
|
711 |
+
cmf_res = torch.cat(cfm_resss, 2)
|
712 |
+
cmf_res = denorm_spec(cmf_res).to(device)
|
713 |
+
print("cmf_res:", cmf_res.shape, cmf_res.dtype)
|
714 |
+
with torch.inference_mode():
|
715 |
+
cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
|
716 |
+
torch._dynamo.mark_dynamic(cmf_res_rand, 2)
|
717 |
+
bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
718 |
+
bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
|
719 |
+
wav_gen = bigvgan_model(cmf_res)
|
720 |
+
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
721 |
+
audio = wav_gen[0][0].cpu().detach().numpy()
|
722 |
+
|
723 |
+
sr = 24000
|
724 |
+
soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
|
725 |
+
|
726 |
+
|
727 |
+
from datetime import datetime
|
728 |
+
|
729 |
+
|
730 |
+
def test_export(
|
731 |
+
todo_text,
|
732 |
+
gpt_sovits_v3_half,
|
733 |
+
cfm,
|
734 |
+
bigvgan,
|
735 |
+
output,
|
736 |
+
):
|
737 |
+
# hps = sovits.hps
|
738 |
+
ref_wav_path = "onnx/ad/ref.wav"
|
739 |
+
speed = 1.0
|
740 |
+
sample_steps = 8
|
741 |
+
|
742 |
+
dtype = torch.float16 if is_half == True else torch.float32
|
743 |
+
|
744 |
+
zero_wav = np.zeros(
|
745 |
+
int(16000 * 0.3),
|
746 |
+
dtype=np.float16 if is_half == True else np.float32,
|
747 |
+
)
|
748 |
+
|
749 |
+
with torch.no_grad():
|
750 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
751 |
+
wav16k = torch.from_numpy(wav16k)
|
752 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
753 |
+
|
754 |
+
if is_half == True:
|
755 |
+
wav16k = wav16k.half().to(device)
|
756 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
757 |
+
else:
|
758 |
+
wav16k = wav16k.to(device)
|
759 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
760 |
+
wav16k = torch.cat([wav16k, zero_wav_torch])
|
761 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
762 |
+
|
763 |
+
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
764 |
+
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
765 |
+
|
766 |
+
phones1, bert1, norm_text1 = get_phones_and_bert(
|
767 |
+
"你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
768 |
+
)
|
769 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(
|
770 |
+
todo_text,
|
771 |
+
"zh",
|
772 |
+
"v3",
|
773 |
+
)
|
774 |
+
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
775 |
+
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
776 |
+
|
777 |
+
bert1 = bert1.T.to(device)
|
778 |
+
bert2 = bert2.T.to(device)
|
779 |
+
top_k = torch.LongTensor([15]).to(device)
|
780 |
+
|
781 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
782 |
+
logger.info("start inference %s", current_time)
|
783 |
+
print(
|
784 |
+
ssl_content.shape,
|
785 |
+
ref_audio_32k.shape,
|
786 |
+
phoneme_ids0.shape,
|
787 |
+
phoneme_ids1.shape,
|
788 |
+
bert1.shape,
|
789 |
+
bert2.shape,
|
790 |
+
top_k.shape,
|
791 |
+
)
|
792 |
+
fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
|
793 |
+
ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
|
794 |
+
)
|
795 |
+
chunk_len = 934 - fea_ref.shape[2]
|
796 |
+
print(fea_ref.shape, fea_todo.shape, mel2.shape)
|
797 |
+
|
798 |
+
cfm_resss = []
|
799 |
+
sample_steps = torch.LongTensor([sample_steps])
|
800 |
+
idx = 0
|
801 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
802 |
+
logger.info("start cfm %s", current_time)
|
803 |
+
wav_gen_length = fea_todo.shape[2] * 256
|
804 |
+
|
805 |
+
while 1:
|
806 |
+
current_time = datetime.now()
|
807 |
+
print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
808 |
+
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
809 |
+
if fea_todo_chunk.shape[-1] == 0:
|
810 |
+
break
|
811 |
+
|
812 |
+
complete_len = chunk_len - fea_todo_chunk.shape[-1]
|
813 |
+
if complete_len != 0:
|
814 |
+
fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(device).to(dtype)], 2)
|
815 |
+
|
816 |
+
cfm_res, fea_ref, mel2 = cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
|
817 |
+
# if complete_len > 0 :
|
818 |
+
# cfm_res = cfm_res[:, :, :-complete_len]
|
819 |
+
# fea_ref = fea_ref[:, :, :-complete_len]
|
820 |
+
# mel2 = mel2[:, :, :-complete_len]
|
821 |
+
|
822 |
+
idx += chunk_len
|
823 |
+
|
824 |
+
current_time = datetime.now()
|
825 |
+
print("cfm end", current_time.strftime("%Y-%m-%d %H:%M:%S"))
|
826 |
+
cfm_res = denorm_spec(cfm_res).to(device)
|
827 |
+
bigvgan_res = bigvgan(cfm_res)
|
828 |
+
cfm_resss.append(bigvgan_res)
|
829 |
+
|
830 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
831 |
+
logger.info("start bigvgan %s", current_time)
|
832 |
+
wav_gen = torch.cat(cfm_resss, 2)
|
833 |
+
# cmf_res = denorm_spec(cmf_res)
|
834 |
+
# cmf_res = cmf_res.to(device)
|
835 |
+
# print("cmf_res:", cmf_res.shape)
|
836 |
+
|
837 |
+
# cmf_res = torch.cat([cmf_res,torch.zeros([1,100,2000-cmf_res.size(2)],device=device,dtype=cmf_res.dtype)], 2)
|
838 |
+
|
839 |
+
# wav_gen = bigvgan(cmf_res)
|
840 |
+
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
841 |
+
wav_gen = wav_gen[:, :, :wav_gen_length]
|
842 |
+
|
843 |
+
audio = wav_gen[0][0].cpu().detach().numpy()
|
844 |
+
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
845 |
+
sr = 24000
|
846 |
+
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
847 |
+
|
848 |
+
|
849 |
+
def test_export1(
|
850 |
+
todo_text,
|
851 |
+
gpt_sovits_v3,
|
852 |
+
output,
|
853 |
+
):
|
854 |
+
# hps = sovits.hps
|
855 |
+
ref_wav_path = "onnx/ad/ref.wav"
|
856 |
+
speed = 1.0
|
857 |
+
sample_steps = torch.LongTensor([16])
|
858 |
+
|
859 |
+
dtype = torch.float16 if is_half == True else torch.float32
|
860 |
+
|
861 |
+
zero_wav = np.zeros(
|
862 |
+
int(24000 * 0.3),
|
863 |
+
dtype=np.float16 if is_half == True else np.float32,
|
864 |
+
)
|
865 |
+
|
866 |
+
with torch.no_grad():
|
867 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
868 |
+
wav16k = torch.from_numpy(wav16k)
|
869 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
870 |
+
|
871 |
+
if is_half == True:
|
872 |
+
wav16k = wav16k.half().to(device)
|
873 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
874 |
+
else:
|
875 |
+
wav16k = wav16k.to(device)
|
876 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
877 |
+
wav16k = torch.cat([wav16k, zero_wav_torch])
|
878 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
879 |
+
print("ssl_content:", ssl_content.shape, ssl_content.dtype)
|
880 |
+
|
881 |
+
ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
|
882 |
+
ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
|
883 |
+
|
884 |
+
phones1, bert1, norm_text1 = get_phones_and_bert(
|
885 |
+
"你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
886 |
+
)
|
887 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(
|
888 |
+
todo_text,
|
889 |
+
"zh",
|
890 |
+
"v3",
|
891 |
+
)
|
892 |
+
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
893 |
+
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
894 |
+
|
895 |
+
bert1 = bert1.T.to(device)
|
896 |
+
bert2 = bert2.T.to(device)
|
897 |
+
top_k = torch.LongTensor([15]).to(device)
|
898 |
+
|
899 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
900 |
+
logger.info("start inference %s", current_time)
|
901 |
+
print(
|
902 |
+
ssl_content.shape,
|
903 |
+
ref_audio_32k.shape,
|
904 |
+
phoneme_ids0.shape,
|
905 |
+
phoneme_ids1.shape,
|
906 |
+
bert1.shape,
|
907 |
+
bert2.shape,
|
908 |
+
top_k.shape,
|
909 |
+
)
|
910 |
+
wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
|
911 |
+
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
912 |
+
|
913 |
+
wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
|
914 |
+
|
915 |
+
audio = wav_gen.cpu().detach().numpy()
|
916 |
+
logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|
917 |
+
sr = 24000
|
918 |
+
soundfile.write(output, (audio * 32768).astype(np.int16), sr)
|
919 |
+
|
920 |
+
|
921 |
+
import time
|
922 |
+
|
923 |
+
|
924 |
+
def test_():
|
925 |
+
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
926 |
+
|
927 |
+
# cfm = ExportCFM(sovits.cfm)
|
928 |
+
# cfm.cfm.estimator = dit
|
929 |
+
sovits.cfm = None
|
930 |
+
|
931 |
+
cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
|
932 |
+
# cfm = torch.jit.optimize_for_inference(cfm)
|
933 |
+
cfm = cfm.half().to(device)
|
934 |
+
|
935 |
+
cfm.eval()
|
936 |
+
|
937 |
+
logger.info("cfm ok")
|
938 |
+
|
939 |
+
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
940 |
+
# v2 的 gpt 也可以用
|
941 |
+
# dict_s1 = torch.load("GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
|
942 |
+
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
943 |
+
print("#### get_raw_t2s_model ####")
|
944 |
+
print(raw_t2s.config)
|
945 |
+
if is_half:
|
946 |
+
raw_t2s = raw_t2s.half().to(device)
|
947 |
+
t2s_m = T2SModel(raw_t2s).half().to(device)
|
948 |
+
t2s_m.eval()
|
949 |
+
t2s_m = torch.jit.script(t2s_m)
|
950 |
+
t2s_m.eval()
|
951 |
+
# t2s_m.top_k = 15
|
952 |
+
logger.info("t2s_m ok")
|
953 |
+
|
954 |
+
vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
|
955 |
+
# vq_model = torch.jit.optimize_for_inference(vq_model)
|
956 |
+
# vq_model = vq_model.half().to(device)
|
957 |
+
vq_model.eval()
|
958 |
+
# vq_model = sovits.vq_model
|
959 |
+
logger.info("vq_model ok")
|
960 |
+
|
961 |
+
# gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
|
962 |
+
# gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
|
963 |
+
# gpt_sovits_v3_half = gpt_sovits_v3_half.half()
|
964 |
+
# gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
|
965 |
+
# gpt_sovits_v3_half.eval()
|
966 |
+
gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
|
967 |
+
logger.info("gpt_sovits_v3_half ok")
|
968 |
+
|
969 |
+
# init_bigvgan()
|
970 |
+
# global bigvgan_model
|
971 |
+
bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
|
972 |
+
# bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
|
973 |
+
bigvgan_model = bigvgan_model.half()
|
974 |
+
bigvgan_model = bigvgan_model.cuda()
|
975 |
+
bigvgan_model.eval()
|
976 |
+
|
977 |
+
logger.info("bigvgan ok")
|
978 |
+
|
979 |
+
gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
|
980 |
+
gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
|
981 |
+
gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
|
982 |
+
gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
|
983 |
+
gpt_sovits_v3.eval()
|
984 |
+
print("save gpt_sovits_v3 ok")
|
985 |
+
|
986 |
+
time.sleep(5)
|
987 |
+
# print("thread:", torch.get_num_threads())
|
988 |
+
# print("thread:", torch.get_num_interop_threads())
|
989 |
+
# torch.set_num_interop_threads(1)
|
990 |
+
# torch.set_num_threads(1)
|
991 |
+
|
992 |
+
test_export1(
|
993 |
+
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
994 |
+
gpt_sovits_v3,
|
995 |
+
"out.wav",
|
996 |
+
)
|
997 |
+
|
998 |
+
test_export1(
|
999 |
+
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
1000 |
+
gpt_sovits_v3,
|
1001 |
+
"out2.wav",
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
# test_export(
|
1005 |
+
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP. 哈哈哈...",
|
1006 |
+
# gpt_sovits_v3_half,
|
1007 |
+
# cfm,
|
1008 |
+
# bigvgan_model,
|
1009 |
+
# "out2.wav",
|
1010 |
+
# )
|
1011 |
+
|
1012 |
+
|
1013 |
+
def test_export_gpt_sovits_v3():
|
1014 |
+
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
|
1015 |
+
# test_export1(
|
1016 |
+
# "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
1017 |
+
# gpt_sovits_v3,
|
1018 |
+
# "out3.wav",
|
1019 |
+
# )
|
1020 |
+
# test_export1(
|
1021 |
+
# "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
1022 |
+
# gpt_sovits_v3,
|
1023 |
+
# "out4.wav",
|
1024 |
+
# )
|
1025 |
+
test_export1(
|
1026 |
+
"风萧萧兮易水寒,壮士一去兮不复还.",
|
1027 |
+
gpt_sovits_v3,
|
1028 |
+
"out5.wav",
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
|
1032 |
+
with torch.no_grad():
|
1033 |
+
# export()
|
1034 |
+
test_()
|
1035 |
+
# test_export_gpt_sovits_v3()
|
GPT_SoVITS/inference_cli.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import soundfile as sf
|
4 |
+
|
5 |
+
from tools.i18n.i18n import I18nAuto
|
6 |
+
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
7 |
+
|
8 |
+
i18n = I18nAuto()
|
9 |
+
|
10 |
+
|
11 |
+
def synthesize(
|
12 |
+
GPT_model_path,
|
13 |
+
SoVITS_model_path,
|
14 |
+
ref_audio_path,
|
15 |
+
ref_text_path,
|
16 |
+
ref_language,
|
17 |
+
target_text_path,
|
18 |
+
target_language,
|
19 |
+
output_path,
|
20 |
+
):
|
21 |
+
# Read reference text
|
22 |
+
with open(ref_text_path, "r", encoding="utf-8") as file:
|
23 |
+
ref_text = file.read()
|
24 |
+
|
25 |
+
# Read target text
|
26 |
+
with open(target_text_path, "r", encoding="utf-8") as file:
|
27 |
+
target_text = file.read()
|
28 |
+
|
29 |
+
# Change model weights
|
30 |
+
change_gpt_weights(gpt_path=GPT_model_path)
|
31 |
+
change_sovits_weights(sovits_path=SoVITS_model_path)
|
32 |
+
|
33 |
+
# Synthesize audio
|
34 |
+
synthesis_result = get_tts_wav(
|
35 |
+
ref_wav_path=ref_audio_path,
|
36 |
+
prompt_text=ref_text,
|
37 |
+
prompt_language=i18n(ref_language),
|
38 |
+
text=target_text,
|
39 |
+
text_language=i18n(target_language),
|
40 |
+
top_p=1,
|
41 |
+
temperature=1,
|
42 |
+
)
|
43 |
+
|
44 |
+
result_list = list(synthesis_result)
|
45 |
+
|
46 |
+
if result_list:
|
47 |
+
last_sampling_rate, last_audio_data = result_list[-1]
|
48 |
+
output_wav_path = os.path.join(output_path, "output.wav")
|
49 |
+
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
50 |
+
print(f"Audio saved to {output_wav_path}")
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
55 |
+
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
56 |
+
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
57 |
+
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
58 |
+
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
59 |
+
parser.add_argument(
|
60 |
+
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
61 |
+
)
|
62 |
+
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
63 |
+
parser.add_argument(
|
64 |
+
"--target_language",
|
65 |
+
required=True,
|
66 |
+
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
67 |
+
help="Language of the target text",
|
68 |
+
)
|
69 |
+
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
70 |
+
|
71 |
+
args = parser.parse_args()
|
72 |
+
|
73 |
+
synthesize(
|
74 |
+
args.gpt_model,
|
75 |
+
args.sovits_model,
|
76 |
+
args.ref_audio,
|
77 |
+
args.ref_text,
|
78 |
+
args.ref_language,
|
79 |
+
args.target_text,
|
80 |
+
args.target_language,
|
81 |
+
args.output_path,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
main()
|
GPT_SoVITS/inference_gui.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from PyQt5.QtCore import QEvent
|
4 |
+
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
|
5 |
+
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
from tools.i18n.i18n import I18nAuto
|
9 |
+
|
10 |
+
i18n = I18nAuto()
|
11 |
+
|
12 |
+
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
13 |
+
|
14 |
+
|
15 |
+
class GPTSoVITSGUI(QMainWindow):
|
16 |
+
GPT_Path = gpt_path
|
17 |
+
SoVITS_Path = sovits_path
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.setWindowTitle("GPT-SoVITS GUI")
|
23 |
+
self.setGeometry(800, 450, 950, 850)
|
24 |
+
|
25 |
+
self.setStyleSheet("""
|
26 |
+
QWidget {
|
27 |
+
background-color: #a3d3b1;
|
28 |
+
}
|
29 |
+
|
30 |
+
QTabWidget::pane {
|
31 |
+
background-color: #a3d3b1;
|
32 |
+
}
|
33 |
+
|
34 |
+
QTabWidget::tab-bar {
|
35 |
+
alignment: left;
|
36 |
+
}
|
37 |
+
|
38 |
+
QTabBar::tab {
|
39 |
+
background: #8da4bf;
|
40 |
+
color: #ffffff;
|
41 |
+
padding: 8px;
|
42 |
+
}
|
43 |
+
|
44 |
+
QTabBar::tab:selected {
|
45 |
+
background: #2a3f54;
|
46 |
+
}
|
47 |
+
|
48 |
+
QLabel {
|
49 |
+
color: #000000;
|
50 |
+
}
|
51 |
+
|
52 |
+
QPushButton {
|
53 |
+
background-color: #4CAF50;
|
54 |
+
color: white;
|
55 |
+
padding: 8px;
|
56 |
+
border: 1px solid #4CAF50;
|
57 |
+
border-radius: 4px;
|
58 |
+
}
|
59 |
+
|
60 |
+
QPushButton:hover {
|
61 |
+
background-color: #45a049;
|
62 |
+
border: 1px solid #45a049;
|
63 |
+
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
64 |
+
}
|
65 |
+
""")
|
66 |
+
|
67 |
+
license_text = (
|
68 |
+
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
69 |
+
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
70 |
+
)
|
71 |
+
license_label = QLabel(license_text)
|
72 |
+
license_label.setWordWrap(True)
|
73 |
+
|
74 |
+
self.GPT_model_label = QLabel("选择GPT模型:")
|
75 |
+
self.GPT_model_input = QLineEdit()
|
76 |
+
self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
|
77 |
+
self.GPT_model_input.setText(self.GPT_Path)
|
78 |
+
self.GPT_model_input.setReadOnly(True)
|
79 |
+
self.GPT_model_button = QPushButton("选择GPT模型文件")
|
80 |
+
self.GPT_model_button.clicked.connect(self.select_GPT_model)
|
81 |
+
|
82 |
+
self.SoVITS_model_label = QLabel("选择SoVITS模型:")
|
83 |
+
self.SoVITS_model_input = QLineEdit()
|
84 |
+
self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
|
85 |
+
self.SoVITS_model_input.setText(self.SoVITS_Path)
|
86 |
+
self.SoVITS_model_input.setReadOnly(True)
|
87 |
+
self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
|
88 |
+
self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
|
89 |
+
|
90 |
+
self.ref_audio_label = QLabel("上传参考音频:")
|
91 |
+
self.ref_audio_input = QLineEdit()
|
92 |
+
self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
|
93 |
+
self.ref_audio_input.setReadOnly(True)
|
94 |
+
self.ref_audio_button = QPushButton("选择音频文件")
|
95 |
+
self.ref_audio_button.clicked.connect(self.select_ref_audio)
|
96 |
+
|
97 |
+
self.ref_text_label = QLabel("参考音频文本:")
|
98 |
+
self.ref_text_input = QLineEdit()
|
99 |
+
self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
|
100 |
+
self.ref_text_button = QPushButton("上传文本")
|
101 |
+
self.ref_text_button.clicked.connect(self.upload_ref_text)
|
102 |
+
|
103 |
+
self.ref_language_label = QLabel("参考音频语言:")
|
104 |
+
self.ref_language_combobox = QComboBox()
|
105 |
+
self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
106 |
+
self.ref_language_combobox.setCurrentText("多语种混合")
|
107 |
+
|
108 |
+
self.target_text_label = QLabel("合成目标文本:")
|
109 |
+
self.target_text_input = QLineEdit()
|
110 |
+
self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
|
111 |
+
self.target_text_button = QPushButton("上传文本")
|
112 |
+
self.target_text_button.clicked.connect(self.upload_target_text)
|
113 |
+
|
114 |
+
self.target_language_label = QLabel("合成音频语言:")
|
115 |
+
self.target_language_combobox = QComboBox()
|
116 |
+
self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
117 |
+
self.target_language_combobox.setCurrentText("多语种混合")
|
118 |
+
|
119 |
+
self.output_label = QLabel("输出音频路径:")
|
120 |
+
self.output_input = QLineEdit()
|
121 |
+
self.output_input.setPlaceholderText("拖拽或选择文件")
|
122 |
+
self.output_input.setReadOnly(True)
|
123 |
+
self.output_button = QPushButton("选择文件夹")
|
124 |
+
self.output_button.clicked.connect(self.select_output_path)
|
125 |
+
|
126 |
+
self.output_text = QTextEdit()
|
127 |
+
self.output_text.setReadOnly(True)
|
128 |
+
|
129 |
+
self.add_drag_drop_events(
|
130 |
+
[
|
131 |
+
self.GPT_model_input,
|
132 |
+
self.SoVITS_model_input,
|
133 |
+
self.ref_audio_input,
|
134 |
+
self.ref_text_input,
|
135 |
+
self.target_text_input,
|
136 |
+
self.output_input,
|
137 |
+
]
|
138 |
+
)
|
139 |
+
|
140 |
+
self.synthesize_button = QPushButton("合成")
|
141 |
+
self.synthesize_button.clicked.connect(self.synthesize)
|
142 |
+
|
143 |
+
self.clear_output_button = QPushButton("清空输出")
|
144 |
+
self.clear_output_button.clicked.connect(self.clear_output)
|
145 |
+
|
146 |
+
self.status_bar = QStatusBar()
|
147 |
+
|
148 |
+
main_layout = QVBoxLayout()
|
149 |
+
|
150 |
+
input_layout = QGridLayout(self)
|
151 |
+
input_layout.setSpacing(10)
|
152 |
+
|
153 |
+
input_layout.addWidget(license_label, 0, 0, 1, 3)
|
154 |
+
|
155 |
+
input_layout.addWidget(self.GPT_model_label, 1, 0)
|
156 |
+
input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
|
157 |
+
input_layout.addWidget(self.GPT_model_button, 2, 2)
|
158 |
+
|
159 |
+
input_layout.addWidget(self.SoVITS_model_label, 3, 0)
|
160 |
+
input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
|
161 |
+
input_layout.addWidget(self.SoVITS_model_button, 4, 2)
|
162 |
+
|
163 |
+
input_layout.addWidget(self.ref_audio_label, 5, 0)
|
164 |
+
input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
|
165 |
+
input_layout.addWidget(self.ref_audio_button, 6, 2)
|
166 |
+
|
167 |
+
input_layout.addWidget(self.ref_language_label, 7, 0)
|
168 |
+
input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
|
169 |
+
input_layout.addWidget(self.ref_text_label, 9, 0)
|
170 |
+
input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
|
171 |
+
input_layout.addWidget(self.ref_text_button, 10, 2)
|
172 |
+
|
173 |
+
input_layout.addWidget(self.target_language_label, 11, 0)
|
174 |
+
input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
|
175 |
+
input_layout.addWidget(self.target_text_label, 13, 0)
|
176 |
+
input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
|
177 |
+
input_layout.addWidget(self.target_text_button, 14, 2)
|
178 |
+
|
179 |
+
input_layout.addWidget(self.output_label, 15, 0)
|
180 |
+
input_layout.addWidget(self.output_input, 16, 0, 1, 2)
|
181 |
+
input_layout.addWidget(self.output_button, 16, 2)
|
182 |
+
|
183 |
+
main_layout.addLayout(input_layout)
|
184 |
+
|
185 |
+
output_layout = QVBoxLayout()
|
186 |
+
output_layout.addWidget(self.output_text)
|
187 |
+
main_layout.addLayout(output_layout)
|
188 |
+
|
189 |
+
main_layout.addWidget(self.synthesize_button)
|
190 |
+
|
191 |
+
main_layout.addWidget(self.clear_output_button)
|
192 |
+
|
193 |
+
main_layout.addWidget(self.status_bar)
|
194 |
+
|
195 |
+
self.central_widget = QWidget()
|
196 |
+
self.central_widget.setLayout(main_layout)
|
197 |
+
self.setCentralWidget(self.central_widget)
|
198 |
+
|
199 |
+
def dragEnterEvent(self, event):
|
200 |
+
if event.mimeData().hasUrls():
|
201 |
+
event.acceptProposedAction()
|
202 |
+
|
203 |
+
def dropEvent(self, event):
|
204 |
+
if event.mimeData().hasUrls():
|
205 |
+
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
|
206 |
+
if len(file_paths) == 1:
|
207 |
+
self.update_ref_audio(file_paths[0])
|
208 |
+
else:
|
209 |
+
self.update_ref_audio(", ".join(file_paths))
|
210 |
+
|
211 |
+
def add_drag_drop_events(self, widgets):
|
212 |
+
for widget in widgets:
|
213 |
+
widget.setAcceptDrops(True)
|
214 |
+
widget.installEventFilter(self)
|
215 |
+
|
216 |
+
def eventFilter(self, obj, event):
|
217 |
+
if event.type() in (QEvent.DragEnter, QEvent.Drop):
|
218 |
+
mime_data = event.mimeData()
|
219 |
+
if mime_data.hasUrls():
|
220 |
+
event.acceptProposedAction()
|
221 |
+
|
222 |
+
return super().eventFilter(obj, event)
|
223 |
+
|
224 |
+
def select_GPT_model(self):
|
225 |
+
file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
|
226 |
+
if file_path:
|
227 |
+
self.GPT_model_input.setText(file_path)
|
228 |
+
|
229 |
+
def select_SoVITS_model(self):
|
230 |
+
file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
|
231 |
+
if file_path:
|
232 |
+
self.SoVITS_model_input.setText(file_path)
|
233 |
+
|
234 |
+
def select_ref_audio(self):
|
235 |
+
file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
|
236 |
+
if file_path:
|
237 |
+
self.update_ref_audio(file_path)
|
238 |
+
|
239 |
+
def upload_ref_text(self):
|
240 |
+
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
241 |
+
if file_path:
|
242 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
243 |
+
content = file.read()
|
244 |
+
self.ref_text_input.setText(content)
|
245 |
+
|
246 |
+
def upload_target_text(self):
|
247 |
+
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
248 |
+
if file_path:
|
249 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
250 |
+
content = file.read()
|
251 |
+
self.target_text_input.setText(content)
|
252 |
+
|
253 |
+
def select_output_path(self):
|
254 |
+
options = QFileDialog.Options()
|
255 |
+
options |= QFileDialog.DontUseNativeDialog
|
256 |
+
options |= QFileDialog.ShowDirsOnly
|
257 |
+
|
258 |
+
folder_dialog = QFileDialog()
|
259 |
+
folder_dialog.setOptions(options)
|
260 |
+
folder_dialog.setFileMode(QFileDialog.Directory)
|
261 |
+
|
262 |
+
if folder_dialog.exec_():
|
263 |
+
folder_path = folder_dialog.selectedFiles()[0]
|
264 |
+
self.output_input.setText(folder_path)
|
265 |
+
|
266 |
+
def update_ref_audio(self, file_path):
|
267 |
+
self.ref_audio_input.setText(file_path)
|
268 |
+
|
269 |
+
def clear_output(self):
|
270 |
+
self.output_text.clear()
|
271 |
+
|
272 |
+
def synthesize(self):
|
273 |
+
GPT_model_path = self.GPT_model_input.text()
|
274 |
+
SoVITS_model_path = self.SoVITS_model_input.text()
|
275 |
+
ref_audio_path = self.ref_audio_input.text()
|
276 |
+
language_combobox = self.ref_language_combobox.currentText()
|
277 |
+
language_combobox = i18n(language_combobox)
|
278 |
+
ref_text = self.ref_text_input.text()
|
279 |
+
target_language_combobox = self.target_language_combobox.currentText()
|
280 |
+
target_language_combobox = i18n(target_language_combobox)
|
281 |
+
target_text = self.target_text_input.text()
|
282 |
+
output_path = self.output_input.text()
|
283 |
+
|
284 |
+
if GPT_model_path != self.GPT_Path:
|
285 |
+
change_gpt_weights(gpt_path=GPT_model_path)
|
286 |
+
self.GPT_Path = GPT_model_path
|
287 |
+
if SoVITS_model_path != self.SoVITS_Path:
|
288 |
+
change_sovits_weights(sovits_path=SoVITS_model_path)
|
289 |
+
self.SoVITS_Path = SoVITS_model_path
|
290 |
+
|
291 |
+
synthesis_result = get_tts_wav(
|
292 |
+
ref_wav_path=ref_audio_path,
|
293 |
+
prompt_text=ref_text,
|
294 |
+
prompt_language=language_combobox,
|
295 |
+
text=target_text,
|
296 |
+
text_language=target_language_combobox,
|
297 |
+
)
|
298 |
+
|
299 |
+
result_list = list(synthesis_result)
|
300 |
+
|
301 |
+
if result_list:
|
302 |
+
last_sampling_rate, last_audio_data = result_list[-1]
|
303 |
+
output_wav_path = os.path.join(output_path, "output.wav")
|
304 |
+
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
305 |
+
|
306 |
+
result = "Audio saved to " + output_wav_path
|
307 |
+
|
308 |
+
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
|
309 |
+
self.output_text.append("处理结果:\n" + result)
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == "__main__":
|
313 |
+
app = QApplication(sys.argv)
|
314 |
+
mainWin = GPTSoVITSGUI()
|
315 |
+
mainWin.show()
|
316 |
+
sys.exit(app.exec_())
|
GPT_SoVITS/inference_webui.py
ADDED
@@ -0,0 +1,1281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
按中英混合识别
|
3 |
+
按日英混合识别
|
4 |
+
多语种启动切分识别语种
|
5 |
+
全部按中文识别
|
6 |
+
全部按英文识别
|
7 |
+
全部按日文识别
|
8 |
+
"""
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import traceback
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
import torchaudio
|
15 |
+
|
16 |
+
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
17 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
18 |
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
19 |
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
20 |
+
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
21 |
+
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
22 |
+
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
23 |
+
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
24 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
25 |
+
|
26 |
+
import json
|
27 |
+
import os
|
28 |
+
import re
|
29 |
+
import sys
|
30 |
+
|
31 |
+
import torch
|
32 |
+
from text.LangSegmenter import LangSegmenter
|
33 |
+
|
34 |
+
try:
|
35 |
+
import gradio.analytics as analytics
|
36 |
+
|
37 |
+
analytics.version_check = lambda: None
|
38 |
+
except:
|
39 |
+
...
|
40 |
+
version = model_version = os.environ.get("version", "v2")
|
41 |
+
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
42 |
+
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
43 |
+
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
44 |
+
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
45 |
+
pretrained_sovits_name = [
|
46 |
+
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
47 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
48 |
+
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
49 |
+
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
50 |
+
]
|
51 |
+
pretrained_gpt_name = [
|
52 |
+
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
53 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
54 |
+
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
55 |
+
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
56 |
+
]
|
57 |
+
|
58 |
+
|
59 |
+
_ = [[], []]
|
60 |
+
for i in range(4):
|
61 |
+
if os.path.exists(pretrained_gpt_name[i]):
|
62 |
+
_[0].append(pretrained_gpt_name[i])
|
63 |
+
if os.path.exists(pretrained_sovits_name[i]):
|
64 |
+
_[-1].append(pretrained_sovits_name[i])
|
65 |
+
pretrained_gpt_name, pretrained_sovits_name = _
|
66 |
+
|
67 |
+
|
68 |
+
if os.path.exists("./weight.json"):
|
69 |
+
pass
|
70 |
+
else:
|
71 |
+
with open("./weight.json", "w", encoding="utf-8") as file:
|
72 |
+
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
73 |
+
|
74 |
+
with open("./weight.json", "r", encoding="utf-8") as file:
|
75 |
+
weight_data = file.read()
|
76 |
+
weight_data = json.loads(weight_data)
|
77 |
+
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
78 |
+
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
79 |
+
if isinstance(gpt_path, list):
|
80 |
+
gpt_path = gpt_path[0]
|
81 |
+
if isinstance(sovits_path, list):
|
82 |
+
sovits_path = sovits_path[0]
|
83 |
+
|
84 |
+
# gpt_path = os.environ.get(
|
85 |
+
# "gpt_path", pretrained_gpt_name
|
86 |
+
# )
|
87 |
+
# sovits_path = os.environ.get("sovits_path", pretrained_sovits_name)
|
88 |
+
cnhubert_base_path = os.environ.get("cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base")
|
89 |
+
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
90 |
+
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
91 |
+
infer_ttswebui = int(infer_ttswebui)
|
92 |
+
is_share = os.environ.get("is_share", "False")
|
93 |
+
is_share = eval(is_share)
|
94 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
95 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
96 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
97 |
+
# is_half=False
|
98 |
+
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
99 |
+
import gradio as gr
|
100 |
+
import librosa
|
101 |
+
import numpy as np
|
102 |
+
from feature_extractor import cnhubert
|
103 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
104 |
+
|
105 |
+
cnhubert.cnhubert_base_path = cnhubert_base_path
|
106 |
+
|
107 |
+
import random
|
108 |
+
|
109 |
+
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
|
110 |
+
|
111 |
+
|
112 |
+
def set_seed(seed):
|
113 |
+
if seed == -1:
|
114 |
+
seed = random.randint(0, 1000000)
|
115 |
+
seed = int(seed)
|
116 |
+
random.seed(seed)
|
117 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
118 |
+
np.random.seed(seed)
|
119 |
+
torch.manual_seed(seed)
|
120 |
+
torch.cuda.manual_seed(seed)
|
121 |
+
|
122 |
+
|
123 |
+
# set_seed(42)
|
124 |
+
|
125 |
+
from time import time as ttime
|
126 |
+
|
127 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
128 |
+
from peft import LoraConfig, get_peft_model
|
129 |
+
from text import cleaned_text_to_sequence
|
130 |
+
from text.cleaner import clean_text
|
131 |
+
|
132 |
+
from tools.i18n.i18n import I18nAuto, scan_language_list
|
133 |
+
|
134 |
+
language = os.environ.get("language", "Auto")
|
135 |
+
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
136 |
+
i18n = I18nAuto(language=language)
|
137 |
+
|
138 |
+
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
139 |
+
|
140 |
+
if torch.cuda.is_available():
|
141 |
+
device = "cuda"
|
142 |
+
else:
|
143 |
+
device = "cpu"
|
144 |
+
|
145 |
+
dict_language_v1 = {
|
146 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
147 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
148 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
149 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
150 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
151 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
152 |
+
}
|
153 |
+
dict_language_v2 = {
|
154 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
155 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
156 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
157 |
+
i18n("粤语"): "all_yue", # 全部按中文识别
|
158 |
+
i18n("韩文"): "all_ko", # 全部按韩文识别
|
159 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
160 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
161 |
+
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
162 |
+
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
163 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
164 |
+
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
165 |
+
}
|
166 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
167 |
+
|
168 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
169 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
170 |
+
if is_half == True:
|
171 |
+
bert_model = bert_model.half().to(device)
|
172 |
+
else:
|
173 |
+
bert_model = bert_model.to(device)
|
174 |
+
|
175 |
+
|
176 |
+
def get_bert_feature(text, word2ph):
|
177 |
+
with torch.no_grad():
|
178 |
+
inputs = tokenizer(text, return_tensors="pt")
|
179 |
+
for i in inputs:
|
180 |
+
inputs[i] = inputs[i].to(device)
|
181 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
182 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
183 |
+
assert len(word2ph) == len(text)
|
184 |
+
phone_level_feature = []
|
185 |
+
for i in range(len(word2ph)):
|
186 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
187 |
+
phone_level_feature.append(repeat_feature)
|
188 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
189 |
+
return phone_level_feature.T
|
190 |
+
|
191 |
+
|
192 |
+
class DictToAttrRecursive(dict):
|
193 |
+
def __init__(self, input_dict):
|
194 |
+
super().__init__(input_dict)
|
195 |
+
for key, value in input_dict.items():
|
196 |
+
if isinstance(value, dict):
|
197 |
+
value = DictToAttrRecursive(value)
|
198 |
+
self[key] = value
|
199 |
+
setattr(self, key, value)
|
200 |
+
|
201 |
+
def __getattr__(self, item):
|
202 |
+
try:
|
203 |
+
return self[item]
|
204 |
+
except KeyError:
|
205 |
+
raise AttributeError(f"Attribute {item} not found")
|
206 |
+
|
207 |
+
def __setattr__(self, key, value):
|
208 |
+
if isinstance(value, dict):
|
209 |
+
value = DictToAttrRecursive(value)
|
210 |
+
super(DictToAttrRecursive, self).__setitem__(key, value)
|
211 |
+
super().__setattr__(key, value)
|
212 |
+
|
213 |
+
def __delattr__(self, item):
|
214 |
+
try:
|
215 |
+
del self[item]
|
216 |
+
except KeyError:
|
217 |
+
raise AttributeError(f"Attribute {item} not found")
|
218 |
+
|
219 |
+
|
220 |
+
ssl_model = cnhubert.get_model()
|
221 |
+
if is_half == True:
|
222 |
+
ssl_model = ssl_model.half().to(device)
|
223 |
+
else:
|
224 |
+
ssl_model = ssl_model.to(device)
|
225 |
+
|
226 |
+
resample_transform_dict = {}
|
227 |
+
|
228 |
+
|
229 |
+
def resample(audio_tensor, sr0,sr1):
|
230 |
+
global resample_transform_dict
|
231 |
+
key="%s-%s"%(sr0,sr1)
|
232 |
+
if key not in resample_transform_dict:
|
233 |
+
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
234 |
+
return resample_transform_dict[key](audio_tensor)
|
235 |
+
|
236 |
+
|
237 |
+
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
|
238 |
+
# symbol_version-model_version-if_lora_v3
|
239 |
+
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
240 |
+
|
241 |
+
v3v4set={"v3","v4"}
|
242 |
+
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
243 |
+
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
244 |
+
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
245 |
+
print(sovits_path,version, model_version, if_lora_v3)
|
246 |
+
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
|
247 |
+
if if_lora_v3 == True and is_exist == False:
|
248 |
+
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
249 |
+
gr.Warning(info)
|
250 |
+
raise FileExistsError(info)
|
251 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
252 |
+
if prompt_language is not None and text_language is not None:
|
253 |
+
if prompt_language in list(dict_language.keys()):
|
254 |
+
prompt_text_update, prompt_language_update = (
|
255 |
+
{"__type__": "update"},
|
256 |
+
{"__type__": "update", "value": prompt_language},
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
prompt_text_update = {"__type__": "update", "value": ""}
|
260 |
+
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
261 |
+
if text_language in list(dict_language.keys()):
|
262 |
+
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
263 |
+
else:
|
264 |
+
text_update = {"__type__": "update", "value": ""}
|
265 |
+
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
266 |
+
if model_version in v3v4set:
|
267 |
+
visible_sample_steps = True
|
268 |
+
visible_inp_refs = False
|
269 |
+
else:
|
270 |
+
visible_sample_steps = False
|
271 |
+
visible_inp_refs = True
|
272 |
+
yield (
|
273 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
274 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
275 |
+
prompt_text_update,
|
276 |
+
prompt_language_update,
|
277 |
+
text_update,
|
278 |
+
text_language_update,
|
279 |
+
{"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
|
280 |
+
{"__type__": "update", "visible": visible_inp_refs},
|
281 |
+
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
|
282 |
+
{"__type__": "update", "visible": True if model_version =="v3" else False},
|
283 |
+
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
284 |
+
)
|
285 |
+
|
286 |
+
dict_s2 = load_sovits_new(sovits_path)
|
287 |
+
hps = dict_s2["config"]
|
288 |
+
hps = DictToAttrRecursive(hps)
|
289 |
+
hps.model.semantic_frame_rate = "25hz"
|
290 |
+
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
291 |
+
hps.model.version = "v2" # v3model,v2sybomls
|
292 |
+
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
293 |
+
hps.model.version = "v1"
|
294 |
+
else:
|
295 |
+
hps.model.version = "v2"
|
296 |
+
version = hps.model.version
|
297 |
+
# print("sovits版本:",hps.model.version)
|
298 |
+
if model_version not in v3v4set:
|
299 |
+
vq_model = SynthesizerTrn(
|
300 |
+
hps.data.filter_length // 2 + 1,
|
301 |
+
hps.train.segment_size // hps.data.hop_length,
|
302 |
+
n_speakers=hps.data.n_speakers,
|
303 |
+
**hps.model,
|
304 |
+
)
|
305 |
+
model_version = version
|
306 |
+
else:
|
307 |
+
hps.model.version=model_version
|
308 |
+
vq_model = SynthesizerTrnV3(
|
309 |
+
hps.data.filter_length // 2 + 1,
|
310 |
+
hps.train.segment_size // hps.data.hop_length,
|
311 |
+
n_speakers=hps.data.n_speakers,
|
312 |
+
**hps.model,
|
313 |
+
)
|
314 |
+
if "pretrained" not in sovits_path:
|
315 |
+
try:
|
316 |
+
del vq_model.enc_q
|
317 |
+
except:
|
318 |
+
pass
|
319 |
+
if is_half == True:
|
320 |
+
vq_model = vq_model.half().to(device)
|
321 |
+
else:
|
322 |
+
vq_model = vq_model.to(device)
|
323 |
+
vq_model.eval()
|
324 |
+
if if_lora_v3 == False:
|
325 |
+
print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
326 |
+
else:
|
327 |
+
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
328 |
+
print(
|
329 |
+
"loading sovits_%spretrained_G"%model_version,
|
330 |
+
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
|
331 |
+
)
|
332 |
+
lora_rank = dict_s2["lora_rank"]
|
333 |
+
lora_config = LoraConfig(
|
334 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
335 |
+
r=lora_rank,
|
336 |
+
lora_alpha=lora_rank,
|
337 |
+
init_lora_weights=True,
|
338 |
+
)
|
339 |
+
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
|
340 |
+
print("loading sovits_%s_lora%s" % (model_version,lora_rank))
|
341 |
+
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
342 |
+
vq_model.cfm = vq_model.cfm.merge_and_unload()
|
343 |
+
# torch.save(vq_model.state_dict(),"merge_win.pth")
|
344 |
+
vq_model.eval()
|
345 |
+
|
346 |
+
yield (
|
347 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
348 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
349 |
+
prompt_text_update,
|
350 |
+
prompt_language_update,
|
351 |
+
text_update,
|
352 |
+
text_language_update,
|
353 |
+
{"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
|
354 |
+
{"__type__": "update", "visible": visible_inp_refs},
|
355 |
+
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
|
356 |
+
{"__type__": "update", "visible": True if model_version =="v3" else False},
|
357 |
+
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
358 |
+
)
|
359 |
+
with open("./weight.json") as f:
|
360 |
+
data = f.read()
|
361 |
+
data = json.loads(data)
|
362 |
+
data["SoVITS"][version] = sovits_path
|
363 |
+
with open("./weight.json", "w") as f:
|
364 |
+
f.write(json.dumps(data))
|
365 |
+
|
366 |
+
|
367 |
+
try:
|
368 |
+
next(change_sovits_weights(sovits_path))
|
369 |
+
except:
|
370 |
+
pass
|
371 |
+
|
372 |
+
|
373 |
+
def change_gpt_weights(gpt_path):
|
374 |
+
global hz, max_sec, t2s_model, config
|
375 |
+
hz = 50
|
376 |
+
dict_s1 = torch.load(gpt_path, map_location="cpu")
|
377 |
+
config = dict_s1["config"]
|
378 |
+
max_sec = config["data"]["max_sec"]
|
379 |
+
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
380 |
+
t2s_model.load_state_dict(dict_s1["weight"])
|
381 |
+
if is_half == True:
|
382 |
+
t2s_model = t2s_model.half()
|
383 |
+
t2s_model = t2s_model.to(device)
|
384 |
+
t2s_model.eval()
|
385 |
+
# total = sum([param.nelement() for param in t2s_model.parameters()])
|
386 |
+
# print("Number of parameter: %.2fM" % (total / 1e6))
|
387 |
+
with open("./weight.json") as f:
|
388 |
+
data = f.read()
|
389 |
+
data = json.loads(data)
|
390 |
+
data["GPT"][version] = gpt_path
|
391 |
+
with open("./weight.json", "w") as f:
|
392 |
+
f.write(json.dumps(data))
|
393 |
+
|
394 |
+
|
395 |
+
change_gpt_weights(gpt_path)
|
396 |
+
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
397 |
+
import torch
|
398 |
+
|
399 |
+
now_dir = os.getcwd()
|
400 |
+
|
401 |
+
|
402 |
+
def init_bigvgan():
|
403 |
+
global bigvgan_model,hifigan_model
|
404 |
+
from BigVGAN import bigvgan
|
405 |
+
|
406 |
+
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
407 |
+
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
408 |
+
use_cuda_kernel=False,
|
409 |
+
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
410 |
+
# remove weight norm in the model and set to eval mode
|
411 |
+
bigvgan_model.remove_weight_norm()
|
412 |
+
bigvgan_model = bigvgan_model.eval()
|
413 |
+
if hifigan_model:
|
414 |
+
hifigan_model=hifigan_model.cpu()
|
415 |
+
hifigan_model=None
|
416 |
+
try:torch.cuda.empty_cache()
|
417 |
+
except:pass
|
418 |
+
if is_half == True:
|
419 |
+
bigvgan_model = bigvgan_model.half().to(device)
|
420 |
+
else:
|
421 |
+
bigvgan_model = bigvgan_model.to(device)
|
422 |
+
|
423 |
+
def init_hifigan():
|
424 |
+
global hifigan_model,bigvgan_model
|
425 |
+
hifigan_model = Generator(
|
426 |
+
initial_channel=100,
|
427 |
+
resblock="1",
|
428 |
+
resblock_kernel_sizes=[3, 7, 11],
|
429 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
430 |
+
upsample_rates=[10, 6, 2, 2, 2],
|
431 |
+
upsample_initial_channel=512,
|
432 |
+
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
433 |
+
gin_channels=0, is_bias=True
|
434 |
+
)
|
435 |
+
hifigan_model.eval()
|
436 |
+
hifigan_model.remove_weight_norm()
|
437 |
+
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
438 |
+
print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
|
439 |
+
if bigvgan_model:
|
440 |
+
bigvgan_model=bigvgan_model.cpu()
|
441 |
+
bigvgan_model=None
|
442 |
+
try:torch.cuda.empty_cache()
|
443 |
+
except:pass
|
444 |
+
if is_half == True:
|
445 |
+
hifigan_model = hifigan_model.half().to(device)
|
446 |
+
else:
|
447 |
+
hifigan_model = hifigan_model.to(device)
|
448 |
+
|
449 |
+
bigvgan_model=hifigan_model=None
|
450 |
+
if model_version=="v3":
|
451 |
+
init_bigvgan()
|
452 |
+
if model_version=="v4":
|
453 |
+
init_hifigan()
|
454 |
+
|
455 |
+
|
456 |
+
def get_spepc(hps, filename):
|
457 |
+
# audio = load_audio(filename, int(hps.data.sampling_rate))
|
458 |
+
audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
|
459 |
+
audio = torch.FloatTensor(audio)
|
460 |
+
maxx = audio.abs().max()
|
461 |
+
if maxx > 1:
|
462 |
+
audio /= min(2, maxx)
|
463 |
+
audio_norm = audio
|
464 |
+
audio_norm = audio_norm.unsqueeze(0)
|
465 |
+
spec = spectrogram_torch(
|
466 |
+
audio_norm,
|
467 |
+
hps.data.filter_length,
|
468 |
+
hps.data.sampling_rate,
|
469 |
+
hps.data.hop_length,
|
470 |
+
hps.data.win_length,
|
471 |
+
center=False,
|
472 |
+
)
|
473 |
+
return spec
|
474 |
+
|
475 |
+
|
476 |
+
def clean_text_inf(text, language, version):
|
477 |
+
language = language.replace("all_", "")
|
478 |
+
phones, word2ph, norm_text = clean_text(text, language, version)
|
479 |
+
phones = cleaned_text_to_sequence(phones, version)
|
480 |
+
return phones, word2ph, norm_text
|
481 |
+
|
482 |
+
|
483 |
+
dtype = torch.float16 if is_half == True else torch.float32
|
484 |
+
|
485 |
+
|
486 |
+
def get_bert_inf(phones, word2ph, norm_text, language):
|
487 |
+
language = language.replace("all_", "")
|
488 |
+
if language == "zh":
|
489 |
+
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
|
490 |
+
else:
|
491 |
+
bert = torch.zeros(
|
492 |
+
(1024, len(phones)),
|
493 |
+
dtype=torch.float16 if is_half == True else torch.float32,
|
494 |
+
).to(device)
|
495 |
+
|
496 |
+
return bert
|
497 |
+
|
498 |
+
|
499 |
+
splits = {
|
500 |
+
",",
|
501 |
+
"。",
|
502 |
+
"?",
|
503 |
+
"!",
|
504 |
+
",",
|
505 |
+
".",
|
506 |
+
"?",
|
507 |
+
"!",
|
508 |
+
"~",
|
509 |
+
":",
|
510 |
+
":",
|
511 |
+
"—",
|
512 |
+
"…",
|
513 |
+
}
|
514 |
+
|
515 |
+
|
516 |
+
def get_first(text):
|
517 |
+
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
518 |
+
text = re.split(pattern, text)[0].strip()
|
519 |
+
return text
|
520 |
+
|
521 |
+
|
522 |
+
from text import chinese
|
523 |
+
|
524 |
+
|
525 |
+
def get_phones_and_bert(text, language, version, final=False):
|
526 |
+
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
527 |
+
formattext = text
|
528 |
+
while " " in formattext:
|
529 |
+
formattext = formattext.replace(" ", " ")
|
530 |
+
if language == "all_zh":
|
531 |
+
if re.search(r"[A-Za-z]", formattext):
|
532 |
+
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
533 |
+
formattext = chinese.mix_text_normalize(formattext)
|
534 |
+
return get_phones_and_bert(formattext, "zh", version)
|
535 |
+
else:
|
536 |
+
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
537 |
+
bert = get_bert_feature(norm_text, word2ph).to(device)
|
538 |
+
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
539 |
+
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
540 |
+
formattext = chinese.mix_text_normalize(formattext)
|
541 |
+
return get_phones_and_bert(formattext, "yue", version)
|
542 |
+
else:
|
543 |
+
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
544 |
+
bert = torch.zeros(
|
545 |
+
(1024, len(phones)),
|
546 |
+
dtype=torch.float16 if is_half == True else torch.float32,
|
547 |
+
).to(device)
|
548 |
+
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
549 |
+
textlist = []
|
550 |
+
langlist = []
|
551 |
+
if language == "auto":
|
552 |
+
for tmp in LangSegmenter.getTexts(text):
|
553 |
+
langlist.append(tmp["lang"])
|
554 |
+
textlist.append(tmp["text"])
|
555 |
+
elif language == "auto_yue":
|
556 |
+
for tmp in LangSegmenter.getTexts(text):
|
557 |
+
if tmp["lang"] == "zh":
|
558 |
+
tmp["lang"] = "yue"
|
559 |
+
langlist.append(tmp["lang"])
|
560 |
+
textlist.append(tmp["text"])
|
561 |
+
else:
|
562 |
+
for tmp in LangSegmenter.getTexts(text):
|
563 |
+
if tmp["lang"] == "en":
|
564 |
+
langlist.append(tmp["lang"])
|
565 |
+
else:
|
566 |
+
# 因无法区别中日韩文汉字,以用户输入为准
|
567 |
+
langlist.append(language)
|
568 |
+
textlist.append(tmp["text"])
|
569 |
+
print(textlist)
|
570 |
+
print(langlist)
|
571 |
+
phones_list = []
|
572 |
+
bert_list = []
|
573 |
+
norm_text_list = []
|
574 |
+
for i in range(len(textlist)):
|
575 |
+
lang = langlist[i]
|
576 |
+
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
577 |
+
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
578 |
+
phones_list.append(phones)
|
579 |
+
norm_text_list.append(norm_text)
|
580 |
+
bert_list.append(bert)
|
581 |
+
bert = torch.cat(bert_list, dim=1)
|
582 |
+
phones = sum(phones_list, [])
|
583 |
+
norm_text = "".join(norm_text_list)
|
584 |
+
|
585 |
+
if not final and len(phones) < 6:
|
586 |
+
return get_phones_and_bert("." + text, language, version, final=True)
|
587 |
+
|
588 |
+
return phones, bert.to(dtype), norm_text
|
589 |
+
|
590 |
+
|
591 |
+
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
592 |
+
|
593 |
+
spec_min = -12
|
594 |
+
spec_max = 2
|
595 |
+
|
596 |
+
|
597 |
+
def norm_spec(x):
|
598 |
+
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
599 |
+
|
600 |
+
|
601 |
+
def denorm_spec(x):
|
602 |
+
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
603 |
+
|
604 |
+
|
605 |
+
mel_fn = lambda x: mel_spectrogram_torch(
|
606 |
+
x,
|
607 |
+
**{
|
608 |
+
"n_fft": 1024,
|
609 |
+
"win_size": 1024,
|
610 |
+
"hop_size": 256,
|
611 |
+
"num_mels": 100,
|
612 |
+
"sampling_rate": 24000,
|
613 |
+
"fmin": 0,
|
614 |
+
"fmax": None,
|
615 |
+
"center": False,
|
616 |
+
},
|
617 |
+
)
|
618 |
+
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
619 |
+
x,
|
620 |
+
**{
|
621 |
+
"n_fft": 1280,
|
622 |
+
"win_size": 1280,
|
623 |
+
"hop_size": 320,
|
624 |
+
"num_mels": 100,
|
625 |
+
"sampling_rate": 32000,
|
626 |
+
"fmin": 0,
|
627 |
+
"fmax": None,
|
628 |
+
"center": False,
|
629 |
+
},
|
630 |
+
)
|
631 |
+
|
632 |
+
|
633 |
+
def merge_short_text_in_array(texts, threshold):
|
634 |
+
if (len(texts)) < 2:
|
635 |
+
return texts
|
636 |
+
result = []
|
637 |
+
text = ""
|
638 |
+
for ele in texts:
|
639 |
+
text += ele
|
640 |
+
if len(text) >= threshold:
|
641 |
+
result.append(text)
|
642 |
+
text = ""
|
643 |
+
if len(text) > 0:
|
644 |
+
if len(result) == 0:
|
645 |
+
result.append(text)
|
646 |
+
else:
|
647 |
+
result[len(result) - 1] += text
|
648 |
+
return result
|
649 |
+
|
650 |
+
|
651 |
+
sr_model = None
|
652 |
+
|
653 |
+
|
654 |
+
def audio_sr(audio, sr):
|
655 |
+
global sr_model
|
656 |
+
if sr_model == None:
|
657 |
+
from tools.audio_sr import AP_BWE
|
658 |
+
|
659 |
+
try:
|
660 |
+
sr_model = AP_BWE(device, DictToAttrRecursive)
|
661 |
+
except FileNotFoundError:
|
662 |
+
gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
663 |
+
return audio.cpu().detach().numpy(), sr
|
664 |
+
return sr_model(audio, sr)
|
665 |
+
|
666 |
+
|
667 |
+
##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
|
668 |
+
# cache_tokens={}#暂未实现清理机制
|
669 |
+
cache = {}
|
670 |
+
|
671 |
+
|
672 |
+
def get_tts_wav(
|
673 |
+
ref_wav_path,
|
674 |
+
prompt_text,
|
675 |
+
prompt_language,
|
676 |
+
text,
|
677 |
+
text_language,
|
678 |
+
how_to_cut=i18n("不切"),
|
679 |
+
top_k=20,
|
680 |
+
top_p=0.6,
|
681 |
+
temperature=0.6,
|
682 |
+
ref_free=False,
|
683 |
+
speed=1,
|
684 |
+
if_freeze=False,
|
685 |
+
inp_refs=None,
|
686 |
+
sample_steps=8,
|
687 |
+
if_sr=False,
|
688 |
+
pause_second=0.3,
|
689 |
+
):
|
690 |
+
global cache
|
691 |
+
if ref_wav_path:
|
692 |
+
pass
|
693 |
+
else:
|
694 |
+
gr.Warning(i18n("请上传参考音频"))
|
695 |
+
if text:
|
696 |
+
pass
|
697 |
+
else:
|
698 |
+
gr.Warning(i18n("请填入推理文本"))
|
699 |
+
t = []
|
700 |
+
if prompt_text is None or len(prompt_text) == 0:
|
701 |
+
ref_free = True
|
702 |
+
if model_version in v3v4set:
|
703 |
+
ref_free = False # s2v3暂不支持ref_free
|
704 |
+
else:
|
705 |
+
if_sr = False
|
706 |
+
t0 = ttime()
|
707 |
+
prompt_language = dict_language[prompt_language]
|
708 |
+
text_language = dict_language[text_language]
|
709 |
+
|
710 |
+
if not ref_free:
|
711 |
+
prompt_text = prompt_text.strip("\n")
|
712 |
+
if prompt_text[-1] not in splits:
|
713 |
+
prompt_text += "。" if prompt_language != "en" else "."
|
714 |
+
print(i18n("实际输入的参考文本:"), prompt_text)
|
715 |
+
text = text.strip("\n")
|
716 |
+
# if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
717 |
+
|
718 |
+
print(i18n("实际输入的目标文本:"), text)
|
719 |
+
zero_wav = np.zeros(
|
720 |
+
int(hps.data.sampling_rate * pause_second),
|
721 |
+
dtype=np.float16 if is_half == True else np.float32,
|
722 |
+
)
|
723 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
724 |
+
if is_half == True:
|
725 |
+
zero_wav_torch = zero_wav_torch.half().to(device)
|
726 |
+
else:
|
727 |
+
zero_wav_torch = zero_wav_torch.to(device)
|
728 |
+
if not ref_free:
|
729 |
+
with torch.no_grad():
|
730 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
731 |
+
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
732 |
+
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
|
733 |
+
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
734 |
+
wav16k = torch.from_numpy(wav16k)
|
735 |
+
if is_half == True:
|
736 |
+
wav16k = wav16k.half().to(device)
|
737 |
+
else:
|
738 |
+
wav16k = wav16k.to(device)
|
739 |
+
wav16k = torch.cat([wav16k, zero_wav_torch])
|
740 |
+
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
741 |
+
codes = vq_model.extract_latent(ssl_content)
|
742 |
+
prompt_semantic = codes[0, 0]
|
743 |
+
prompt = prompt_semantic.unsqueeze(0).to(device)
|
744 |
+
|
745 |
+
t1 = ttime()
|
746 |
+
t.append(t1 - t0)
|
747 |
+
|
748 |
+
if how_to_cut == i18n("凑四句一切"):
|
749 |
+
text = cut1(text)
|
750 |
+
elif how_to_cut == i18n("凑50字一切"):
|
751 |
+
text = cut2(text)
|
752 |
+
elif how_to_cut == i18n("按中文句号。切"):
|
753 |
+
text = cut3(text)
|
754 |
+
elif how_to_cut == i18n("按英文句号.切"):
|
755 |
+
text = cut4(text)
|
756 |
+
elif how_to_cut == i18n("按标点符号切"):
|
757 |
+
text = cut5(text)
|
758 |
+
while "\n\n" in text:
|
759 |
+
text = text.replace("\n\n", "\n")
|
760 |
+
print(i18n("实际输入的目标文本(切句后):"), text)
|
761 |
+
texts = text.split("\n")
|
762 |
+
texts = process_text(texts)
|
763 |
+
texts = merge_short_text_in_array(texts, 5)
|
764 |
+
audio_opt = []
|
765 |
+
###s2v3暂不支持ref_free
|
766 |
+
if not ref_free:
|
767 |
+
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
768 |
+
|
769 |
+
for i_text, text in enumerate(texts):
|
770 |
+
# 解决输入目标文本的空行导致报错的问题
|
771 |
+
if len(text.strip()) == 0:
|
772 |
+
continue
|
773 |
+
if text[-1] not in splits:
|
774 |
+
text += "。" if text_language != "en" else "."
|
775 |
+
print(i18n("实际输入的目标文本(每句):"), text)
|
776 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
777 |
+
print(i18n("前端处理后的文本(每句):"), norm_text2)
|
778 |
+
if not ref_free:
|
779 |
+
bert = torch.cat([bert1, bert2], 1)
|
780 |
+
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
781 |
+
else:
|
782 |
+
bert = bert2
|
783 |
+
all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
784 |
+
|
785 |
+
bert = bert.to(device).unsqueeze(0)
|
786 |
+
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
787 |
+
|
788 |
+
t2 = ttime()
|
789 |
+
# cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
|
790 |
+
# print(cache.keys(),if_freeze)
|
791 |
+
if i_text in cache and if_freeze == True:
|
792 |
+
pred_semantic = cache[i_text]
|
793 |
+
else:
|
794 |
+
with torch.no_grad():
|
795 |
+
pred_semantic, idx = t2s_model.model.infer_panel(
|
796 |
+
all_phoneme_ids,
|
797 |
+
all_phoneme_len,
|
798 |
+
None if ref_free else prompt,
|
799 |
+
bert,
|
800 |
+
# prompt_phone_len=ph_offset,
|
801 |
+
top_k=top_k,
|
802 |
+
top_p=top_p,
|
803 |
+
temperature=temperature,
|
804 |
+
early_stop_num=hz * max_sec,
|
805 |
+
)
|
806 |
+
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
807 |
+
cache[i_text] = pred_semantic
|
808 |
+
t3 = ttime()
|
809 |
+
###v3不存在以下逻辑和inp_refs
|
810 |
+
if model_version not in v3v4set:
|
811 |
+
refers = []
|
812 |
+
if inp_refs:
|
813 |
+
for path in inp_refs:
|
814 |
+
try:
|
815 |
+
refer = get_spepc(hps, path.name).to(dtype).to(device)
|
816 |
+
refers.append(refer)
|
817 |
+
except:
|
818 |
+
traceback.print_exc()
|
819 |
+
if len(refers) == 0:
|
820 |
+
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
821 |
+
audio = vq_model.decode(
|
822 |
+
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
823 |
+
)[0][0] # .cpu().detach().numpy()
|
824 |
+
else:
|
825 |
+
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
826 |
+
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
827 |
+
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
828 |
+
# print(11111111, phoneme_ids0, phoneme_ids1)
|
829 |
+
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
830 |
+
ref_audio, sr = torchaudio.load(ref_wav_path)
|
831 |
+
ref_audio = ref_audio.to(device).float()
|
832 |
+
if ref_audio.shape[0] == 2:
|
833 |
+
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
834 |
+
tgt_sr=24000 if model_version=="v3"else 32000
|
835 |
+
if sr != tgt_sr:
|
836 |
+
ref_audio = resample(ref_audio, sr,tgt_sr)
|
837 |
+
# print("ref_audio",ref_audio.abs().mean())
|
838 |
+
mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
|
839 |
+
mel2 = norm_spec(mel2)
|
840 |
+
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
841 |
+
mel2 = mel2[:, :, :T_min]
|
842 |
+
fea_ref = fea_ref[:, :, :T_min]
|
843 |
+
Tref=468 if model_version=="v3"else 500
|
844 |
+
Tchunk=934 if model_version=="v3"else 1000
|
845 |
+
if T_min > Tref:
|
846 |
+
mel2 = mel2[:, :, -Tref:]
|
847 |
+
fea_ref = fea_ref[:, :, -Tref:]
|
848 |
+
T_min = Tref
|
849 |
+
chunk_len = Tchunk - T_min
|
850 |
+
mel2 = mel2.to(dtype)
|
851 |
+
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
852 |
+
cfm_resss = []
|
853 |
+
idx = 0
|
854 |
+
while 1:
|
855 |
+
fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
|
856 |
+
if fea_todo_chunk.shape[-1] == 0:
|
857 |
+
break
|
858 |
+
idx += chunk_len
|
859 |
+
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
860 |
+
cfm_res = vq_model.cfm.inference(
|
861 |
+
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
862 |
+
)
|
863 |
+
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
864 |
+
mel2 = cfm_res[:, :, -T_min:]
|
865 |
+
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
866 |
+
cfm_resss.append(cfm_res)
|
867 |
+
cfm_res = torch.cat(cfm_resss, 2)
|
868 |
+
cfm_res = denorm_spec(cfm_res)
|
869 |
+
if model_version=="v3":
|
870 |
+
if bigvgan_model == None:
|
871 |
+
init_bigvgan()
|
872 |
+
else:#v4
|
873 |
+
if hifigan_model == None:
|
874 |
+
init_hifigan()
|
875 |
+
vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
|
876 |
+
with torch.inference_mode():
|
877 |
+
wav_gen = vocoder_model(cfm_res)
|
878 |
+
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
879 |
+
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
880 |
+
if max_audio > 1:
|
881 |
+
audio = audio / max_audio
|
882 |
+
audio_opt.append(audio)
|
883 |
+
audio_opt.append(zero_wav_torch) # zero_wav
|
884 |
+
t4 = ttime()
|
885 |
+
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
886 |
+
t1 = ttime()
|
887 |
+
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
|
888 |
+
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
|
889 |
+
if model_version in {"v1","v2"}:opt_sr=32000
|
890 |
+
elif model_version=="v3":opt_sr=24000
|
891 |
+
else:opt_sr=48000#v4
|
892 |
+
if if_sr == True and opt_sr == 24000:
|
893 |
+
print(i18n("音频超分中"))
|
894 |
+
audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
|
895 |
+
max_audio = np.abs(audio_opt).max()
|
896 |
+
if max_audio > 1:
|
897 |
+
audio_opt /= max_audio
|
898 |
+
else:
|
899 |
+
audio_opt = audio_opt.cpu().detach().numpy()
|
900 |
+
yield opt_sr, (audio_opt * 32767).astype(np.int16)
|
901 |
+
|
902 |
+
|
903 |
+
def split(todo_text):
|
904 |
+
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
905 |
+
if todo_text[-1] not in splits:
|
906 |
+
todo_text += "。"
|
907 |
+
i_split_head = i_split_tail = 0
|
908 |
+
len_text = len(todo_text)
|
909 |
+
todo_texts = []
|
910 |
+
while 1:
|
911 |
+
if i_split_head >= len_text:
|
912 |
+
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
913 |
+
if todo_text[i_split_head] in splits:
|
914 |
+
i_split_head += 1
|
915 |
+
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
916 |
+
i_split_tail = i_split_head
|
917 |
+
else:
|
918 |
+
i_split_head += 1
|
919 |
+
return todo_texts
|
920 |
+
|
921 |
+
|
922 |
+
def cut1(inp):
|
923 |
+
inp = inp.strip("\n")
|
924 |
+
inps = split(inp)
|
925 |
+
split_idx = list(range(0, len(inps), 4))
|
926 |
+
split_idx[-1] = None
|
927 |
+
if len(split_idx) > 1:
|
928 |
+
opts = []
|
929 |
+
for idx in range(len(split_idx) - 1):
|
930 |
+
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
931 |
+
else:
|
932 |
+
opts = [inp]
|
933 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
934 |
+
return "\n".join(opts)
|
935 |
+
|
936 |
+
|
937 |
+
def cut2(inp):
|
938 |
+
inp = inp.strip("\n")
|
939 |
+
inps = split(inp)
|
940 |
+
if len(inps) < 2:
|
941 |
+
return inp
|
942 |
+
opts = []
|
943 |
+
summ = 0
|
944 |
+
tmp_str = ""
|
945 |
+
for i in range(len(inps)):
|
946 |
+
summ += len(inps[i])
|
947 |
+
tmp_str += inps[i]
|
948 |
+
if summ > 50:
|
949 |
+
summ = 0
|
950 |
+
opts.append(tmp_str)
|
951 |
+
tmp_str = ""
|
952 |
+
if tmp_str != "":
|
953 |
+
opts.append(tmp_str)
|
954 |
+
# print(opts)
|
955 |
+
if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
|
956 |
+
opts[-2] = opts[-2] + opts[-1]
|
957 |
+
opts = opts[:-1]
|
958 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
959 |
+
return "\n".join(opts)
|
960 |
+
|
961 |
+
|
962 |
+
def cut3(inp):
|
963 |
+
inp = inp.strip("\n")
|
964 |
+
opts = ["%s" % item for item in inp.strip("。").split("。")]
|
965 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
966 |
+
return "\n".join(opts)
|
967 |
+
|
968 |
+
|
969 |
+
def cut4(inp):
|
970 |
+
inp = inp.strip("\n")
|
971 |
+
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
972 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
973 |
+
return "\n".join(opts)
|
974 |
+
|
975 |
+
|
976 |
+
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
977 |
+
def cut5(inp):
|
978 |
+
inp = inp.strip("\n")
|
979 |
+
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
980 |
+
mergeitems = []
|
981 |
+
items = []
|
982 |
+
|
983 |
+
for i, char in enumerate(inp):
|
984 |
+
if char in punds:
|
985 |
+
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
986 |
+
items.append(char)
|
987 |
+
else:
|
988 |
+
items.append(char)
|
989 |
+
mergeitems.append("".join(items))
|
990 |
+
items = []
|
991 |
+
else:
|
992 |
+
items.append(char)
|
993 |
+
|
994 |
+
if items:
|
995 |
+
mergeitems.append("".join(items))
|
996 |
+
|
997 |
+
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
998 |
+
return "\n".join(opt)
|
999 |
+
|
1000 |
+
|
1001 |
+
def custom_sort_key(s):
|
1002 |
+
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
1003 |
+
parts = re.split("(\d+)", s)
|
1004 |
+
# 将数字部分转换为整数,非数字部分保持不变
|
1005 |
+
parts = [int(part) if part.isdigit() else part for part in parts]
|
1006 |
+
return parts
|
1007 |
+
|
1008 |
+
|
1009 |
+
def process_text(texts):
|
1010 |
+
_text = []
|
1011 |
+
if all(text in [None, " ", "\n", ""] for text in texts):
|
1012 |
+
raise ValueError(i18n("请输入有效文本"))
|
1013 |
+
for text in texts:
|
1014 |
+
if text in [None, " ", ""]:
|
1015 |
+
pass
|
1016 |
+
else:
|
1017 |
+
_text.append(text)
|
1018 |
+
return _text
|
1019 |
+
|
1020 |
+
|
1021 |
+
def change_choices():
|
1022 |
+
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
1023 |
+
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
1024 |
+
"choices": sorted(GPT_names, key=custom_sort_key),
|
1025 |
+
"__type__": "update",
|
1026 |
+
}
|
1027 |
+
|
1028 |
+
|
1029 |
+
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
|
1030 |
+
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
|
1031 |
+
for path in SoVITS_weight_root + GPT_weight_root:
|
1032 |
+
os.makedirs(path, exist_ok=True)
|
1033 |
+
|
1034 |
+
|
1035 |
+
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
1036 |
+
SoVITS_names = [i for i in pretrained_sovits_name]
|
1037 |
+
for path in SoVITS_weight_root:
|
1038 |
+
for name in os.listdir(path):
|
1039 |
+
if name.endswith(".pth"):
|
1040 |
+
SoVITS_names.append("%s/%s" % (path, name))
|
1041 |
+
GPT_names = [i for i in pretrained_gpt_name]
|
1042 |
+
for path in GPT_weight_root:
|
1043 |
+
for name in os.listdir(path):
|
1044 |
+
if name.endswith(".ckpt"):
|
1045 |
+
GPT_names.append("%s/%s" % (path, name))
|
1046 |
+
return SoVITS_names, GPT_names
|
1047 |
+
|
1048 |
+
|
1049 |
+
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
1050 |
+
|
1051 |
+
|
1052 |
+
def html_center(text, label="p"):
|
1053 |
+
return f"""<div style="text-align: center; margin: 100; padding: 50;">
|
1054 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
1055 |
+
</div>"""
|
1056 |
+
|
1057 |
+
|
1058 |
+
def html_left(text, label="p"):
|
1059 |
+
return f"""<div style="text-align: left; margin: 0; padding: 0;">
|
1060 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
1061 |
+
</div>"""
|
1062 |
+
|
1063 |
+
|
1064 |
+
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
1065 |
+
gr.Markdown(
|
1066 |
+
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
1067 |
+
+ "<br>"
|
1068 |
+
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
1069 |
+
)
|
1070 |
+
with gr.Group():
|
1071 |
+
gr.Markdown(html_center(i18n("模型切换"), "h3"))
|
1072 |
+
with gr.Row():
|
1073 |
+
GPT_dropdown = gr.Dropdown(
|
1074 |
+
label=i18n("GPT模型列表"),
|
1075 |
+
choices=sorted(GPT_names, key=custom_sort_key),
|
1076 |
+
value=gpt_path,
|
1077 |
+
interactive=True,
|
1078 |
+
scale=14,
|
1079 |
+
)
|
1080 |
+
SoVITS_dropdown = gr.Dropdown(
|
1081 |
+
label=i18n("SoVITS模型列表"),
|
1082 |
+
choices=sorted(SoVITS_names, key=custom_sort_key),
|
1083 |
+
value=sovits_path,
|
1084 |
+
interactive=True,
|
1085 |
+
scale=14,
|
1086 |
+
)
|
1087 |
+
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
|
1088 |
+
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
1089 |
+
gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
|
1090 |
+
with gr.Row():
|
1091 |
+
inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13)
|
1092 |
+
with gr.Column(scale=13):
|
1093 |
+
ref_text_free = gr.Checkbox(
|
1094 |
+
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
|
1095 |
+
+ i18n("v3暂不支持该模式,使用了会报错。"),
|
1096 |
+
value=False,
|
1097 |
+
interactive=True if model_version not in v3v4set else False,
|
1098 |
+
show_label=True,
|
1099 |
+
scale=1,
|
1100 |
+
)
|
1101 |
+
gr.Markdown(
|
1102 |
+
html_left(
|
1103 |
+
i18n("使用无参考文本模式时建议使用微调的GPT")
|
1104 |
+
+ "<br>"
|
1105 |
+
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
1106 |
+
)
|
1107 |
+
)
|
1108 |
+
prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5, scale=1)
|
1109 |
+
with gr.Column(scale=14):
|
1110 |
+
prompt_language = gr.Dropdown(
|
1111 |
+
label=i18n("参考音频的语种"),
|
1112 |
+
choices=list(dict_language.keys()),
|
1113 |
+
value=i18n("中文"),
|
1114 |
+
)
|
1115 |
+
inp_refs = (
|
1116 |
+
gr.File(
|
1117 |
+
label=i18n(
|
1118 |
+
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
|
1119 |
+
),
|
1120 |
+
file_count="multiple",
|
1121 |
+
)
|
1122 |
+
if model_version not in v3v4set
|
1123 |
+
else gr.File(
|
1124 |
+
label=i18n(
|
1125 |
+
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
|
1126 |
+
),
|
1127 |
+
file_count="multiple",
|
1128 |
+
visible=False,
|
1129 |
+
)
|
1130 |
+
)
|
1131 |
+
sample_steps = (
|
1132 |
+
gr.Radio(
|
1133 |
+
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
1134 |
+
value=32 if model_version=="v3"else 8,
|
1135 |
+
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
|
1136 |
+
visible=True,
|
1137 |
+
)
|
1138 |
+
if model_version in v3v4set
|
1139 |
+
else gr.Radio(
|
1140 |
+
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
1141 |
+
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
|
1142 |
+
visible=False,
|
1143 |
+
value=32 if model_version=="v3"else 8,
|
1144 |
+
)
|
1145 |
+
)
|
1146 |
+
if_sr_Checkbox = gr.Checkbox(
|
1147 |
+
label=i18n("v3输出如果觉得闷可以试试开超分"),
|
1148 |
+
value=False,
|
1149 |
+
interactive=True,
|
1150 |
+
show_label=True,
|
1151 |
+
visible=False if model_version !="v3" else True,
|
1152 |
+
)
|
1153 |
+
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
|
1154 |
+
with gr.Row():
|
1155 |
+
with gr.Column(scale=13):
|
1156 |
+
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
|
1157 |
+
with gr.Column(scale=7):
|
1158 |
+
text_language = gr.Dropdown(
|
1159 |
+
label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
|
1160 |
+
choices=list(dict_language.keys()),
|
1161 |
+
value=i18n("中文"),
|
1162 |
+
scale=1,
|
1163 |
+
)
|
1164 |
+
how_to_cut = gr.Dropdown(
|
1165 |
+
label=i18n("怎么切"),
|
1166 |
+
choices=[
|
1167 |
+
i18n("不切"),
|
1168 |
+
i18n("凑四句一切"),
|
1169 |
+
i18n("凑50字一切"),
|
1170 |
+
i18n("按中文句号。切"),
|
1171 |
+
i18n("按英文句号.切"),
|
1172 |
+
i18n("按标点符号切"),
|
1173 |
+
],
|
1174 |
+
value=i18n("凑四句一切"),
|
1175 |
+
interactive=True,
|
1176 |
+
scale=1,
|
1177 |
+
)
|
1178 |
+
gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
|
1179 |
+
if_freeze = gr.Checkbox(
|
1180 |
+
label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
|
1181 |
+
value=False,
|
1182 |
+
interactive=True,
|
1183 |
+
show_label=True,
|
1184 |
+
scale=1,
|
1185 |
+
)
|
1186 |
+
with gr.Row():
|
1187 |
+
speed = gr.Slider(
|
1188 |
+
minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1
|
1189 |
+
)
|
1190 |
+
pause_second_slider = gr.Slider(
|
1191 |
+
minimum=0.1,
|
1192 |
+
maximum=0.5,
|
1193 |
+
step=0.01,
|
1194 |
+
label=i18n("句间停顿秒数"),
|
1195 |
+
value=0.3,
|
1196 |
+
interactive=True,
|
1197 |
+
scale=1,
|
1198 |
+
)
|
1199 |
+
gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
|
1200 |
+
top_k = gr.Slider(
|
1201 |
+
minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1
|
1202 |
+
)
|
1203 |
+
top_p = gr.Slider(
|
1204 |
+
minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1
|
1205 |
+
)
|
1206 |
+
temperature = gr.Slider(
|
1207 |
+
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1
|
1208 |
+
)
|
1209 |
+
# with gr.Column():
|
1210 |
+
# gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
|
1211 |
+
# phoneme=gr.Textbox(label=i18n("音素框"), value="")
|
1212 |
+
# get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
|
1213 |
+
with gr.Row():
|
1214 |
+
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
|
1215 |
+
output = gr.Audio(label=i18n("输出的语音"), scale=14)
|
1216 |
+
|
1217 |
+
inference_button.click(
|
1218 |
+
get_tts_wav,
|
1219 |
+
[
|
1220 |
+
inp_ref,
|
1221 |
+
prompt_text,
|
1222 |
+
prompt_language,
|
1223 |
+
text,
|
1224 |
+
text_language,
|
1225 |
+
how_to_cut,
|
1226 |
+
top_k,
|
1227 |
+
top_p,
|
1228 |
+
temperature,
|
1229 |
+
ref_text_free,
|
1230 |
+
speed,
|
1231 |
+
if_freeze,
|
1232 |
+
inp_refs,
|
1233 |
+
sample_steps,
|
1234 |
+
if_sr_Checkbox,
|
1235 |
+
pause_second_slider,
|
1236 |
+
],
|
1237 |
+
[output],
|
1238 |
+
)
|
1239 |
+
SoVITS_dropdown.change(
|
1240 |
+
change_sovits_weights,
|
1241 |
+
[SoVITS_dropdown, prompt_language, text_language],
|
1242 |
+
[
|
1243 |
+
prompt_language,
|
1244 |
+
text_language,
|
1245 |
+
prompt_text,
|
1246 |
+
prompt_language,
|
1247 |
+
text,
|
1248 |
+
text_language,
|
1249 |
+
sample_steps,
|
1250 |
+
inp_refs,
|
1251 |
+
ref_text_free,
|
1252 |
+
if_sr_Checkbox,
|
1253 |
+
inference_button,
|
1254 |
+
],
|
1255 |
+
)
|
1256 |
+
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
1257 |
+
|
1258 |
+
# gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
1259 |
+
# with gr.Row():
|
1260 |
+
# text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
|
1261 |
+
# button1 = gr.Button(i18n("凑四句一切"), variant="primary")
|
1262 |
+
# button2 = gr.Button(i18n("凑50字一切"), variant="primary")
|
1263 |
+
# button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
|
1264 |
+
# button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
|
1265 |
+
# button5 = gr.Button(i18n("按标点符号切"), variant="primary")
|
1266 |
+
# text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
|
1267 |
+
# button1.click(cut1, [text_inp], [text_opt])
|
1268 |
+
# button2.click(cut2, [text_inp], [text_opt])
|
1269 |
+
# button3.click(cut3, [text_inp], [text_opt])
|
1270 |
+
# button4.click(cut4, [text_inp], [text_opt])
|
1271 |
+
# button5.click(cut5, [text_inp], [text_opt])
|
1272 |
+
# gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
|
1273 |
+
|
1274 |
+
if __name__ == "__main__":
|
1275 |
+
app.queue().launch( # concurrency_count=511, max_size=1022
|
1276 |
+
server_name="0.0.0.0",
|
1277 |
+
inbrowser=True,
|
1278 |
+
share=is_share,
|
1279 |
+
server_port=infer_ttswebui,
|
1280 |
+
# quiet=True,
|
1281 |
+
)
|
GPT_SoVITS/inference_webui_fast.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
按中英混合识别
|
3 |
+
按日英混合识别
|
4 |
+
多语种启动切分识别语种
|
5 |
+
全部按中文识别
|
6 |
+
全部按英文识别
|
7 |
+
全部按日文识别
|
8 |
+
"""
|
9 |
+
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import re
|
15 |
+
import sys
|
16 |
+
|
17 |
+
now_dir = os.getcwd()
|
18 |
+
sys.path.append(now_dir)
|
19 |
+
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
20 |
+
|
21 |
+
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
22 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
23 |
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
24 |
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
25 |
+
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
26 |
+
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
27 |
+
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
28 |
+
import torch
|
29 |
+
|
30 |
+
try:
|
31 |
+
import gradio.analytics as analytics
|
32 |
+
|
33 |
+
analytics.version_check = lambda: None
|
34 |
+
except:
|
35 |
+
...
|
36 |
+
|
37 |
+
|
38 |
+
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
39 |
+
infer_ttswebui = int(infer_ttswebui)
|
40 |
+
is_share = os.environ.get("is_share", "False")
|
41 |
+
is_share = eval(is_share)
|
42 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
43 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
44 |
+
|
45 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
46 |
+
gpt_path = os.environ.get("gpt_path", None)
|
47 |
+
sovits_path = os.environ.get("sovits_path", None)
|
48 |
+
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
49 |
+
bert_path = os.environ.get("bert_path", None)
|
50 |
+
version = model_version = os.environ.get("version", "v2")
|
51 |
+
|
52 |
+
import gradio as gr
|
53 |
+
from TTS_infer_pack.text_segmentation_method import get_method
|
54 |
+
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
55 |
+
|
56 |
+
from tools.i18n.i18n import I18nAuto, scan_language_list
|
57 |
+
|
58 |
+
language = os.environ.get("language", "Auto")
|
59 |
+
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
60 |
+
i18n = I18nAuto(language=language)
|
61 |
+
|
62 |
+
|
63 |
+
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
64 |
+
|
65 |
+
if torch.cuda.is_available():
|
66 |
+
device = "cuda"
|
67 |
+
# elif torch.backends.mps.is_available():
|
68 |
+
# device = "mps"
|
69 |
+
else:
|
70 |
+
device = "cpu"
|
71 |
+
|
72 |
+
# is_half = False
|
73 |
+
# device = "cpu"
|
74 |
+
|
75 |
+
dict_language_v1 = {
|
76 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
77 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
78 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
79 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
80 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
81 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
82 |
+
}
|
83 |
+
dict_language_v2 = {
|
84 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
85 |
+
i18n("英文"): "en", # 全部按英文识别#######不变
|
86 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
87 |
+
i18n("粤语"): "all_yue", # 全部按中文识别
|
88 |
+
i18n("韩文"): "all_ko", # 全部按韩文识别
|
89 |
+
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
90 |
+
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
91 |
+
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
92 |
+
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
93 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
94 |
+
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
95 |
+
}
|
96 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
97 |
+
|
98 |
+
cut_method = {
|
99 |
+
i18n("不切"): "cut0",
|
100 |
+
i18n("凑四句一切"): "cut1",
|
101 |
+
i18n("凑50字一切"): "cut2",
|
102 |
+
i18n("按中文句号。切"): "cut3",
|
103 |
+
i18n("按英文句号.切"): "cut4",
|
104 |
+
i18n("按标点符号切"): "cut5",
|
105 |
+
}
|
106 |
+
|
107 |
+
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
108 |
+
tts_config.device = device
|
109 |
+
tts_config.is_half = is_half
|
110 |
+
tts_config.version = version
|
111 |
+
if gpt_path is not None:
|
112 |
+
tts_config.t2s_weights_path = gpt_path
|
113 |
+
if sovits_path is not None:
|
114 |
+
tts_config.vits_weights_path = sovits_path
|
115 |
+
if cnhubert_base_path is not None:
|
116 |
+
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
117 |
+
if bert_path is not None:
|
118 |
+
tts_config.bert_base_path = bert_path
|
119 |
+
|
120 |
+
print(tts_config)
|
121 |
+
tts_pipeline = TTS(tts_config)
|
122 |
+
gpt_path = tts_config.t2s_weights_path
|
123 |
+
sovits_path = tts_config.vits_weights_path
|
124 |
+
version = tts_config.version
|
125 |
+
|
126 |
+
|
127 |
+
def inference(
|
128 |
+
text,
|
129 |
+
text_lang,
|
130 |
+
ref_audio_path,
|
131 |
+
aux_ref_audio_paths,
|
132 |
+
prompt_text,
|
133 |
+
prompt_lang,
|
134 |
+
top_k,
|
135 |
+
top_p,
|
136 |
+
temperature,
|
137 |
+
text_split_method,
|
138 |
+
batch_size,
|
139 |
+
speed_factor,
|
140 |
+
ref_text_free,
|
141 |
+
split_bucket,
|
142 |
+
fragment_interval,
|
143 |
+
seed,
|
144 |
+
keep_random,
|
145 |
+
parallel_infer,
|
146 |
+
repetition_penalty,
|
147 |
+
sample_steps,
|
148 |
+
super_sampling,
|
149 |
+
):
|
150 |
+
seed = -1 if keep_random else seed
|
151 |
+
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
152 |
+
inputs = {
|
153 |
+
"text": text,
|
154 |
+
"text_lang": dict_language[text_lang],
|
155 |
+
"ref_audio_path": ref_audio_path,
|
156 |
+
"aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
|
157 |
+
"prompt_text": prompt_text if not ref_text_free else "",
|
158 |
+
"prompt_lang": dict_language[prompt_lang],
|
159 |
+
"top_k": top_k,
|
160 |
+
"top_p": top_p,
|
161 |
+
"temperature": temperature,
|
162 |
+
"text_split_method": cut_method[text_split_method],
|
163 |
+
"batch_size": int(batch_size),
|
164 |
+
"speed_factor": float(speed_factor),
|
165 |
+
"split_bucket": split_bucket,
|
166 |
+
"return_fragment": False,
|
167 |
+
"fragment_interval": fragment_interval,
|
168 |
+
"seed": actual_seed,
|
169 |
+
"parallel_infer": parallel_infer,
|
170 |
+
"repetition_penalty": repetition_penalty,
|
171 |
+
"sample_steps": int(sample_steps),
|
172 |
+
"super_sampling": super_sampling,
|
173 |
+
}
|
174 |
+
try:
|
175 |
+
for item in tts_pipeline.run(inputs):
|
176 |
+
yield item, actual_seed
|
177 |
+
except NO_PROMPT_ERROR:
|
178 |
+
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
179 |
+
|
180 |
+
|
181 |
+
def custom_sort_key(s):
|
182 |
+
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
183 |
+
parts = re.split("(\d+)", s)
|
184 |
+
# 将数字部分转换为整数,非数字部分保持不变
|
185 |
+
parts = [int(part) if part.isdigit() else part for part in parts]
|
186 |
+
return parts
|
187 |
+
|
188 |
+
|
189 |
+
def change_choices():
|
190 |
+
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
191 |
+
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
192 |
+
"choices": sorted(GPT_names, key=custom_sort_key),
|
193 |
+
"__type__": "update",
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
198 |
+
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
199 |
+
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
200 |
+
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
201 |
+
pretrained_sovits_name = [
|
202 |
+
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
203 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
204 |
+
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
205 |
+
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
206 |
+
]
|
207 |
+
pretrained_gpt_name = [
|
208 |
+
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
209 |
+
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
210 |
+
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
211 |
+
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
212 |
+
]
|
213 |
+
|
214 |
+
|
215 |
+
_ = [[], []]
|
216 |
+
for i in range(4):
|
217 |
+
if os.path.exists(pretrained_gpt_name[i]):
|
218 |
+
_[0].append(pretrained_gpt_name[i])
|
219 |
+
if os.path.exists(pretrained_sovits_name[i]):
|
220 |
+
_[-1].append(pretrained_sovits_name[i])
|
221 |
+
pretrained_gpt_name, pretrained_sovits_name = _
|
222 |
+
|
223 |
+
if os.path.exists("./weight.json"):
|
224 |
+
pass
|
225 |
+
else:
|
226 |
+
with open("./weight.json", "w", encoding="utf-8") as file:
|
227 |
+
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
228 |
+
|
229 |
+
with open("./weight.json", "r", encoding="utf-8") as file:
|
230 |
+
weight_data = file.read()
|
231 |
+
weight_data = json.loads(weight_data)
|
232 |
+
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
233 |
+
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
234 |
+
if isinstance(gpt_path, list):
|
235 |
+
gpt_path = gpt_path[0]
|
236 |
+
if isinstance(sovits_path, list):
|
237 |
+
sovits_path = sovits_path[0]
|
238 |
+
|
239 |
+
|
240 |
+
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
|
241 |
+
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
|
242 |
+
for path in SoVITS_weight_root + GPT_weight_root:
|
243 |
+
os.makedirs(path, exist_ok=True)
|
244 |
+
|
245 |
+
|
246 |
+
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
247 |
+
SoVITS_names = [i for i in pretrained_sovits_name]
|
248 |
+
for path in SoVITS_weight_root:
|
249 |
+
for name in os.listdir(path):
|
250 |
+
if name.endswith(".pth"):
|
251 |
+
SoVITS_names.append("%s/%s" % (path, name))
|
252 |
+
GPT_names = [i for i in pretrained_gpt_name]
|
253 |
+
for path in GPT_weight_root:
|
254 |
+
for name in os.listdir(path):
|
255 |
+
if name.endswith(".ckpt"):
|
256 |
+
GPT_names.append("%s/%s" % (path, name))
|
257 |
+
return SoVITS_names, GPT_names
|
258 |
+
|
259 |
+
|
260 |
+
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
261 |
+
|
262 |
+
|
263 |
+
from process_ckpt import get_sovits_version_from_path_fast
|
264 |
+
|
265 |
+
v3v4set={"v3","v4"}
|
266 |
+
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
267 |
+
global version, model_version, dict_language, if_lora_v3
|
268 |
+
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
269 |
+
# print(sovits_path,version, model_version, if_lora_v3)
|
270 |
+
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
|
271 |
+
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
272 |
+
if if_lora_v3 == True and is_exist == False:
|
273 |
+
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
274 |
+
gr.Warning(info)
|
275 |
+
raise FileExistsError(info)
|
276 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
277 |
+
if prompt_language is not None and text_language is not None:
|
278 |
+
if prompt_language in list(dict_language.keys()):
|
279 |
+
prompt_text_update, prompt_language_update = (
|
280 |
+
{"__type__": "update"},
|
281 |
+
{"__type__": "update", "value": prompt_language},
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
prompt_text_update = {"__type__": "update", "value": ""}
|
285 |
+
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
286 |
+
if text_language in list(dict_language.keys()):
|
287 |
+
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
288 |
+
else:
|
289 |
+
text_update = {"__type__": "update", "value": ""}
|
290 |
+
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
291 |
+
if model_version in v3v4set:
|
292 |
+
visible_sample_steps = True
|
293 |
+
visible_inp_refs = False
|
294 |
+
else:
|
295 |
+
visible_sample_steps = False
|
296 |
+
visible_inp_refs = True
|
297 |
+
yield (
|
298 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
299 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
300 |
+
prompt_text_update,
|
301 |
+
prompt_language_update,
|
302 |
+
text_update,
|
303 |
+
text_language_update,
|
304 |
+
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
305 |
+
{"__type__": "update", "visible": visible_inp_refs},
|
306 |
+
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
307 |
+
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
308 |
+
)
|
309 |
+
|
310 |
+
tts_pipeline.init_vits_weights(sovits_path)
|
311 |
+
yield (
|
312 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
313 |
+
{"__type__": "update", "choices": list(dict_language.keys())},
|
314 |
+
prompt_text_update,
|
315 |
+
prompt_language_update,
|
316 |
+
text_update,
|
317 |
+
text_language_update,
|
318 |
+
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
319 |
+
{"__type__": "update", "visible": visible_inp_refs},
|
320 |
+
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
321 |
+
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
322 |
+
)
|
323 |
+
with open("./weight.json") as f:
|
324 |
+
data = f.read()
|
325 |
+
data = json.loads(data)
|
326 |
+
data["SoVITS"][version] = sovits_path
|
327 |
+
with open("./weight.json", "w") as f:
|
328 |
+
f.write(json.dumps(data))
|
329 |
+
|
330 |
+
|
331 |
+
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
332 |
+
gr.Markdown(
|
333 |
+
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
334 |
+
+ "<br>"
|
335 |
+
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
336 |
+
)
|
337 |
+
|
338 |
+
with gr.Column():
|
339 |
+
# with gr.Group():
|
340 |
+
gr.Markdown(value=i18n("模型切换"))
|
341 |
+
with gr.Row():
|
342 |
+
GPT_dropdown = gr.Dropdown(
|
343 |
+
label=i18n("GPT模型列表"),
|
344 |
+
choices=sorted(GPT_names, key=custom_sort_key),
|
345 |
+
value=gpt_path,
|
346 |
+
interactive=True,
|
347 |
+
)
|
348 |
+
SoVITS_dropdown = gr.Dropdown(
|
349 |
+
label=i18n("SoVITS模型列表"),
|
350 |
+
choices=sorted(SoVITS_names, key=custom_sort_key),
|
351 |
+
value=sovits_path,
|
352 |
+
interactive=True,
|
353 |
+
)
|
354 |
+
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
355 |
+
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
356 |
+
|
357 |
+
with gr.Row():
|
358 |
+
with gr.Column():
|
359 |
+
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
360 |
+
with gr.Row():
|
361 |
+
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
362 |
+
inp_refs = gr.File(
|
363 |
+
label=i18n("辅参考音频(可选多个,或不选)"),
|
364 |
+
file_count="multiple",
|
365 |
+
visible=True if model_version != "v3" else False,
|
366 |
+
)
|
367 |
+
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
368 |
+
with gr.Row():
|
369 |
+
prompt_language = gr.Dropdown(
|
370 |
+
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
371 |
+
)
|
372 |
+
with gr.Column():
|
373 |
+
ref_text_free = gr.Checkbox(
|
374 |
+
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
375 |
+
value=False,
|
376 |
+
interactive=True if model_version != "v3" else False,
|
377 |
+
show_label=True,
|
378 |
+
)
|
379 |
+
gr.Markdown(
|
380 |
+
i18n("使用无参考文本模式时建议使用微调的GPT")
|
381 |
+
+ "<br>"
|
382 |
+
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
383 |
+
)
|
384 |
+
|
385 |
+
with gr.Column():
|
386 |
+
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
387 |
+
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
|
388 |
+
text_language = gr.Dropdown(
|
389 |
+
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
390 |
+
)
|
391 |
+
|
392 |
+
with gr.Group():
|
393 |
+
gr.Markdown(value=i18n("推理设置"))
|
394 |
+
with gr.Row():
|
395 |
+
with gr.Column():
|
396 |
+
with gr.Row():
|
397 |
+
batch_size = gr.Slider(
|
398 |
+
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
399 |
+
)
|
400 |
+
sample_steps = gr.Radio(
|
401 |
+
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
402 |
+
)
|
403 |
+
with gr.Row():
|
404 |
+
fragment_interval = gr.Slider(
|
405 |
+
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
406 |
+
)
|
407 |
+
speed_factor = gr.Slider(
|
408 |
+
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
409 |
+
)
|
410 |
+
with gr.Row():
|
411 |
+
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
412 |
+
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
413 |
+
with gr.Row():
|
414 |
+
temperature = gr.Slider(
|
415 |
+
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
416 |
+
)
|
417 |
+
repetition_penalty = gr.Slider(
|
418 |
+
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
419 |
+
)
|
420 |
+
|
421 |
+
with gr.Column():
|
422 |
+
with gr.Row():
|
423 |
+
how_to_cut = gr.Dropdown(
|
424 |
+
label=i18n("怎么切"),
|
425 |
+
choices=[
|
426 |
+
i18n("不切"),
|
427 |
+
i18n("凑四句一切"),
|
428 |
+
i18n("凑50字一切"),
|
429 |
+
i18n("按中文句号。切"),
|
430 |
+
i18n("按英文句号.切"),
|
431 |
+
i18n("按标点符号切"),
|
432 |
+
],
|
433 |
+
value=i18n("凑四句一切"),
|
434 |
+
interactive=True,
|
435 |
+
scale=1,
|
436 |
+
)
|
437 |
+
super_sampling = gr.Checkbox(
|
438 |
+
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
439 |
+
)
|
440 |
+
|
441 |
+
with gr.Row():
|
442 |
+
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
443 |
+
split_bucket = gr.Checkbox(
|
444 |
+
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
445 |
+
value=True,
|
446 |
+
interactive=True,
|
447 |
+
show_label=True,
|
448 |
+
)
|
449 |
+
|
450 |
+
with gr.Row():
|
451 |
+
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
452 |
+
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
453 |
+
|
454 |
+
output = gr.Audio(label=i18n("输出的语音"))
|
455 |
+
with gr.Row():
|
456 |
+
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
457 |
+
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
458 |
+
|
459 |
+
inference_button.click(
|
460 |
+
inference,
|
461 |
+
[
|
462 |
+
text,
|
463 |
+
text_language,
|
464 |
+
inp_ref,
|
465 |
+
inp_refs,
|
466 |
+
prompt_text,
|
467 |
+
prompt_language,
|
468 |
+
top_k,
|
469 |
+
top_p,
|
470 |
+
temperature,
|
471 |
+
how_to_cut,
|
472 |
+
batch_size,
|
473 |
+
speed_factor,
|
474 |
+
ref_text_free,
|
475 |
+
split_bucket,
|
476 |
+
fragment_interval,
|
477 |
+
seed,
|
478 |
+
keep_random,
|
479 |
+
parallel_infer,
|
480 |
+
repetition_penalty,
|
481 |
+
sample_steps,
|
482 |
+
super_sampling,
|
483 |
+
],
|
484 |
+
[output, seed],
|
485 |
+
)
|
486 |
+
stop_infer.click(tts_pipeline.stop, [], [])
|
487 |
+
SoVITS_dropdown.change(
|
488 |
+
change_sovits_weights,
|
489 |
+
[SoVITS_dropdown, prompt_language, text_language],
|
490 |
+
[
|
491 |
+
prompt_language,
|
492 |
+
text_language,
|
493 |
+
prompt_text,
|
494 |
+
prompt_language,
|
495 |
+
text,
|
496 |
+
text_language,
|
497 |
+
sample_steps,
|
498 |
+
inp_refs,
|
499 |
+
ref_text_free,
|
500 |
+
inference_button,
|
501 |
+
],
|
502 |
+
) #
|
503 |
+
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
504 |
+
|
505 |
+
with gr.Group():
|
506 |
+
gr.Markdown(
|
507 |
+
value=i18n(
|
508 |
+
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
509 |
+
)
|
510 |
+
)
|
511 |
+
with gr.Row():
|
512 |
+
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
513 |
+
with gr.Column():
|
514 |
+
_how_to_cut = gr.Radio(
|
515 |
+
label=i18n("怎么切"),
|
516 |
+
choices=[
|
517 |
+
i18n("不切"),
|
518 |
+
i18n("凑四句一切"),
|
519 |
+
i18n("凑50字一切"),
|
520 |
+
i18n("按中文句号。切"),
|
521 |
+
i18n("按英文句号.切"),
|
522 |
+
i18n("按标点符号切"),
|
523 |
+
],
|
524 |
+
value=i18n("凑四句一切"),
|
525 |
+
interactive=True,
|
526 |
+
)
|
527 |
+
cut_text = gr.Button(i18n("切分"), variant="primary")
|
528 |
+
|
529 |
+
def to_cut(text_inp, how_to_cut):
|
530 |
+
if len(text_inp.strip()) == 0 or text_inp == []:
|
531 |
+
return ""
|
532 |
+
method = get_method(cut_method[how_to_cut])
|
533 |
+
return method(text_inp)
|
534 |
+
|
535 |
+
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
|
536 |
+
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
537 |
+
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
538 |
+
|
539 |
+
if __name__ == "__main__":
|
540 |
+
app.queue().launch( # concurrency_count=511, max_size=1022
|
541 |
+
server_name="0.0.0.0",
|
542 |
+
inbrowser=True,
|
543 |
+
share=is_share,
|
544 |
+
server_port=infer_ttswebui,
|
545 |
+
# quiet=True,
|
546 |
+
)
|
GPT_SoVITS/onnx_export.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
4 |
+
from feature_extractor import cnhubert
|
5 |
+
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
9 |
+
cnhubert.cnhubert_base_path = cnhubert_base_path
|
10 |
+
ssl_model = cnhubert.get_model()
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
|
14 |
+
import soundfile
|
15 |
+
from text import cleaned_text_to_sequence
|
16 |
+
|
17 |
+
|
18 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
19 |
+
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
20 |
+
y = torch.nn.functional.pad(
|
21 |
+
y.unsqueeze(1),
|
22 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
23 |
+
mode="reflect",
|
24 |
+
)
|
25 |
+
y = y.squeeze(1)
|
26 |
+
spec = torch.stft(
|
27 |
+
y,
|
28 |
+
n_fft,
|
29 |
+
hop_length=hop_size,
|
30 |
+
win_length=win_size,
|
31 |
+
window=hann_window,
|
32 |
+
center=center,
|
33 |
+
pad_mode="reflect",
|
34 |
+
normalized=False,
|
35 |
+
onesided=True,
|
36 |
+
return_complex=False,
|
37 |
+
)
|
38 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
39 |
+
return spec
|
40 |
+
|
41 |
+
|
42 |
+
class DictToAttrRecursive(dict):
|
43 |
+
def __init__(self, input_dict):
|
44 |
+
super().__init__(input_dict)
|
45 |
+
for key, value in input_dict.items():
|
46 |
+
if isinstance(value, dict):
|
47 |
+
value = DictToAttrRecursive(value)
|
48 |
+
self[key] = value
|
49 |
+
setattr(self, key, value)
|
50 |
+
|
51 |
+
def __getattr__(self, item):
|
52 |
+
try:
|
53 |
+
return self[item]
|
54 |
+
except KeyError:
|
55 |
+
raise AttributeError(f"Attribute {item} not found")
|
56 |
+
|
57 |
+
def __setattr__(self, key, value):
|
58 |
+
if isinstance(value, dict):
|
59 |
+
value = DictToAttrRecursive(value)
|
60 |
+
super(DictToAttrRecursive, self).__setitem__(key, value)
|
61 |
+
super().__setattr__(key, value)
|
62 |
+
|
63 |
+
def __delattr__(self, item):
|
64 |
+
try:
|
65 |
+
del self[item]
|
66 |
+
except KeyError:
|
67 |
+
raise AttributeError(f"Attribute {item} not found")
|
68 |
+
|
69 |
+
|
70 |
+
class T2SEncoder(nn.Module):
|
71 |
+
def __init__(self, t2s, vits):
|
72 |
+
super().__init__()
|
73 |
+
self.encoder = t2s.onnx_encoder
|
74 |
+
self.vits = vits
|
75 |
+
|
76 |
+
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
77 |
+
codes = self.vits.extract_latent(ssl_content)
|
78 |
+
prompt_semantic = codes[0, 0]
|
79 |
+
bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
|
80 |
+
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
81 |
+
bert = bert.unsqueeze(0)
|
82 |
+
prompt = prompt_semantic.unsqueeze(0)
|
83 |
+
return self.encoder(all_phoneme_ids, bert), prompt
|
84 |
+
|
85 |
+
|
86 |
+
class T2SModel(nn.Module):
|
87 |
+
def __init__(self, t2s_path, vits_model):
|
88 |
+
super().__init__()
|
89 |
+
dict_s1 = torch.load(t2s_path, map_location="cpu")
|
90 |
+
self.config = dict_s1["config"]
|
91 |
+
self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
|
92 |
+
self.t2s_model.load_state_dict(dict_s1["weight"])
|
93 |
+
self.t2s_model.eval()
|
94 |
+
self.vits_model = vits_model.vq_model
|
95 |
+
self.hz = 50
|
96 |
+
self.max_sec = self.config["data"]["max_sec"]
|
97 |
+
self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
|
98 |
+
self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
99 |
+
self.t2s_model = self.t2s_model.model
|
100 |
+
self.t2s_model.init_onnx()
|
101 |
+
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
102 |
+
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
103 |
+
self.stage_decoder = self.t2s_model.stage_decoder
|
104 |
+
# self.t2s_model = torch.jit.script(self.t2s_model)
|
105 |
+
|
106 |
+
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
107 |
+
early_stop_num = self.t2s_model.early_stop_num
|
108 |
+
|
109 |
+
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
110 |
+
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
111 |
+
|
112 |
+
prefix_len = prompts.shape[1]
|
113 |
+
|
114 |
+
# [1,N,512] [1,N]
|
115 |
+
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
116 |
+
|
117 |
+
stop = False
|
118 |
+
for idx in range(1, 1500):
|
119 |
+
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
120 |
+
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
121 |
+
y, k, v, y_emb, logits, samples = enco
|
122 |
+
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
123 |
+
stop = True
|
124 |
+
if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
|
125 |
+
stop = True
|
126 |
+
if stop:
|
127 |
+
break
|
128 |
+
y[0, -1] = 0
|
129 |
+
|
130 |
+
return y[:, -idx:].unsqueeze(0)
|
131 |
+
|
132 |
+
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
133 |
+
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
134 |
+
if dynamo:
|
135 |
+
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
136 |
+
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
137 |
+
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
138 |
+
)
|
139 |
+
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
140 |
+
return
|
141 |
+
|
142 |
+
torch.onnx.export(
|
143 |
+
self.onnx_encoder,
|
144 |
+
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
145 |
+
f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
|
146 |
+
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
147 |
+
output_names=["x", "prompts"],
|
148 |
+
dynamic_axes={
|
149 |
+
"ref_seq": {1: "ref_length"},
|
150 |
+
"text_seq": {1: "text_length"},
|
151 |
+
"ref_bert": {0: "ref_length"},
|
152 |
+
"text_bert": {0: "text_length"},
|
153 |
+
"ssl_content": {2: "ssl_length"},
|
154 |
+
},
|
155 |
+
opset_version=16,
|
156 |
+
)
|
157 |
+
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
158 |
+
|
159 |
+
torch.onnx.export(
|
160 |
+
self.first_stage_decoder,
|
161 |
+
(x, prompts),
|
162 |
+
f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
|
163 |
+
input_names=["x", "prompts"],
|
164 |
+
output_names=["y", "k", "v", "y_emb", "x_example"],
|
165 |
+
dynamic_axes={
|
166 |
+
"x": {1: "x_length"},
|
167 |
+
"prompts": {1: "prompts_length"},
|
168 |
+
},
|
169 |
+
verbose=False,
|
170 |
+
opset_version=16,
|
171 |
+
)
|
172 |
+
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
173 |
+
|
174 |
+
torch.onnx.export(
|
175 |
+
self.stage_decoder,
|
176 |
+
(y, k, v, y_emb, x_example),
|
177 |
+
f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
|
178 |
+
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
179 |
+
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
180 |
+
dynamic_axes={
|
181 |
+
"iy": {1: "iy_length"},
|
182 |
+
"ik": {1: "ik_length"},
|
183 |
+
"iv": {1: "iv_length"},
|
184 |
+
"iy_emb": {1: "iy_emb_length"},
|
185 |
+
"ix_example": {1: "ix_example_length"},
|
186 |
+
},
|
187 |
+
verbose=False,
|
188 |
+
opset_version=16,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
class VitsModel(nn.Module):
|
193 |
+
def __init__(self, vits_path):
|
194 |
+
super().__init__()
|
195 |
+
dict_s2 = torch.load(vits_path, map_location="cpu")
|
196 |
+
self.hps = dict_s2["config"]
|
197 |
+
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
198 |
+
self.hps["model"]["version"] = "v1"
|
199 |
+
else:
|
200 |
+
self.hps["model"]["version"] = "v2"
|
201 |
+
|
202 |
+
self.hps = DictToAttrRecursive(self.hps)
|
203 |
+
self.hps.model.semantic_frame_rate = "25hz"
|
204 |
+
self.vq_model = SynthesizerTrn(
|
205 |
+
self.hps.data.filter_length // 2 + 1,
|
206 |
+
self.hps.train.segment_size // self.hps.data.hop_length,
|
207 |
+
n_speakers=self.hps.data.n_speakers,
|
208 |
+
**self.hps.model,
|
209 |
+
)
|
210 |
+
self.vq_model.eval()
|
211 |
+
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
212 |
+
|
213 |
+
def forward(self, text_seq, pred_semantic, ref_audio):
|
214 |
+
refer = spectrogram_torch(
|
215 |
+
ref_audio,
|
216 |
+
self.hps.data.filter_length,
|
217 |
+
self.hps.data.sampling_rate,
|
218 |
+
self.hps.data.hop_length,
|
219 |
+
self.hps.data.win_length,
|
220 |
+
center=False,
|
221 |
+
)
|
222 |
+
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
223 |
+
|
224 |
+
|
225 |
+
class GptSoVits(nn.Module):
|
226 |
+
def __init__(self, vits, t2s):
|
227 |
+
super().__init__()
|
228 |
+
self.vits = vits
|
229 |
+
self.t2s = t2s
|
230 |
+
|
231 |
+
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
232 |
+
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
233 |
+
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
234 |
+
if debug:
|
235 |
+
import onnxruntime
|
236 |
+
|
237 |
+
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
238 |
+
audio1 = sess.run(
|
239 |
+
None,
|
240 |
+
{
|
241 |
+
"text_seq": text_seq.detach().cpu().numpy(),
|
242 |
+
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
243 |
+
"ref_audio": ref_audio.detach().cpu().numpy(),
|
244 |
+
},
|
245 |
+
)
|
246 |
+
return audio, audio1
|
247 |
+
return audio
|
248 |
+
|
249 |
+
def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
|
250 |
+
self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
|
251 |
+
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
252 |
+
torch.onnx.export(
|
253 |
+
self.vits,
|
254 |
+
(text_seq, pred_semantic, ref_audio),
|
255 |
+
f"onnx/{project_name}/{project_name}_vits.onnx",
|
256 |
+
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
257 |
+
output_names=["audio"],
|
258 |
+
dynamic_axes={
|
259 |
+
"text_seq": {1: "text_length"},
|
260 |
+
"pred_semantic": {2: "pred_length"},
|
261 |
+
"ref_audio": {1: "audio_length"},
|
262 |
+
},
|
263 |
+
opset_version=17,
|
264 |
+
verbose=False,
|
265 |
+
)
|
266 |
+
|
267 |
+
|
268 |
+
class SSLModel(nn.Module):
|
269 |
+
def __init__(self):
|
270 |
+
super().__init__()
|
271 |
+
self.ssl = ssl_model
|
272 |
+
|
273 |
+
def forward(self, ref_audio_16k):
|
274 |
+
return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
275 |
+
|
276 |
+
|
277 |
+
def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
278 |
+
vits = VitsModel(vits_path)
|
279 |
+
gpt = T2SModel(gpt_path, vits)
|
280 |
+
gpt_sovits = GptSoVits(vits, gpt)
|
281 |
+
ssl = SSLModel()
|
282 |
+
ref_seq = torch.LongTensor(
|
283 |
+
[
|
284 |
+
cleaned_text_to_sequence(
|
285 |
+
[
|
286 |
+
"n",
|
287 |
+
"i2",
|
288 |
+
"h",
|
289 |
+
"ao3",
|
290 |
+
",",
|
291 |
+
"w",
|
292 |
+
"o3",
|
293 |
+
"sh",
|
294 |
+
"i4",
|
295 |
+
"b",
|
296 |
+
"ai2",
|
297 |
+
"y",
|
298 |
+
"e4",
|
299 |
+
],
|
300 |
+
version=vits_model,
|
301 |
+
)
|
302 |
+
]
|
303 |
+
)
|
304 |
+
text_seq = torch.LongTensor(
|
305 |
+
[
|
306 |
+
cleaned_text_to_sequence(
|
307 |
+
[
|
308 |
+
"w",
|
309 |
+
"o3",
|
310 |
+
"sh",
|
311 |
+
"i4",
|
312 |
+
"b",
|
313 |
+
"ai2",
|
314 |
+
"y",
|
315 |
+
"e4",
|
316 |
+
"w",
|
317 |
+
"o3",
|
318 |
+
"sh",
|
319 |
+
"i4",
|
320 |
+
"b",
|
321 |
+
"ai2",
|
322 |
+
"y",
|
323 |
+
"e4",
|
324 |
+
"w",
|
325 |
+
"o3",
|
326 |
+
"sh",
|
327 |
+
"i4",
|
328 |
+
"b",
|
329 |
+
"ai2",
|
330 |
+
"y",
|
331 |
+
"e4",
|
332 |
+
],
|
333 |
+
version=vits_model,
|
334 |
+
)
|
335 |
+
]
|
336 |
+
)
|
337 |
+
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
338 |
+
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
339 |
+
ref_audio = torch.randn((1, 48000 * 5)).float()
|
340 |
+
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
341 |
+
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
342 |
+
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
343 |
+
|
344 |
+
try:
|
345 |
+
os.mkdir(f"onnx/{project_name}")
|
346 |
+
except:
|
347 |
+
pass
|
348 |
+
|
349 |
+
ssl_content = ssl(ref_audio_16k).float()
|
350 |
+
|
351 |
+
# debug = False
|
352 |
+
debug = True
|
353 |
+
|
354 |
+
# gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
|
355 |
+
|
356 |
+
if debug:
|
357 |
+
a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
|
358 |
+
soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
|
359 |
+
soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
|
360 |
+
else:
|
361 |
+
a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
|
362 |
+
soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
363 |
+
|
364 |
+
if vits_model == "v1":
|
365 |
+
symbols = symbols_v1
|
366 |
+
else:
|
367 |
+
symbols = symbols_v2
|
368 |
+
|
369 |
+
MoeVSConf = {
|
370 |
+
"Folder": f"{project_name}",
|
371 |
+
"Name": f"{project_name}",
|
372 |
+
"Type": "GPT-SoVits",
|
373 |
+
"Rate": vits.hps.data.sampling_rate,
|
374 |
+
"NumLayers": gpt.t2s_model.num_layers,
|
375 |
+
"EmbeddingDim": gpt.t2s_model.embedding_dim,
|
376 |
+
"Dict": "BasicDict",
|
377 |
+
"BertPath": "chinese-roberta-wwm-ext-large",
|
378 |
+
# "Symbol": symbols,
|
379 |
+
"AddBlank": False,
|
380 |
+
}
|
381 |
+
|
382 |
+
MoeVSConfJson = json.dumps(MoeVSConf)
|
383 |
+
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
384 |
+
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
385 |
+
|
386 |
+
|
387 |
+
if __name__ == "__main__":
|
388 |
+
try:
|
389 |
+
os.mkdir("onnx")
|
390 |
+
except:
|
391 |
+
pass
|
392 |
+
|
393 |
+
gpt_path = "GPT_weights/nahida-e25.ckpt"
|
394 |
+
vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
|
395 |
+
exp_path = "nahida"
|
396 |
+
export(vits_path, gpt_path, exp_path)
|
397 |
+
|
398 |
+
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
GPT_SoVITS/prepare_datasets/1-get-text.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
inp_text = os.environ.get("inp_text")
|
6 |
+
inp_wav_dir = os.environ.get("inp_wav_dir")
|
7 |
+
exp_name = os.environ.get("exp_name")
|
8 |
+
i_part = os.environ.get("i_part")
|
9 |
+
all_parts = os.environ.get("all_parts")
|
10 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
12 |
+
opt_dir = os.environ.get("opt_dir")
|
13 |
+
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
14 |
+
import torch
|
15 |
+
|
16 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
17 |
+
version = os.environ.get("version", None)
|
18 |
+
import traceback
|
19 |
+
import os.path
|
20 |
+
from text.cleaner import clean_text
|
21 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
22 |
+
from tools.my_utils import clean_path
|
23 |
+
|
24 |
+
# inp_text=sys.argv[1]
|
25 |
+
# inp_wav_dir=sys.argv[2]
|
26 |
+
# exp_name=sys.argv[3]
|
27 |
+
# i_part=sys.argv[4]
|
28 |
+
# all_parts=sys.argv[5]
|
29 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
|
30 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
31 |
+
# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
|
32 |
+
|
33 |
+
from time import time as ttime
|
34 |
+
import shutil
|
35 |
+
|
36 |
+
|
37 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
38 |
+
dir = os.path.dirname(path)
|
39 |
+
name = os.path.basename(path)
|
40 |
+
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
41 |
+
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
42 |
+
torch.save(fea, tmp_path)
|
43 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
44 |
+
|
45 |
+
|
46 |
+
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
47 |
+
if os.path.exists(txt_path) == False:
|
48 |
+
bert_dir = "%s/3-bert" % (opt_dir)
|
49 |
+
os.makedirs(opt_dir, exist_ok=True)
|
50 |
+
os.makedirs(bert_dir, exist_ok=True)
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
device = "cuda:0"
|
53 |
+
# elif torch.backends.mps.is_available():
|
54 |
+
# device = "mps"
|
55 |
+
else:
|
56 |
+
device = "cpu"
|
57 |
+
if os.path.exists(bert_pretrained_dir):
|
58 |
+
...
|
59 |
+
else:
|
60 |
+
raise FileNotFoundError(bert_pretrained_dir)
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
62 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
63 |
+
if is_half == True:
|
64 |
+
bert_model = bert_model.half().to(device)
|
65 |
+
else:
|
66 |
+
bert_model = bert_model.to(device)
|
67 |
+
|
68 |
+
def get_bert_feature(text, word2ph):
|
69 |
+
with torch.no_grad():
|
70 |
+
inputs = tokenizer(text, return_tensors="pt")
|
71 |
+
for i in inputs:
|
72 |
+
inputs[i] = inputs[i].to(device)
|
73 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
74 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
75 |
+
|
76 |
+
assert len(word2ph) == len(text)
|
77 |
+
phone_level_feature = []
|
78 |
+
for i in range(len(word2ph)):
|
79 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
80 |
+
phone_level_feature.append(repeat_feature)
|
81 |
+
|
82 |
+
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
83 |
+
|
84 |
+
return phone_level_feature.T
|
85 |
+
|
86 |
+
def process(data, res):
|
87 |
+
for name, text, lan in data:
|
88 |
+
try:
|
89 |
+
name = clean_path(name)
|
90 |
+
name = os.path.basename(name)
|
91 |
+
print(name)
|
92 |
+
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
93 |
+
path_bert = "%s/%s.pt" % (bert_dir, name)
|
94 |
+
if os.path.exists(path_bert) == False and lan == "zh":
|
95 |
+
bert_feature = get_bert_feature(norm_text, word2ph)
|
96 |
+
assert bert_feature.shape[-1] == len(phones)
|
97 |
+
# torch.save(bert_feature, path_bert)
|
98 |
+
my_save(bert_feature, path_bert)
|
99 |
+
phones = " ".join(phones)
|
100 |
+
# res.append([name,phones])
|
101 |
+
res.append([name, phones, word2ph, norm_text])
|
102 |
+
except:
|
103 |
+
print(name, text, traceback.format_exc())
|
104 |
+
|
105 |
+
todo = []
|
106 |
+
res = []
|
107 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
108 |
+
lines = f.read().strip("\n").split("\n")
|
109 |
+
|
110 |
+
language_v1_to_language_v2 = {
|
111 |
+
"ZH": "zh",
|
112 |
+
"zh": "zh",
|
113 |
+
"JP": "ja",
|
114 |
+
"jp": "ja",
|
115 |
+
"JA": "ja",
|
116 |
+
"ja": "ja",
|
117 |
+
"EN": "en",
|
118 |
+
"en": "en",
|
119 |
+
"En": "en",
|
120 |
+
"KO": "ko",
|
121 |
+
"Ko": "ko",
|
122 |
+
"ko": "ko",
|
123 |
+
"yue": "yue",
|
124 |
+
"YUE": "yue",
|
125 |
+
"Yue": "yue",
|
126 |
+
}
|
127 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
128 |
+
try:
|
129 |
+
wav_name, spk_name, language, text = line.split("|")
|
130 |
+
# todo.append([name,text,"zh"])
|
131 |
+
if language in language_v1_to_language_v2.keys():
|
132 |
+
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
133 |
+
else:
|
134 |
+
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
135 |
+
except:
|
136 |
+
print(line, traceback.format_exc())
|
137 |
+
|
138 |
+
process(todo, res)
|
139 |
+
opt = []
|
140 |
+
for name, phones, word2ph, norm_text in res:
|
141 |
+
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
|
142 |
+
with open(txt_path, "w", encoding="utf8") as f:
|
143 |
+
f.write("\n".join(opt) + "\n")
|
GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
inp_text = os.environ.get("inp_text")
|
7 |
+
inp_wav_dir = os.environ.get("inp_wav_dir")
|
8 |
+
exp_name = os.environ.get("exp_name")
|
9 |
+
i_part = os.environ.get("i_part")
|
10 |
+
all_parts = os.environ.get("all_parts")
|
11 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
12 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
13 |
+
from feature_extractor import cnhubert
|
14 |
+
|
15 |
+
opt_dir = os.environ.get("opt_dir")
|
16 |
+
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
17 |
+
import torch
|
18 |
+
|
19 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
20 |
+
|
21 |
+
import traceback
|
22 |
+
import numpy as np
|
23 |
+
from scipy.io import wavfile
|
24 |
+
import librosa
|
25 |
+
|
26 |
+
now_dir = os.getcwd()
|
27 |
+
sys.path.append(now_dir)
|
28 |
+
from tools.my_utils import load_audio, clean_path
|
29 |
+
|
30 |
+
# from config import cnhubert_base_path
|
31 |
+
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
32 |
+
# inp_text=sys.argv[1]
|
33 |
+
# inp_wav_dir=sys.argv[2]
|
34 |
+
# exp_name=sys.argv[3]
|
35 |
+
# i_part=sys.argv[4]
|
36 |
+
# all_parts=sys.argv[5]
|
37 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
|
38 |
+
# cnhubert.cnhubert_base_path=sys.argv[7]
|
39 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
40 |
+
|
41 |
+
from time import time as ttime
|
42 |
+
import shutil
|
43 |
+
|
44 |
+
|
45 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
46 |
+
dir = os.path.dirname(path)
|
47 |
+
name = os.path.basename(path)
|
48 |
+
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
49 |
+
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
50 |
+
torch.save(fea, tmp_path)
|
51 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
52 |
+
|
53 |
+
|
54 |
+
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
55 |
+
wav32dir = "%s/5-wav32k" % (opt_dir)
|
56 |
+
os.makedirs(opt_dir, exist_ok=True)
|
57 |
+
os.makedirs(hubert_dir, exist_ok=True)
|
58 |
+
os.makedirs(wav32dir, exist_ok=True)
|
59 |
+
|
60 |
+
maxx = 0.95
|
61 |
+
alpha = 0.5
|
62 |
+
if torch.cuda.is_available():
|
63 |
+
device = "cuda:0"
|
64 |
+
# elif torch.backends.mps.is_available():
|
65 |
+
# device = "mps"
|
66 |
+
else:
|
67 |
+
device = "cpu"
|
68 |
+
model = cnhubert.get_model()
|
69 |
+
# is_half=False
|
70 |
+
if is_half == True:
|
71 |
+
model = model.half().to(device)
|
72 |
+
else:
|
73 |
+
model = model.to(device)
|
74 |
+
|
75 |
+
nan_fails = []
|
76 |
+
|
77 |
+
|
78 |
+
def name2go(wav_name, wav_path):
|
79 |
+
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
80 |
+
if os.path.exists(hubert_path):
|
81 |
+
return
|
82 |
+
tmp_audio = load_audio(wav_path, 32000)
|
83 |
+
tmp_max = np.abs(tmp_audio).max()
|
84 |
+
if tmp_max > 2.2:
|
85 |
+
print("%s-filtered,%s" % (wav_name, tmp_max))
|
86 |
+
return
|
87 |
+
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
88 |
+
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
89 |
+
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
90 |
+
tensor_wav16 = torch.from_numpy(tmp_audio)
|
91 |
+
if is_half == True:
|
92 |
+
tensor_wav16 = tensor_wav16.half().to(device)
|
93 |
+
else:
|
94 |
+
tensor_wav16 = tensor_wav16.to(device)
|
95 |
+
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
96 |
+
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
97 |
+
nan_fails.append((wav_name, wav_path))
|
98 |
+
print("nan filtered:%s" % wav_name)
|
99 |
+
return
|
100 |
+
wavfile.write(
|
101 |
+
"%s/%s" % (wav32dir, wav_name),
|
102 |
+
32000,
|
103 |
+
tmp_audio32.astype("int16"),
|
104 |
+
)
|
105 |
+
my_save(ssl, hubert_path)
|
106 |
+
|
107 |
+
|
108 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
109 |
+
lines = f.read().strip("\n").split("\n")
|
110 |
+
|
111 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
112 |
+
try:
|
113 |
+
# wav_name,text=line.split("\t")
|
114 |
+
wav_name, spk_name, language, text = line.split("|")
|
115 |
+
wav_name = clean_path(wav_name)
|
116 |
+
if inp_wav_dir != "" and inp_wav_dir != None:
|
117 |
+
wav_name = os.path.basename(wav_name)
|
118 |
+
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
119 |
+
|
120 |
+
else:
|
121 |
+
wav_path = wav_name
|
122 |
+
wav_name = os.path.basename(wav_name)
|
123 |
+
name2go(wav_name, wav_path)
|
124 |
+
except:
|
125 |
+
print(line, traceback.format_exc())
|
126 |
+
|
127 |
+
if len(nan_fails) > 0 and is_half == True:
|
128 |
+
is_half = False
|
129 |
+
model = model.float()
|
130 |
+
for wav in nan_fails:
|
131 |
+
try:
|
132 |
+
name2go(wav[0], wav[1])
|
133 |
+
except:
|
134 |
+
print(wav_name, traceback.format_exc())
|
GPT_SoVITS/prepare_datasets/3-get-semantic.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
inp_text = os.environ.get("inp_text")
|
4 |
+
exp_name = os.environ.get("exp_name")
|
5 |
+
i_part = os.environ.get("i_part")
|
6 |
+
all_parts = os.environ.get("all_parts")
|
7 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
8 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
9 |
+
opt_dir = os.environ.get("opt_dir")
|
10 |
+
pretrained_s2G = os.environ.get("pretrained_s2G")
|
11 |
+
s2config_path = os.environ.get("s2config_path")
|
12 |
+
|
13 |
+
if os.path.exists(pretrained_s2G):
|
14 |
+
...
|
15 |
+
else:
|
16 |
+
raise FileNotFoundError(pretrained_s2G)
|
17 |
+
# version=os.environ.get("version","v2")
|
18 |
+
size = os.path.getsize(pretrained_s2G)
|
19 |
+
if size < 82978 * 1024:
|
20 |
+
version = "v1"
|
21 |
+
elif size < 100 * 1024 * 1024:
|
22 |
+
version = "v2"
|
23 |
+
elif size < 103520 * 1024:
|
24 |
+
version = "v1"
|
25 |
+
elif size < 700 * 1024 * 1024:
|
26 |
+
version = "v2"
|
27 |
+
else:
|
28 |
+
version = "v3"
|
29 |
+
import torch
|
30 |
+
|
31 |
+
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
32 |
+
import traceback
|
33 |
+
import sys
|
34 |
+
|
35 |
+
now_dir = os.getcwd()
|
36 |
+
sys.path.append(now_dir)
|
37 |
+
import logging
|
38 |
+
import utils
|
39 |
+
|
40 |
+
if version != "v3":
|
41 |
+
from module.models import SynthesizerTrn
|
42 |
+
else:
|
43 |
+
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
44 |
+
from tools.my_utils import clean_path
|
45 |
+
|
46 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
47 |
+
# from config import pretrained_s2G
|
48 |
+
|
49 |
+
# inp_text=sys.argv[1]
|
50 |
+
# exp_name=sys.argv[2]
|
51 |
+
# i_part=sys.argv[3]
|
52 |
+
# all_parts=sys.argv[4]
|
53 |
+
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
|
54 |
+
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
|
55 |
+
|
56 |
+
|
57 |
+
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
58 |
+
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
|
59 |
+
if os.path.exists(semantic_path) == False:
|
60 |
+
os.makedirs(opt_dir, exist_ok=True)
|
61 |
+
|
62 |
+
if torch.cuda.is_available():
|
63 |
+
device = "cuda"
|
64 |
+
# elif torch.backends.mps.is_available():
|
65 |
+
# device = "mps"
|
66 |
+
else:
|
67 |
+
device = "cpu"
|
68 |
+
hps = utils.get_hparams_from_file(s2config_path)
|
69 |
+
vq_model = SynthesizerTrn(
|
70 |
+
hps.data.filter_length // 2 + 1,
|
71 |
+
hps.train.segment_size // hps.data.hop_length,
|
72 |
+
n_speakers=hps.data.n_speakers,
|
73 |
+
version=version,
|
74 |
+
**hps.model,
|
75 |
+
)
|
76 |
+
if is_half == True:
|
77 |
+
vq_model = vq_model.half().to(device)
|
78 |
+
else:
|
79 |
+
vq_model = vq_model.to(device)
|
80 |
+
vq_model.eval()
|
81 |
+
# utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
|
82 |
+
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
83 |
+
print(
|
84 |
+
vq_model.load_state_dict(
|
85 |
+
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
|
86 |
+
)
|
87 |
+
)
|
88 |
+
|
89 |
+
def name2go(wav_name, lines):
|
90 |
+
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
91 |
+
if os.path.exists(hubert_path) == False:
|
92 |
+
return
|
93 |
+
ssl_content = torch.load(hubert_path, map_location="cpu")
|
94 |
+
if is_half == True:
|
95 |
+
ssl_content = ssl_content.half().to(device)
|
96 |
+
else:
|
97 |
+
ssl_content = ssl_content.to(device)
|
98 |
+
codes = vq_model.extract_latent(ssl_content)
|
99 |
+
semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
|
100 |
+
lines.append("%s\t%s" % (wav_name, semantic))
|
101 |
+
|
102 |
+
with open(inp_text, "r", encoding="utf8") as f:
|
103 |
+
lines = f.read().strip("\n").split("\n")
|
104 |
+
|
105 |
+
lines1 = []
|
106 |
+
for line in lines[int(i_part) :: int(all_parts)]:
|
107 |
+
# print(line)
|
108 |
+
try:
|
109 |
+
# wav_name,text=line.split("\t")
|
110 |
+
wav_name, spk_name, language, text = line.split("|")
|
111 |
+
wav_name = clean_path(wav_name)
|
112 |
+
wav_name = os.path.basename(wav_name)
|
113 |
+
# name2go(name,lines1)
|
114 |
+
name2go(wav_name, lines1)
|
115 |
+
except:
|
116 |
+
print(line, traceback.format_exc())
|
117 |
+
with open(semantic_path, "w", encoding="utf8") as f:
|
118 |
+
f.write("\n".join(lines1))
|
GPT_SoVITS/pretrained_models/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1
|
3 |
+
size 155093966
|
GPT_SoVITS/pretrained_models/s1v3.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:87133414860ea14ff6620c483a3db5ed07b44be42e2c3fcdad65523a729a745a
|
3 |
+
size 155284856
|
GPT_SoVITS/process_ckpt.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from collections import OrderedDict
|
3 |
+
from time import time as ttime
|
4 |
+
import shutil
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from tools.i18n.i18n import I18nAuto
|
8 |
+
|
9 |
+
i18n = I18nAuto()
|
10 |
+
|
11 |
+
|
12 |
+
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
13 |
+
dir = os.path.dirname(path)
|
14 |
+
name = os.path.basename(path)
|
15 |
+
tmp_path = "%s.pth" % (ttime())
|
16 |
+
torch.save(fea, tmp_path)
|
17 |
+
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
18 |
+
|
19 |
+
|
20 |
+
"""
|
21 |
+
00:v1
|
22 |
+
01:v2
|
23 |
+
02:v3
|
24 |
+
03:v3lora
|
25 |
+
04:v4lora
|
26 |
+
|
27 |
+
"""
|
28 |
+
from io import BytesIO
|
29 |
+
|
30 |
+
|
31 |
+
def my_save2(fea, path,cfm_version):
|
32 |
+
bio = BytesIO()
|
33 |
+
torch.save(fea, bio)
|
34 |
+
bio.seek(0)
|
35 |
+
data = bio.getvalue()
|
36 |
+
byte=b"03" if cfm_version=="v3"else b"04"
|
37 |
+
data = byte + data[2:]
|
38 |
+
with open(path, "wb") as f:
|
39 |
+
f.write(data)
|
40 |
+
|
41 |
+
|
42 |
+
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
|
43 |
+
try:
|
44 |
+
opt = OrderedDict()
|
45 |
+
opt["weight"] = {}
|
46 |
+
for key in ckpt.keys():
|
47 |
+
if "enc_q" in key:
|
48 |
+
continue
|
49 |
+
opt["weight"][key] = ckpt[key].half()
|
50 |
+
opt["config"] = hps
|
51 |
+
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
|
52 |
+
if lora_rank:
|
53 |
+
opt["lora_rank"] = lora_rank
|
54 |
+
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
|
55 |
+
else:
|
56 |
+
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
57 |
+
return "Success."
|
58 |
+
except:
|
59 |
+
return traceback.format_exc()
|
60 |
+
|
61 |
+
|
62 |
+
head2version = {
|
63 |
+
b"00": ["v1", "v1", False],
|
64 |
+
b"01": ["v2", "v2", False],
|
65 |
+
b"02": ["v2", "v3", False],
|
66 |
+
b"03": ["v2", "v3", True],
|
67 |
+
b"04": ["v2", "v4", True],
|
68 |
+
}
|
69 |
+
hash_pretrained_dict = {
|
70 |
+
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
|
71 |
+
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
|
72 |
+
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
|
73 |
+
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
|
74 |
+
}
|
75 |
+
import hashlib
|
76 |
+
|
77 |
+
|
78 |
+
def get_hash_from_file(sovits_path):
|
79 |
+
with open(sovits_path, "rb") as f:
|
80 |
+
data = f.read(8192)
|
81 |
+
hash_md5 = hashlib.md5()
|
82 |
+
hash_md5.update(data)
|
83 |
+
return hash_md5.hexdigest()
|
84 |
+
|
85 |
+
|
86 |
+
def get_sovits_version_from_path_fast(sovits_path):
|
87 |
+
###1-if it is pretrained sovits models, by hash
|
88 |
+
hash = get_hash_from_file(sovits_path)
|
89 |
+
if hash in hash_pretrained_dict:
|
90 |
+
return hash_pretrained_dict[hash]
|
91 |
+
###2-new weights, by head
|
92 |
+
with open(sovits_path, "rb") as f:
|
93 |
+
version = f.read(2)
|
94 |
+
if version != b"PK":
|
95 |
+
return head2version[version]
|
96 |
+
###3-old weights, by file size
|
97 |
+
if_lora_v3 = False
|
98 |
+
size = os.path.getsize(sovits_path)
|
99 |
+
"""
|
100 |
+
v1weights:about 82942KB
|
101 |
+
half thr:82978KB
|
102 |
+
v2weights:about 83014KB
|
103 |
+
v3weights:about 750MB
|
104 |
+
"""
|
105 |
+
if size < 82978 * 1024:
|
106 |
+
model_version = version = "v1"
|
107 |
+
elif size < 700 * 1024 * 1024:
|
108 |
+
model_version = version = "v2"
|
109 |
+
else:
|
110 |
+
version = "v2"
|
111 |
+
model_version = "v3"
|
112 |
+
return version, model_version, if_lora_v3
|
113 |
+
|
114 |
+
|
115 |
+
def load_sovits_new(sovits_path):
|
116 |
+
f = open(sovits_path, "rb")
|
117 |
+
meta = f.read(2)
|
118 |
+
if meta != "PK":
|
119 |
+
data = b"PK" + f.read()
|
120 |
+
bio = BytesIO()
|
121 |
+
bio.write(data)
|
122 |
+
bio.seek(0)
|
123 |
+
return torch.load(bio, map_location="cpu", weights_only=False)
|
124 |
+
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
GPT_SoVITS/s1_train.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
2 |
+
import os
|
3 |
+
|
4 |
+
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
import platform
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from AR.data.data_module import Text2SemanticDataModule
|
13 |
+
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
14 |
+
from AR.utils.io import load_yaml_config
|
15 |
+
from pytorch_lightning import Trainer, seed_everything
|
16 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
+
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
18 |
+
from pytorch_lightning.strategies import DDPStrategy
|
19 |
+
|
20 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
21 |
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
22 |
+
torch.set_float32_matmul_precision("high")
|
23 |
+
from collections import OrderedDict
|
24 |
+
|
25 |
+
from AR.utils import get_newest_ckpt
|
26 |
+
from process_ckpt import my_save
|
27 |
+
|
28 |
+
|
29 |
+
class my_model_ckpt(ModelCheckpoint):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
config,
|
33 |
+
if_save_latest,
|
34 |
+
if_save_every_weights,
|
35 |
+
half_weights_save_dir,
|
36 |
+
exp_name,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
super().__init__(**kwargs)
|
40 |
+
self.if_save_latest = if_save_latest
|
41 |
+
self.if_save_every_weights = if_save_every_weights
|
42 |
+
self.half_weights_save_dir = half_weights_save_dir
|
43 |
+
self.exp_name = exp_name
|
44 |
+
self.config = config
|
45 |
+
|
46 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
47 |
+
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
48 |
+
if self._should_save_on_train_epoch_end(trainer):
|
49 |
+
monitor_candidates = self._monitor_candidates(trainer)
|
50 |
+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
51 |
+
if (
|
52 |
+
self.if_save_latest == True
|
53 |
+
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
54 |
+
to_clean = list(os.listdir(self.dirpath))
|
55 |
+
self._save_topk_checkpoint(trainer, monitor_candidates)
|
56 |
+
if self.if_save_latest == True:
|
57 |
+
for name in to_clean:
|
58 |
+
try:
|
59 |
+
os.remove("%s/%s" % (self.dirpath, name))
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
if self.if_save_every_weights == True:
|
63 |
+
to_save_od = OrderedDict()
|
64 |
+
to_save_od["weight"] = OrderedDict()
|
65 |
+
dictt = trainer.strategy._lightning_module.state_dict()
|
66 |
+
for key in dictt:
|
67 |
+
to_save_od["weight"][key] = dictt[key].half()
|
68 |
+
to_save_od["config"] = self.config
|
69 |
+
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
70 |
+
# torch.save(
|
71 |
+
# print(os.environ)
|
72 |
+
if os.environ.get("LOCAL_RANK", "0") == "0":
|
73 |
+
my_save(
|
74 |
+
to_save_od,
|
75 |
+
"%s/%s-e%s.ckpt"
|
76 |
+
% (
|
77 |
+
self.half_weights_save_dir,
|
78 |
+
self.exp_name,
|
79 |
+
trainer.current_epoch + 1,
|
80 |
+
),
|
81 |
+
)
|
82 |
+
self._save_last_checkpoint(trainer, monitor_candidates)
|
83 |
+
|
84 |
+
|
85 |
+
def main(args):
|
86 |
+
config = load_yaml_config(args.config_file)
|
87 |
+
|
88 |
+
output_dir = Path(config["output_dir"])
|
89 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
|
91 |
+
ckpt_dir = output_dir / "ckpt"
|
92 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
93 |
+
|
94 |
+
seed_everything(config["train"]["seed"], workers=True)
|
95 |
+
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
96 |
+
config=config,
|
97 |
+
if_save_latest=config["train"]["if_save_latest"],
|
98 |
+
if_save_every_weights=config["train"]["if_save_every_weights"],
|
99 |
+
half_weights_save_dir=config["train"]["half_weights_save_dir"],
|
100 |
+
exp_name=config["train"]["exp_name"],
|
101 |
+
save_top_k=-1,
|
102 |
+
monitor="top_3_acc",
|
103 |
+
mode="max",
|
104 |
+
save_on_train_epoch_end=True,
|
105 |
+
every_n_epochs=config["train"]["save_every_n_epoch"],
|
106 |
+
dirpath=ckpt_dir,
|
107 |
+
)
|
108 |
+
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
109 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
110 |
+
os.environ["USE_LIBUV"] = "0"
|
111 |
+
trainer: Trainer = Trainer(
|
112 |
+
max_epochs=config["train"]["epochs"],
|
113 |
+
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
114 |
+
# val_check_interval=9999999999999999999999,###不要验证
|
115 |
+
# check_val_every_n_epoch=None,
|
116 |
+
limit_val_batches=0,
|
117 |
+
devices=-1 if torch.cuda.is_available() else 1,
|
118 |
+
benchmark=False,
|
119 |
+
fast_dev_run=False,
|
120 |
+
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
121 |
+
if torch.cuda.is_available()
|
122 |
+
else "auto",
|
123 |
+
precision=config["train"]["precision"],
|
124 |
+
logger=logger,
|
125 |
+
num_sanity_val_steps=0,
|
126 |
+
callbacks=[ckpt_callback],
|
127 |
+
use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
|
128 |
+
)
|
129 |
+
|
130 |
+
model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
|
131 |
+
|
132 |
+
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
|
133 |
+
config,
|
134 |
+
train_semantic_path=config["train_semantic_path"],
|
135 |
+
train_phoneme_path=config["train_phoneme_path"],
|
136 |
+
# dev_semantic_path=args.dev_semantic_path,
|
137 |
+
# dev_phoneme_path=args.dev_phoneme_path
|
138 |
+
)
|
139 |
+
|
140 |
+
try:
|
141 |
+
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
|
142 |
+
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
|
143 |
+
ckpt_path = ckpt_dir / newest_ckpt_name
|
144 |
+
except Exception:
|
145 |
+
ckpt_path = None
|
146 |
+
print("ckpt_path:", ckpt_path)
|
147 |
+
trainer.fit(model, data_module, ckpt_path=ckpt_path)
|
148 |
+
|
149 |
+
|
150 |
+
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
|
151 |
+
if __name__ == "__main__":
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
parser.add_argument(
|
154 |
+
"-c",
|
155 |
+
"--config_file",
|
156 |
+
type=str,
|
157 |
+
default="configs/s1longer.yaml",
|
158 |
+
help="path of config file",
|
159 |
+
)
|
160 |
+
# args for dataset
|
161 |
+
# parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
|
162 |
+
# parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
|
163 |
+
|
164 |
+
# parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
|
165 |
+
# parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
|
166 |
+
# parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
|
167 |
+
# parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
|
168 |
+
|
169 |
+
args = parser.parse_args()
|
170 |
+
logging.info(str(args))
|
171 |
+
main(args)
|
GPT_SoVITS/s2_train.py
ADDED
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
warnings.filterwarnings("ignore")
|
4 |
+
import os
|
5 |
+
|
6 |
+
import utils
|
7 |
+
|
8 |
+
hps = utils.get_hparams(stage=2)
|
9 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
|
10 |
+
import logging
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import torch.multiprocessing as mp
|
15 |
+
from torch.cuda.amp import GradScaler, autocast
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
from torch.utils.tensorboard import SummaryWriter
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
logging.getLogger("matplotlib").setLevel(logging.INFO)
|
23 |
+
logging.getLogger("h5py").setLevel(logging.INFO)
|
24 |
+
logging.getLogger("numba").setLevel(logging.INFO)
|
25 |
+
from random import randint
|
26 |
+
|
27 |
+
from module import commons
|
28 |
+
from module.data_utils import (
|
29 |
+
DistributedBucketSampler,
|
30 |
+
TextAudioSpeakerCollate,
|
31 |
+
TextAudioSpeakerLoader,
|
32 |
+
)
|
33 |
+
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
|
34 |
+
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
|
35 |
+
from module.models import (
|
36 |
+
MultiPeriodDiscriminator,
|
37 |
+
SynthesizerTrn,
|
38 |
+
)
|
39 |
+
from process_ckpt import savee
|
40 |
+
|
41 |
+
torch.backends.cudnn.benchmark = False
|
42 |
+
torch.backends.cudnn.deterministic = False
|
43 |
+
###反正A100fp32更快,那试试tf32吧
|
44 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
45 |
+
torch.backends.cudnn.allow_tf32 = True
|
46 |
+
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
47 |
+
# from config import pretrained_s2G,pretrained_s2D
|
48 |
+
global_step = 0
|
49 |
+
|
50 |
+
device = "cpu" # cuda以外的设备,等mps优化后加入
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
n_gpus = torch.cuda.device_count()
|
56 |
+
else:
|
57 |
+
n_gpus = 1
|
58 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
59 |
+
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
60 |
+
|
61 |
+
mp.spawn(
|
62 |
+
run,
|
63 |
+
nprocs=n_gpus,
|
64 |
+
args=(
|
65 |
+
n_gpus,
|
66 |
+
hps,
|
67 |
+
),
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def run(rank, n_gpus, hps):
|
72 |
+
global global_step
|
73 |
+
if rank == 0:
|
74 |
+
logger = utils.get_logger(hps.data.exp_dir)
|
75 |
+
logger.info(hps)
|
76 |
+
# utils.check_git_hash(hps.s2_ckpt_dir)
|
77 |
+
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
78 |
+
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
79 |
+
|
80 |
+
dist.init_process_group(
|
81 |
+
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
82 |
+
init_method="env://?use_libuv=False",
|
83 |
+
world_size=n_gpus,
|
84 |
+
rank=rank,
|
85 |
+
)
|
86 |
+
torch.manual_seed(hps.train.seed)
|
87 |
+
if torch.cuda.is_available():
|
88 |
+
torch.cuda.set_device(rank)
|
89 |
+
|
90 |
+
train_dataset = TextAudioSpeakerLoader(hps.data) ########
|
91 |
+
train_sampler = DistributedBucketSampler(
|
92 |
+
train_dataset,
|
93 |
+
hps.train.batch_size,
|
94 |
+
[
|
95 |
+
32,
|
96 |
+
300,
|
97 |
+
400,
|
98 |
+
500,
|
99 |
+
600,
|
100 |
+
700,
|
101 |
+
800,
|
102 |
+
900,
|
103 |
+
1000,
|
104 |
+
1100,
|
105 |
+
1200,
|
106 |
+
1300,
|
107 |
+
1400,
|
108 |
+
1500,
|
109 |
+
1600,
|
110 |
+
1700,
|
111 |
+
1800,
|
112 |
+
1900,
|
113 |
+
],
|
114 |
+
num_replicas=n_gpus,
|
115 |
+
rank=rank,
|
116 |
+
shuffle=True,
|
117 |
+
)
|
118 |
+
collate_fn = TextAudioSpeakerCollate()
|
119 |
+
train_loader = DataLoader(
|
120 |
+
train_dataset,
|
121 |
+
num_workers=6,
|
122 |
+
shuffle=False,
|
123 |
+
pin_memory=True,
|
124 |
+
collate_fn=collate_fn,
|
125 |
+
batch_sampler=train_sampler,
|
126 |
+
persistent_workers=True,
|
127 |
+
prefetch_factor=4,
|
128 |
+
)
|
129 |
+
# if rank == 0:
|
130 |
+
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
|
131 |
+
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
|
132 |
+
# batch_size=1, pin_memory=True,
|
133 |
+
# drop_last=False, collate_fn=collate_fn)
|
134 |
+
|
135 |
+
net_g = (
|
136 |
+
SynthesizerTrn(
|
137 |
+
hps.data.filter_length // 2 + 1,
|
138 |
+
hps.train.segment_size // hps.data.hop_length,
|
139 |
+
n_speakers=hps.data.n_speakers,
|
140 |
+
**hps.model,
|
141 |
+
).cuda(rank)
|
142 |
+
if torch.cuda.is_available()
|
143 |
+
else SynthesizerTrn(
|
144 |
+
hps.data.filter_length // 2 + 1,
|
145 |
+
hps.train.segment_size // hps.data.hop_length,
|
146 |
+
n_speakers=hps.data.n_speakers,
|
147 |
+
**hps.model,
|
148 |
+
).to(device)
|
149 |
+
)
|
150 |
+
|
151 |
+
net_d = (
|
152 |
+
MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
|
153 |
+
if torch.cuda.is_available()
|
154 |
+
else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
|
155 |
+
)
|
156 |
+
for name, param in net_g.named_parameters():
|
157 |
+
if not param.requires_grad:
|
158 |
+
print(name, "not requires_grad")
|
159 |
+
|
160 |
+
te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
|
161 |
+
et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
|
162 |
+
mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
|
163 |
+
base_params = filter(
|
164 |
+
lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
|
165 |
+
net_g.parameters(),
|
166 |
+
)
|
167 |
+
|
168 |
+
# te_p=net_g.enc_p.text_embedding.parameters()
|
169 |
+
# et_p=net_g.enc_p.encoder_text.parameters()
|
170 |
+
# mrte_p=net_g.enc_p.mrte.parameters()
|
171 |
+
|
172 |
+
optim_g = torch.optim.AdamW(
|
173 |
+
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
|
174 |
+
[
|
175 |
+
{"params": base_params, "lr": hps.train.learning_rate},
|
176 |
+
{
|
177 |
+
"params": net_g.enc_p.text_embedding.parameters(),
|
178 |
+
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"params": net_g.enc_p.encoder_text.parameters(),
|
182 |
+
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"params": net_g.enc_p.mrte.parameters(),
|
186 |
+
"lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
|
187 |
+
},
|
188 |
+
],
|
189 |
+
hps.train.learning_rate,
|
190 |
+
betas=hps.train.betas,
|
191 |
+
eps=hps.train.eps,
|
192 |
+
)
|
193 |
+
optim_d = torch.optim.AdamW(
|
194 |
+
net_d.parameters(),
|
195 |
+
hps.train.learning_rate,
|
196 |
+
betas=hps.train.betas,
|
197 |
+
eps=hps.train.eps,
|
198 |
+
)
|
199 |
+
if torch.cuda.is_available():
|
200 |
+
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
201 |
+
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
202 |
+
else:
|
203 |
+
net_g = net_g.to(device)
|
204 |
+
net_d = net_d.to(device)
|
205 |
+
|
206 |
+
try: # 如果能加载自动resume
|
207 |
+
_, _, _, epoch_str = utils.load_checkpoint(
|
208 |
+
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
|
209 |
+
net_d,
|
210 |
+
optim_d,
|
211 |
+
) # D多半加载没事
|
212 |
+
if rank == 0:
|
213 |
+
logger.info("loaded D")
|
214 |
+
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
|
215 |
+
_, _, _, epoch_str = utils.load_checkpoint(
|
216 |
+
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
|
217 |
+
net_g,
|
218 |
+
optim_g,
|
219 |
+
)
|
220 |
+
epoch_str += 1
|
221 |
+
global_step = (epoch_str - 1) * len(train_loader)
|
222 |
+
# epoch_str = 1
|
223 |
+
# global_step = 0
|
224 |
+
except: # 如果首次不能加载,加载pretrain
|
225 |
+
# traceback.print_exc()
|
226 |
+
epoch_str = 1
|
227 |
+
global_step = 0
|
228 |
+
if (
|
229 |
+
hps.train.pretrained_s2G != ""
|
230 |
+
and hps.train.pretrained_s2G != None
|
231 |
+
and os.path.exists(hps.train.pretrained_s2G)
|
232 |
+
):
|
233 |
+
if rank == 0:
|
234 |
+
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
|
235 |
+
print(
|
236 |
+
"loaded pretrained %s" % hps.train.pretrained_s2G,
|
237 |
+
net_g.module.load_state_dict(
|
238 |
+
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
239 |
+
strict=False,
|
240 |
+
)
|
241 |
+
if torch.cuda.is_available()
|
242 |
+
else net_g.load_state_dict(
|
243 |
+
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
|
244 |
+
strict=False,
|
245 |
+
),
|
246 |
+
) ##测试不加载优化器
|
247 |
+
if (
|
248 |
+
hps.train.pretrained_s2D != ""
|
249 |
+
and hps.train.pretrained_s2D != None
|
250 |
+
and os.path.exists(hps.train.pretrained_s2D)
|
251 |
+
):
|
252 |
+
if rank == 0:
|
253 |
+
logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
|
254 |
+
print(
|
255 |
+
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
256 |
+
net_d.module.load_state_dict(
|
257 |
+
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
258 |
+
)
|
259 |
+
if torch.cuda.is_available()
|
260 |
+
else net_d.load_state_dict(
|
261 |
+
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
|
262 |
+
),
|
263 |
+
)
|
264 |
+
|
265 |
+
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
266 |
+
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
|
267 |
+
|
268 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
269 |
+
optim_g,
|
270 |
+
gamma=hps.train.lr_decay,
|
271 |
+
last_epoch=-1,
|
272 |
+
)
|
273 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
274 |
+
optim_d,
|
275 |
+
gamma=hps.train.lr_decay,
|
276 |
+
last_epoch=-1,
|
277 |
+
)
|
278 |
+
for _ in range(epoch_str):
|
279 |
+
scheduler_g.step()
|
280 |
+
scheduler_d.step()
|
281 |
+
|
282 |
+
scaler = GradScaler(enabled=hps.train.fp16_run)
|
283 |
+
|
284 |
+
print("start training from epoch %s" % epoch_str)
|
285 |
+
for epoch in range(epoch_str, hps.train.epochs + 1):
|
286 |
+
if rank == 0:
|
287 |
+
train_and_evaluate(
|
288 |
+
rank,
|
289 |
+
epoch,
|
290 |
+
hps,
|
291 |
+
[net_g, net_d],
|
292 |
+
[optim_g, optim_d],
|
293 |
+
[scheduler_g, scheduler_d],
|
294 |
+
scaler,
|
295 |
+
# [train_loader, eval_loader], logger, [writer, writer_eval])
|
296 |
+
[train_loader, None],
|
297 |
+
logger,
|
298 |
+
[writer, writer_eval],
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
train_and_evaluate(
|
302 |
+
rank,
|
303 |
+
epoch,
|
304 |
+
hps,
|
305 |
+
[net_g, net_d],
|
306 |
+
[optim_g, optim_d],
|
307 |
+
[scheduler_g, scheduler_d],
|
308 |
+
scaler,
|
309 |
+
[train_loader, None],
|
310 |
+
None,
|
311 |
+
None,
|
312 |
+
)
|
313 |
+
scheduler_g.step()
|
314 |
+
scheduler_d.step()
|
315 |
+
print("training done")
|
316 |
+
|
317 |
+
|
318 |
+
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
|
319 |
+
net_g, net_d = nets
|
320 |
+
optim_g, optim_d = optims
|
321 |
+
# scheduler_g, scheduler_d = schedulers
|
322 |
+
train_loader, eval_loader = loaders
|
323 |
+
if writers is not None:
|
324 |
+
writer, writer_eval = writers
|
325 |
+
|
326 |
+
train_loader.batch_sampler.set_epoch(epoch)
|
327 |
+
global global_step
|
328 |
+
|
329 |
+
net_g.train()
|
330 |
+
net_d.train()
|
331 |
+
for batch_idx, (
|
332 |
+
ssl,
|
333 |
+
ssl_lengths,
|
334 |
+
spec,
|
335 |
+
spec_lengths,
|
336 |
+
y,
|
337 |
+
y_lengths,
|
338 |
+
text,
|
339 |
+
text_lengths,
|
340 |
+
) in enumerate(tqdm(train_loader)):
|
341 |
+
if torch.cuda.is_available():
|
342 |
+
spec, spec_lengths = (
|
343 |
+
spec.cuda(
|
344 |
+
rank,
|
345 |
+
non_blocking=True,
|
346 |
+
),
|
347 |
+
spec_lengths.cuda(
|
348 |
+
rank,
|
349 |
+
non_blocking=True,
|
350 |
+
),
|
351 |
+
)
|
352 |
+
y, y_lengths = (
|
353 |
+
y.cuda(
|
354 |
+
rank,
|
355 |
+
non_blocking=True,
|
356 |
+
),
|
357 |
+
y_lengths.cuda(
|
358 |
+
rank,
|
359 |
+
non_blocking=True,
|
360 |
+
),
|
361 |
+
)
|
362 |
+
ssl = ssl.cuda(rank, non_blocking=True)
|
363 |
+
ssl.requires_grad = False
|
364 |
+
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
365 |
+
text, text_lengths = (
|
366 |
+
text.cuda(
|
367 |
+
rank,
|
368 |
+
non_blocking=True,
|
369 |
+
),
|
370 |
+
text_lengths.cuda(
|
371 |
+
rank,
|
372 |
+
non_blocking=True,
|
373 |
+
),
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
377 |
+
y, y_lengths = y.to(device), y_lengths.to(device)
|
378 |
+
ssl = ssl.to(device)
|
379 |
+
ssl.requires_grad = False
|
380 |
+
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
381 |
+
text, text_lengths = text.to(device), text_lengths.to(device)
|
382 |
+
|
383 |
+
with autocast(enabled=hps.train.fp16_run):
|
384 |
+
(
|
385 |
+
y_hat,
|
386 |
+
kl_ssl,
|
387 |
+
ids_slice,
|
388 |
+
x_mask,
|
389 |
+
z_mask,
|
390 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
391 |
+
stats_ssl,
|
392 |
+
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
393 |
+
|
394 |
+
mel = spec_to_mel_torch(
|
395 |
+
spec,
|
396 |
+
hps.data.filter_length,
|
397 |
+
hps.data.n_mel_channels,
|
398 |
+
hps.data.sampling_rate,
|
399 |
+
hps.data.mel_fmin,
|
400 |
+
hps.data.mel_fmax,
|
401 |
+
)
|
402 |
+
y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
|
403 |
+
y_hat_mel = mel_spectrogram_torch(
|
404 |
+
y_hat.squeeze(1),
|
405 |
+
hps.data.filter_length,
|
406 |
+
hps.data.n_mel_channels,
|
407 |
+
hps.data.sampling_rate,
|
408 |
+
hps.data.hop_length,
|
409 |
+
hps.data.win_length,
|
410 |
+
hps.data.mel_fmin,
|
411 |
+
hps.data.mel_fmax,
|
412 |
+
)
|
413 |
+
|
414 |
+
y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
|
415 |
+
|
416 |
+
# Discriminator
|
417 |
+
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
|
418 |
+
with autocast(enabled=False):
|
419 |
+
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
|
420 |
+
y_d_hat_r,
|
421 |
+
y_d_hat_g,
|
422 |
+
)
|
423 |
+
loss_disc_all = loss_disc
|
424 |
+
optim_d.zero_grad()
|
425 |
+
scaler.scale(loss_disc_all).backward()
|
426 |
+
scaler.unscale_(optim_d)
|
427 |
+
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
|
428 |
+
scaler.step(optim_d)
|
429 |
+
|
430 |
+
with autocast(enabled=hps.train.fp16_run):
|
431 |
+
# Generator
|
432 |
+
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
|
433 |
+
with autocast(enabled=False):
|
434 |
+
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
|
435 |
+
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
|
436 |
+
|
437 |
+
loss_fm = feature_loss(fmap_r, fmap_g)
|
438 |
+
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
439 |
+
loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
|
440 |
+
|
441 |
+
optim_g.zero_grad()
|
442 |
+
scaler.scale(loss_gen_all).backward()
|
443 |
+
scaler.unscale_(optim_g)
|
444 |
+
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
|
445 |
+
scaler.step(optim_g)
|
446 |
+
scaler.update()
|
447 |
+
|
448 |
+
if rank == 0:
|
449 |
+
if global_step % hps.train.log_interval == 0:
|
450 |
+
lr = optim_g.param_groups[0]["lr"]
|
451 |
+
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
|
452 |
+
logger.info(
|
453 |
+
"Train Epoch: {} [{:.0f}%]".format(
|
454 |
+
epoch,
|
455 |
+
100.0 * batch_idx / len(train_loader),
|
456 |
+
)
|
457 |
+
)
|
458 |
+
logger.info([x.item() for x in losses] + [global_step, lr])
|
459 |
+
|
460 |
+
scalar_dict = {
|
461 |
+
"loss/g/total": loss_gen_all,
|
462 |
+
"loss/d/total": loss_disc_all,
|
463 |
+
"learning_rate": lr,
|
464 |
+
"grad_norm_d": grad_norm_d,
|
465 |
+
"grad_norm_g": grad_norm_g,
|
466 |
+
}
|
467 |
+
scalar_dict.update(
|
468 |
+
{
|
469 |
+
"loss/g/fm": loss_fm,
|
470 |
+
"loss/g/mel": loss_mel,
|
471 |
+
"loss/g/kl_ssl": kl_ssl,
|
472 |
+
"loss/g/kl": loss_kl,
|
473 |
+
}
|
474 |
+
)
|
475 |
+
|
476 |
+
# scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
|
477 |
+
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
|
478 |
+
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
|
479 |
+
image_dict = None
|
480 |
+
try: ###Some people installed the wrong version of matplotlib.
|
481 |
+
image_dict = {
|
482 |
+
"slice/mel_org": utils.plot_spectrogram_to_numpy(
|
483 |
+
y_mel[0].data.cpu().numpy(),
|
484 |
+
),
|
485 |
+
"slice/mel_gen": utils.plot_spectrogram_to_numpy(
|
486 |
+
y_hat_mel[0].data.cpu().numpy(),
|
487 |
+
),
|
488 |
+
"all/mel": utils.plot_spectrogram_to_numpy(
|
489 |
+
mel[0].data.cpu().numpy(),
|
490 |
+
),
|
491 |
+
"all/stats_ssl": utils.plot_spectrogram_to_numpy(
|
492 |
+
stats_ssl[0].data.cpu().numpy(),
|
493 |
+
),
|
494 |
+
}
|
495 |
+
except:
|
496 |
+
pass
|
497 |
+
if image_dict:
|
498 |
+
utils.summarize(
|
499 |
+
writer=writer,
|
500 |
+
global_step=global_step,
|
501 |
+
images=image_dict,
|
502 |
+
scalars=scalar_dict,
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
utils.summarize(
|
506 |
+
writer=writer,
|
507 |
+
global_step=global_step,
|
508 |
+
scalars=scalar_dict,
|
509 |
+
)
|
510 |
+
global_step += 1
|
511 |
+
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
|
512 |
+
if hps.train.if_save_latest == 0:
|
513 |
+
utils.save_checkpoint(
|
514 |
+
net_g,
|
515 |
+
optim_g,
|
516 |
+
hps.train.learning_rate,
|
517 |
+
epoch,
|
518 |
+
os.path.join(
|
519 |
+
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
520 |
+
"G_{}.pth".format(global_step),
|
521 |
+
),
|
522 |
+
)
|
523 |
+
utils.save_checkpoint(
|
524 |
+
net_d,
|
525 |
+
optim_d,
|
526 |
+
hps.train.learning_rate,
|
527 |
+
epoch,
|
528 |
+
os.path.join(
|
529 |
+
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
530 |
+
"D_{}.pth".format(global_step),
|
531 |
+
),
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
utils.save_checkpoint(
|
535 |
+
net_g,
|
536 |
+
optim_g,
|
537 |
+
hps.train.learning_rate,
|
538 |
+
epoch,
|
539 |
+
os.path.join(
|
540 |
+
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
541 |
+
"G_{}.pth".format(233333333333),
|
542 |
+
),
|
543 |
+
)
|
544 |
+
utils.save_checkpoint(
|
545 |
+
net_d,
|
546 |
+
optim_d,
|
547 |
+
hps.train.learning_rate,
|
548 |
+
epoch,
|
549 |
+
os.path.join(
|
550 |
+
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
|
551 |
+
"D_{}.pth".format(233333333333),
|
552 |
+
),
|
553 |
+
)
|
554 |
+
if rank == 0 and hps.train.if_save_every_weights == True:
|
555 |
+
if hasattr(net_g, "module"):
|
556 |
+
ckpt = net_g.module.state_dict()
|
557 |
+
else:
|
558 |
+
ckpt = net_g.state_dict()
|
559 |
+
logger.info(
|
560 |
+
"saving ckpt %s_e%s:%s"
|
561 |
+
% (
|
562 |
+
hps.name,
|
563 |
+
epoch,
|
564 |
+
savee(
|
565 |
+
ckpt,
|
566 |
+
hps.name + "_e%s_s%s" % (epoch, global_step),
|
567 |
+
epoch,
|
568 |
+
global_step,
|
569 |
+
hps,
|
570 |
+
),
|
571 |
+
)
|
572 |
+
)
|
573 |
+
|
574 |
+
if rank == 0:
|
575 |
+
logger.info("====> Epoch: {}".format(epoch))
|
576 |
+
|
577 |
+
|
578 |
+
def evaluate(hps, generator, eval_loader, writer_eval):
|
579 |
+
generator.eval()
|
580 |
+
image_dict = {}
|
581 |
+
audio_dict = {}
|
582 |
+
print("Evaluating ...")
|
583 |
+
with torch.no_grad():
|
584 |
+
for batch_idx, (
|
585 |
+
ssl,
|
586 |
+
ssl_lengths,
|
587 |
+
spec,
|
588 |
+
spec_lengths,
|
589 |
+
y,
|
590 |
+
y_lengths,
|
591 |
+
text,
|
592 |
+
text_lengths,
|
593 |
+
) in enumerate(eval_loader):
|
594 |
+
print(111)
|
595 |
+
if torch.cuda.is_available():
|
596 |
+
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
597 |
+
y, y_lengths = y.cuda(), y_lengths.cuda()
|
598 |
+
ssl = ssl.cuda()
|
599 |
+
text, text_lengths = text.cuda(), text_lengths.cuda()
|
600 |
+
else:
|
601 |
+
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
602 |
+
y, y_lengths = y.to(device), y_lengths.to(device)
|
603 |
+
ssl = ssl.to(device)
|
604 |
+
text, text_lengths = text.to(device), text_lengths.to(device)
|
605 |
+
for test in [0, 1]:
|
606 |
+
y_hat, mask, *_ = (
|
607 |
+
generator.module.infer(
|
608 |
+
ssl,
|
609 |
+
spec,
|
610 |
+
spec_lengths,
|
611 |
+
text,
|
612 |
+
text_lengths,
|
613 |
+
test=test,
|
614 |
+
)
|
615 |
+
if torch.cuda.is_available()
|
616 |
+
else generator.infer(
|
617 |
+
ssl,
|
618 |
+
spec,
|
619 |
+
spec_lengths,
|
620 |
+
text,
|
621 |
+
text_lengths,
|
622 |
+
test=test,
|
623 |
+
)
|
624 |
+
)
|
625 |
+
y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
|
626 |
+
|
627 |
+
mel = spec_to_mel_torch(
|
628 |
+
spec,
|
629 |
+
hps.data.filter_length,
|
630 |
+
hps.data.n_mel_channels,
|
631 |
+
hps.data.sampling_rate,
|
632 |
+
hps.data.mel_fmin,
|
633 |
+
hps.data.mel_fmax,
|
634 |
+
)
|
635 |
+
y_hat_mel = mel_spectrogram_torch(
|
636 |
+
y_hat.squeeze(1).float(),
|
637 |
+
hps.data.filter_length,
|
638 |
+
hps.data.n_mel_channels,
|
639 |
+
hps.data.sampling_rate,
|
640 |
+
hps.data.hop_length,
|
641 |
+
hps.data.win_length,
|
642 |
+
hps.data.mel_fmin,
|
643 |
+
hps.data.mel_fmax,
|
644 |
+
)
|
645 |
+
image_dict.update(
|
646 |
+
{
|
647 |
+
f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
|
648 |
+
y_hat_mel[0].cpu().numpy(),
|
649 |
+
),
|
650 |
+
}
|
651 |
+
)
|
652 |
+
audio_dict.update(
|
653 |
+
{
|
654 |
+
f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
|
655 |
+
},
|
656 |
+
)
|
657 |
+
image_dict.update(
|
658 |
+
{
|
659 |
+
f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
|
660 |
+
},
|
661 |
+
)
|
662 |
+
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
|
663 |
+
|
664 |
+
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
|
665 |
+
# audio_dict.update({
|
666 |
+
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
|
667 |
+
# })
|
668 |
+
|
669 |
+
utils.summarize(
|
670 |
+
writer=writer_eval,
|
671 |
+
global_step=global_step,
|
672 |
+
images=image_dict,
|
673 |
+
audios=audio_dict,
|
674 |
+
audio_sampling_rate=hps.data.sampling_rate,
|
675 |
+
)
|
676 |
+
generator.train()
|
677 |
+
|
678 |
+
|
679 |
+
if __name__ == "__main__":
|
680 |
+
main()
|