kevinwang676 commited on
Commit
1b03734
·
verified ·
1 Parent(s): 8eb38c9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +8 -0
  2. .gitignore +200 -0
  3. Colab-Inference.ipynb +184 -0
  4. Dockerfile +42 -0
  5. GPT_SoVITS/AR/__init__.py +0 -0
  6. GPT_SoVITS/AR/data/__init__.py +0 -0
  7. GPT_SoVITS/AR/data/bucket_sampler.py +149 -0
  8. GPT_SoVITS/AR/data/data_module.py +81 -0
  9. GPT_SoVITS/AR/data/dataset.py +320 -0
  10. GPT_SoVITS/AR/models/__init__.py +0 -0
  11. GPT_SoVITS/AR/models/t2s_lightning_module.py +145 -0
  12. GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py +110 -0
  13. GPT_SoVITS/AR/models/t2s_model.py +935 -0
  14. GPT_SoVITS/AR/models/t2s_model_onnx.py +394 -0
  15. GPT_SoVITS/AR/models/utils.py +282 -0
  16. GPT_SoVITS/AR/modules/__init__.py +0 -0
  17. GPT_SoVITS/AR/modules/activation.py +413 -0
  18. GPT_SoVITS/AR/modules/activation_onnx.py +188 -0
  19. GPT_SoVITS/AR/modules/embedding.py +78 -0
  20. GPT_SoVITS/AR/modules/embedding_onnx.py +63 -0
  21. GPT_SoVITS/AR/modules/lr_schedulers.py +85 -0
  22. GPT_SoVITS/AR/modules/optim.py +593 -0
  23. GPT_SoVITS/AR/modules/patched_mha_with_cache.py +428 -0
  24. GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py +85 -0
  25. GPT_SoVITS/AR/modules/scaling.py +320 -0
  26. GPT_SoVITS/AR/modules/transformer.py +362 -0
  27. GPT_SoVITS/AR/modules/transformer_onnx.py +281 -0
  28. GPT_SoVITS/AR/text_processing/__init__.py +0 -0
  29. GPT_SoVITS/AR/text_processing/phonemizer.py +72 -0
  30. GPT_SoVITS/AR/text_processing/symbols.py +12 -0
  31. GPT_SoVITS/AR/utils/__init__.py +36 -0
  32. GPT_SoVITS/AR/utils/initialize.py +39 -0
  33. GPT_SoVITS/AR/utils/io.py +30 -0
  34. GPT_SoVITS/download.py +13 -0
  35. GPT_SoVITS/export_torch_script.py +861 -0
  36. GPT_SoVITS/export_torch_script_v3.py +1035 -0
  37. GPT_SoVITS/inference_cli.py +86 -0
  38. GPT_SoVITS/inference_gui.py +316 -0
  39. GPT_SoVITS/inference_webui.py +1281 -0
  40. GPT_SoVITS/inference_webui_fast.py +546 -0
  41. GPT_SoVITS/onnx_export.py +398 -0
  42. GPT_SoVITS/prepare_datasets/1-get-text.py +143 -0
  43. GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +134 -0
  44. GPT_SoVITS/prepare_datasets/3-get-semantic.py +118 -0
  45. GPT_SoVITS/pretrained_models/.gitignore +2 -0
  46. GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt +3 -0
  47. GPT_SoVITS/pretrained_models/s1v3.ckpt +3 -0
  48. GPT_SoVITS/process_ckpt.py +124 -0
  49. GPT_SoVITS/s1_train.py +171 -0
  50. GPT_SoVITS/s2_train.py +680 -0
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ docs
2
+ logs
3
+ output
4
+ reference
5
+ SoVITS_weights
6
+ GPT_weights
7
+ TEMP
8
+ .git
.gitignore ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .vscode
3
+ __pycache__
4
+ *.pyc
5
+ env
6
+ runtime
7
+ .idea
8
+ output
9
+ logs
10
+ reference
11
+ GPT_weights
12
+ SoVITS_weights
13
+ GPT_weights_v2
14
+ SoVITS_weights_v2
15
+ GPT_weights_v3
16
+ SoVITS_weights_v3
17
+ TEMP
18
+ weight.json
19
+ ffmpeg*
20
+ ffprobe*
21
+ cfg.json
22
+ speakers.json
23
+ ref_audios
24
+ tools/AP_BWE_main/24kto48k/*
25
+ !tools/AP_BWE_main/24kto48k/readme.txt
26
+
27
+ # Byte-compiled / optimized / DLL files
28
+ __pycache__/
29
+ *.py[cod]
30
+ *$py.class
31
+
32
+ # C extensions
33
+ *.so
34
+
35
+ # Distribution / packaging
36
+ .Python
37
+ build/
38
+ develop-eggs/
39
+ dist/
40
+ downloads/
41
+ eggs/
42
+ .eggs/
43
+ lib/
44
+ lib64/
45
+ parts/
46
+ sdist/
47
+ var/
48
+ wheels/
49
+ share/python-wheels/
50
+ *.egg-info/
51
+ .installed.cfg
52
+ *.egg
53
+ MANIFEST
54
+
55
+ # PyInstaller
56
+ # Usually these files are written by a python script from a template
57
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
58
+ *.manifest
59
+ *.spec
60
+
61
+ # Installer logs
62
+ pip-log.txt
63
+ pip-delete-this-directory.txt
64
+
65
+ # Unit test / coverage reports
66
+ htmlcov/
67
+ .tox/
68
+ .nox/
69
+ .coverage
70
+ .coverage.*
71
+ .cache
72
+ nosetests.xml
73
+ coverage.xml
74
+ *.cover
75
+ *.py,cover
76
+ .hypothesis/
77
+ .pytest_cache/
78
+ cover/
79
+
80
+ # Translations
81
+ *.mo
82
+ *.pot
83
+
84
+ # Django stuff:
85
+ *.log
86
+ local_settings.py
87
+ db.sqlite3
88
+ db.sqlite3-journal
89
+
90
+ # Flask stuff:
91
+ instance/
92
+ .webassets-cache
93
+
94
+ # Scrapy stuff:
95
+ .scrapy
96
+
97
+ # Sphinx documentation
98
+ docs/_build/
99
+
100
+ # PyBuilder
101
+ .pybuilder/
102
+ target/
103
+
104
+ # Jupyter Notebook
105
+ .ipynb_checkpoints
106
+
107
+ # IPython
108
+ profile_default/
109
+ ipython_config.py
110
+
111
+ # pyenv
112
+ # For a library or package, you might want to ignore these files since the code is
113
+ # intended to run in multiple environments; otherwise, check them in:
114
+ # .python-version
115
+
116
+ # pipenv
117
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
119
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
120
+ # install all needed dependencies.
121
+ #Pipfile.lock
122
+
123
+ # UV
124
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
125
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
126
+ # commonly ignored for libraries.
127
+ #uv.lock
128
+
129
+ # poetry
130
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
131
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
132
+ # commonly ignored for libraries.
133
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
134
+ #poetry.lock
135
+
136
+ # pdm
137
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
138
+ #pdm.lock
139
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
140
+ # in version control.
141
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
142
+ .pdm.toml
143
+ .pdm-python
144
+ .pdm-build/
145
+
146
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
147
+ __pypackages__/
148
+
149
+ # Celery stuff
150
+ celerybeat-schedule
151
+ celerybeat.pid
152
+
153
+ # SageMath parsed files
154
+ *.sage.py
155
+
156
+ # Environments
157
+ .env
158
+ .venv
159
+ env/
160
+ venv/
161
+ ENV/
162
+ env.bak/
163
+ venv.bak/
164
+
165
+ # Spyder project settings
166
+ .spyderproject
167
+ .spyproject
168
+
169
+ # Rope project settings
170
+ .ropeproject
171
+
172
+ # mkdocs documentation
173
+ /site
174
+
175
+ # mypy
176
+ .mypy_cache/
177
+ .dmypy.json
178
+ dmypy.json
179
+
180
+ # Pyre type checker
181
+ .pyre/
182
+
183
+ # pytype static type analyzer
184
+ .pytype/
185
+
186
+ # Cython debug symbols
187
+ cython_debug/
188
+
189
+ # PyCharm
190
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
191
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
192
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
193
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
194
+ #.idea/
195
+
196
+ # Ruff stuff:
197
+ .ruff_cache/
198
+
199
+ # PyPI configuration file
200
+ .pypirc
Colab-Inference.ipynb ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# GPT-SoVITS Infer"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "## Env Setup (Run Once Only)\n",
15
+ "## 环境配置, 只需运行一次"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "### 1."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {
29
+ "id": "e9b7iFV3dm1f"
30
+ },
31
+ "outputs": [],
32
+ "source": [
33
+ "%%writefile /content/setup.sh\n",
34
+ "set -e\n",
35
+ "\n",
36
+ "cd /content\n",
37
+ "\n",
38
+ "git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
39
+ "\n",
40
+ "cd GPT-SoVITS\n",
41
+ "\n",
42
+ "mkdir GPT_weights\n",
43
+ "\n",
44
+ "mkdir SoVITS_weights\n",
45
+ "\n",
46
+ "if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
47
+ " :\n",
48
+ "else\n",
49
+ " conda create -n GPTSoVITS python=3.10 -y\n",
50
+ "fi\n",
51
+ "\n",
52
+ "source activate GPTSoVITS\n",
53
+ "\n",
54
+ "pip install ipykernel\n",
55
+ "\n",
56
+ "bash install.sh --source HF"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "### 2."
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {
70
+ "cellView": "form",
71
+ "id": "0NgxXg5sjv7z"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "%pip install -q condacolab\n",
76
+ "import condacolab\n",
77
+ "condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
78
+ "!cd /content && bash setup.sh"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "# Download Model"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {},
91
+ "source": [
92
+ "### Download From HuggingFace"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {
99
+ "cellView": "form",
100
+ "id": "vbZY-LnM0tzq"
101
+ },
102
+ "outputs": [],
103
+ "source": [
104
+ "# Modify These\n",
105
+ "USER_ID = \"AkitoP\"\n",
106
+ "REPO_NAME = \"GPT-SoVITS-v2-aegi\"\n",
107
+ "BRANCH = \"main\"\n",
108
+ "GPT_PATH = \"new_aegigoe-e100.ckpt\"\n",
109
+ "SOVITS_PATH = \"new_aegigoe_e60_s32220.pth\"\n",
110
+ "\n",
111
+ "# Do Not Modify\n",
112
+ "HF_BASE = \"https://huggingface.co\"\n",
113
+ "REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
114
+ "GPT_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{GPT_PATH}\"\n",
115
+ "SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{SOVITS_PATH}\"\n",
116
+ "\n",
117
+ "!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
118
+ "!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\"\n"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "### Download From ModelScope"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "# Modify These\n",
135
+ "USER_ID = \"aihobbyist\"\n",
136
+ "REPO_NAME = \"GPT-SoVits-V2-models\"\n",
137
+ "BRANCH = \"master\"\n",
138
+ "GPT_PATH = \"Genshin_Impact/EN/GPT_GenshinImpact_EN_5.1.ckpt\"\n",
139
+ "SOVITS_PATH = \"Wuthering_Waves/CN/SV_WutheringWaves_CN_1.3.pth\"\n",
140
+ "\n",
141
+ "# Do Not Modify\n",
142
+ "HF_BASE = \"https://www.modelscope.cn/models\"\n",
143
+ "REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
144
+ "GPT_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{GPT_PATH}\"\n",
145
+ "SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{SOVITS_PATH}\"\n",
146
+ "\n",
147
+ "!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
148
+ "!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\""
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "metadata": {},
154
+ "source": [
155
+ "# Launch WebUI\n",
156
+ "# 启动 WebUI"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {
163
+ "cellView": "form",
164
+ "id": "4oRGUzkrk8C7"
165
+ },
166
+ "outputs": [],
167
+ "source": [
168
+ "!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
169
+ ]
170
+ }
171
+ ],
172
+ "metadata": {
173
+ "accelerator": "GPU",
174
+ "colab": {
175
+ "provenance": []
176
+ },
177
+ "kernelspec": {
178
+ "display_name": "Python 3",
179
+ "name": "python3"
180
+ }
181
+ },
182
+ "nbformat": 4,
183
+ "nbformat_minor": 0
184
+ }
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base CUDA image
2
+ FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
3
+
4
+ LABEL maintainer="[email protected]"
5
+ LABEL version="dev-20240209"
6
+ LABEL description="Docker image for GPT-SoVITS"
7
+
8
+
9
+ # Install 3rd party apps
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+ ENV TZ=Etc/UTC
12
+ RUN apt-get update && \
13
+ apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
14
+ git lfs install && \
15
+ rm -rf /var/lib/apt/lists/*
16
+
17
+ # Copy only requirements.txt initially to leverage Docker cache
18
+ WORKDIR /workspace
19
+ COPY requirements.txt /workspace/
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Define a build-time argument for image type
23
+ ARG IMAGE_TYPE=full
24
+
25
+ # Conditional logic based on the IMAGE_TYPE argument
26
+ # Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
27
+ COPY ./Docker /workspace/Docker
28
+ # elite 类型的镜像里面不包含额外的模型
29
+ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
30
+ chmod +x /workspace/Docker/download.sh && \
31
+ /workspace/Docker/download.sh && \
32
+ python /workspace/Docker/download.py && \
33
+ python -m nltk.downloader averaged_perceptron_tagger cmudict; \
34
+ fi
35
+
36
+
37
+ # Copy the rest of the application
38
+ COPY . /workspace
39
+
40
+ EXPOSE 9871 9872 9873 9874 9880
41
+
42
+ CMD ["python", "webui.py"]
GPT_SoVITS/AR/__init__.py ADDED
File without changes
GPT_SoVITS/AR/data/__init__.py ADDED
File without changes
GPT_SoVITS/AR/data/bucket_sampler.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/bucket_sampler.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import math
5
+ import random
6
+ from random import shuffle
7
+ from typing import Iterator, Optional, TypeVar
8
+
9
+ import torch
10
+ import torch.distributed as dist
11
+ from torch.utils.data import Dataset, Sampler
12
+
13
+ __all__ = [
14
+ "DistributedBucketSampler",
15
+ ]
16
+
17
+ T_co = TypeVar("T_co", covariant=True)
18
+
19
+
20
+ class DistributedBucketSampler(Sampler[T_co]):
21
+ r"""
22
+ sort the dataset wrt. input length
23
+ divide samples into buckets
24
+ sort within buckets
25
+ divide buckets into batches
26
+ sort batches
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dataset: Dataset,
32
+ num_replicas: Optional[int] = None,
33
+ rank: Optional[int] = None,
34
+ shuffle: bool = True,
35
+ seed: int = 0,
36
+ drop_last: bool = False,
37
+ batch_size: int = 32,
38
+ ) -> None:
39
+ if num_replicas is None:
40
+ if not dist.is_available():
41
+ raise RuntimeError("Requires distributed package to be available")
42
+ num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
43
+ if rank is None:
44
+ if not dist.is_available():
45
+ raise RuntimeError("Requires distributed package to be available")
46
+ rank = dist.get_rank() if torch.cuda.is_available() else 0
47
+ if torch.cuda.is_available():
48
+ torch.cuda.set_device(rank)
49
+ if rank >= num_replicas or rank < 0:
50
+ raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
51
+ self.dataset = dataset
52
+ self.num_replicas = num_replicas
53
+ self.rank = rank
54
+ self.epoch = 0
55
+ self.drop_last = drop_last
56
+ # If the dataset length is evenly divisible by # of replicas, then there
57
+ # is no need to drop any data, since the dataset will be split equally.
58
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
59
+ # Split to nearest available length that is evenly divisible.
60
+ # This is to ensure each rank receives the same amount of data when
61
+ # using this Sampler.
62
+ self.num_samples = math.ceil(
63
+ (len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
64
+ )
65
+ else:
66
+ self.num_samples = math.ceil(
67
+ len(self.dataset) / self.num_replicas,
68
+ ) # type: ignore[arg-type]
69
+ self.total_size = self.num_samples * self.num_replicas
70
+ self.shuffle = shuffle
71
+ self.seed = seed
72
+ self.batch_size = batch_size
73
+ self.id_with_length = self._get_sample_lengths()
74
+ self.id_buckets = self.make_buckets(bucket_width=2.0)
75
+
76
+ def _get_sample_lengths(self):
77
+ id_with_lengths = []
78
+ for i in range(len(self.dataset)):
79
+ id_with_lengths.append((i, self.dataset.get_sample_length(i)))
80
+ id_with_lengths.sort(key=lambda x: x[1])
81
+ return id_with_lengths
82
+
83
+ def make_buckets(self, bucket_width: float = 2.0):
84
+ buckets = []
85
+ cur = []
86
+ max_sec = bucket_width
87
+ for id, sec in self.id_with_length:
88
+ if sec < max_sec:
89
+ cur.append(id)
90
+ else:
91
+ buckets.append(cur)
92
+ cur = [id]
93
+ max_sec += bucket_width
94
+ if len(cur) > 0:
95
+ buckets.append(cur)
96
+ return buckets
97
+
98
+ def __iter__(self) -> Iterator[T_co]:
99
+ if self.shuffle:
100
+ # deterministically shuffle based on epoch and seed
101
+ g = torch.Generator()
102
+ g.manual_seed(self.seed + self.epoch)
103
+ random.seed(self.epoch + self.seed)
104
+ shuffled_bucket = []
105
+ for buc in self.id_buckets:
106
+ buc_copy = buc.copy()
107
+ shuffle(buc_copy)
108
+ shuffled_bucket.append(buc_copy)
109
+ grouped_batch_size = self.batch_size * self.num_replicas
110
+ shuffled_bucket = list(itertools.chain(*shuffled_bucket))
111
+ n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
112
+ batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
113
+ shuffle(batches)
114
+ indices = list(itertools.chain(*batches))
115
+ else:
116
+ # type: ignore[arg-type]
117
+ indices = list(range(len(self.dataset)))
118
+
119
+ if not self.drop_last:
120
+ # add extra samples to make it evenly divisible
121
+ padding_size = self.total_size - len(indices)
122
+ if padding_size <= len(indices):
123
+ indices += indices[:padding_size]
124
+ else:
125
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
126
+ else:
127
+ # remove tail of data to make it evenly divisible.
128
+ indices = indices[: self.total_size]
129
+ assert len(indices) == self.total_size
130
+
131
+ # subsample
132
+ indices = indices[self.rank : self.total_size : self.num_replicas]
133
+ assert len(indices) == self.num_samples
134
+
135
+ return iter(indices)
136
+
137
+ def __len__(self) -> int:
138
+ return self.num_samples
139
+
140
+ def set_epoch(self, epoch: int) -> None:
141
+ r"""
142
+ Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
143
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
144
+ sampler will yield the same ordering.
145
+
146
+ Args:
147
+ epoch (int): Epoch number.
148
+ """
149
+ self.epoch = epoch
GPT_SoVITS/AR/data/data_module.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from pytorch_lightning import LightningDataModule
4
+ from torch.utils.data import DataLoader
5
+
6
+ from AR.data.bucket_sampler import DistributedBucketSampler
7
+ from AR.data.dataset import Text2SemanticDataset
8
+
9
+
10
+ class Text2SemanticDataModule(LightningDataModule):
11
+ def __init__(
12
+ self,
13
+ config,
14
+ train_semantic_path,
15
+ train_phoneme_path,
16
+ dev_semantic_path=None,
17
+ dev_phoneme_path=None,
18
+ ):
19
+ super().__init__()
20
+ self.config = config
21
+ self.train_semantic_path = train_semantic_path
22
+ self.train_phoneme_path = train_phoneme_path
23
+ self.dev_semantic_path = dev_semantic_path
24
+ self.dev_phoneme_path = dev_phoneme_path
25
+ self.num_workers = self.config["data"]["num_workers"]
26
+
27
+ def prepare_data(self):
28
+ pass
29
+
30
+ def setup(self, stage=None, output_logs=False):
31
+ self._train_dataset = Text2SemanticDataset(
32
+ phoneme_path=self.train_phoneme_path,
33
+ semantic_path=self.train_semantic_path,
34
+ max_sec=self.config["data"]["max_sec"],
35
+ pad_val=self.config["data"]["pad_val"],
36
+ )
37
+ self._dev_dataset = self._train_dataset
38
+ # self._dev_dataset = Text2SemanticDataset(
39
+ # phoneme_path=self.dev_phoneme_path,
40
+ # semantic_path=self.dev_semantic_path,
41
+ # max_sample=self.config['data']['max_eval_sample'],
42
+ # max_sec=self.config['data']['max_sec'],
43
+ # pad_val=self.config['data']['pad_val'])
44
+
45
+ def train_dataloader(self):
46
+ batch_size = (
47
+ self.config["train"]["batch_size"] // 2
48
+ if self.config["train"].get("if_dpo", False) is True
49
+ else self.config["train"]["batch_size"]
50
+ )
51
+ batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
52
+ sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
53
+ return DataLoader(
54
+ self._train_dataset,
55
+ batch_size=batch_size,
56
+ sampler=sampler,
57
+ collate_fn=self._train_dataset.collate,
58
+ num_workers=self.num_workers,
59
+ persistent_workers=True,
60
+ prefetch_factor=16,
61
+ )
62
+
63
+ def val_dataloader(self):
64
+ return DataLoader(
65
+ self._dev_dataset,
66
+ batch_size=1,
67
+ shuffle=False,
68
+ collate_fn=self._train_dataset.collate,
69
+ num_workers=max(self.num_workers, 12),
70
+ persistent_workers=True,
71
+ prefetch_factor=16,
72
+ )
73
+
74
+ # 这个会使用到嘛?
75
+ def test_dataloader(self):
76
+ return DataLoader(
77
+ self._dev_dataset,
78
+ batch_size=1,
79
+ shuffle=False,
80
+ collate_fn=self._train_dataset.collate,
81
+ )
GPT_SoVITS/AR/data/dataset.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+
4
+ # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
5
+ import os
6
+ import traceback
7
+ from typing import Dict, List
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from torch.utils.data import DataLoader, Dataset
13
+
14
+ version = os.environ.get("version", None)
15
+
16
+ from text import cleaned_text_to_sequence
17
+
18
+ # from config import exp_dir
19
+
20
+
21
+ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0):
22
+ seq = sequences[0]
23
+ ndim = seq.ndim
24
+ if axis < 0:
25
+ axis += ndim
26
+ dtype = seq.dtype
27
+ pad_value = dtype.type(pad_value)
28
+ seq_lengths = [seq.shape[axis] for seq in sequences]
29
+ max_length = np.max(seq_lengths)
30
+
31
+ padded_sequences = []
32
+ for seq, length in zip(sequences, seq_lengths):
33
+ padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
34
+ padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
35
+ padded_sequences.append(padded_seq)
36
+ batch = np.stack(padded_sequences)
37
+ return batch
38
+
39
+
40
+ class Text2SemanticDataset(Dataset):
41
+ """dataset class for text tokens to semantic model training."""
42
+
43
+ def __init__(
44
+ self,
45
+ phoneme_path: str,
46
+ semantic_path: str,
47
+ max_sample: int = None,
48
+ max_sec: int = 100,
49
+ pad_val: int = 1024,
50
+ # min value of phoneme/sec
51
+ min_ps_ratio: int = 3,
52
+ # max value of phoneme/sec
53
+ max_ps_ratio: int = 25,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.semantic_data = pd.read_csv(
58
+ semantic_path,
59
+ delimiter="\t",
60
+ encoding="utf-8",
61
+ )
62
+ # get dict
63
+ self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
64
+ self.path3 = "%s/3-bert" % (
65
+ os.path.dirname(
66
+ phoneme_path,
67
+ )
68
+ ) # "%s/3-bert"%exp_dir#bert_dir
69
+ self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
70
+ assert os.path.exists(self.path2)
71
+ assert os.path.exists(self.path6)
72
+ self.phoneme_data = {}
73
+ with open(self.path2, "r", encoding="utf8") as f:
74
+ lines = f.read().strip("\n").split("\n")
75
+
76
+ for line in lines:
77
+ tmp = line.split("\t")
78
+ if len(tmp) != 4:
79
+ continue
80
+ self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]]
81
+
82
+ # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item()
83
+ # pad for semantic tokens
84
+ self.PAD: int = pad_val
85
+ # self.hz = 25
86
+ # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read()
87
+ # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz
88
+ # self.hz=int(data[:-2])#
89
+ self.hz = int(os.environ.get("hz", "25hz")[:-2])
90
+
91
+ # max seconds of semantic token
92
+ self.max_sec = max_sec
93
+ self.min_ps_ratio = min_ps_ratio
94
+ self.max_ps_ratio = max_ps_ratio
95
+
96
+ if max_sample is not None:
97
+ self.semantic_data = self.semantic_data[:max_sample]
98
+
99
+ # {idx: (semantic, phoneme)}
100
+ # semantic list, phoneme list
101
+ self.semantic_phoneme = []
102
+ self.item_names = []
103
+
104
+ self.inited = False
105
+
106
+ if not self.inited:
107
+ # 调用初始化函数
108
+ self.init_batch()
109
+ self.inited = True
110
+ del self.semantic_data
111
+ del self.phoneme_data
112
+ # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large")
113
+ # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large")
114
+
115
+ def init_batch(self):
116
+ semantic_data_len = len(self.semantic_data)
117
+ phoneme_data_len = len(self.phoneme_data.keys())
118
+ print("semantic_data_len:", semantic_data_len)
119
+ print("phoneme_data_len:", phoneme_data_len)
120
+ print(self.semantic_data)
121
+ idx = 0
122
+ num_not_in = 0
123
+ num_deleted_bigger = 0
124
+ num_deleted_ps = 0
125
+ for i in range(semantic_data_len):
126
+ # 先依次遍历
127
+ # get str
128
+ item_name = self.semantic_data.iloc[i, 0]
129
+ # print(self.phoneme_data)
130
+ try:
131
+ phoneme, word2ph, text = self.phoneme_data[item_name]
132
+ except Exception:
133
+ traceback.print_exc()
134
+ # print(f"{item_name} not in self.phoneme_data !")
135
+ num_not_in += 1
136
+ continue
137
+
138
+ semantic_str = self.semantic_data.iloc[i, 1]
139
+ # get token list
140
+ semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
141
+ # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
142
+ # 过滤掉太长的样��
143
+ if (
144
+ len(semantic_ids) > self.max_sec * self.hz
145
+ ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k
146
+ num_deleted_bigger += 1
147
+ continue
148
+ # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理####
149
+ phoneme = phoneme.split(" ")
150
+
151
+ try:
152
+ phoneme_ids = cleaned_text_to_sequence(phoneme, version)
153
+ except:
154
+ traceback.print_exc()
155
+ # print(f"{item_name} not in self.phoneme_data !")
156
+ num_not_in += 1
157
+ continue
158
+ # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
159
+ if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
160
+ num_deleted_ps += 1
161
+ continue
162
+ # if len(semantic_ids) > 1000:###########3
163
+ # num_deleted_bigger += 1
164
+ # continue
165
+
166
+ ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
167
+
168
+ if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
169
+ num_deleted_ps += 1
170
+ # print(item_name)
171
+ continue
172
+
173
+ self.semantic_phoneme.append((semantic_ids, phoneme_ids))
174
+ idx += 1
175
+ self.item_names.append(item_name)
176
+
177
+ min_num = 100 # 20直接不补#30补了也不存ckpt
178
+ leng = len(self.semantic_phoneme)
179
+ if leng < min_num:
180
+ tmp1 = self.semantic_phoneme
181
+ tmp2 = self.item_names
182
+ self.semantic_phoneme = []
183
+ self.item_names = []
184
+ for _ in range(max(2, int(min_num / leng))):
185
+ self.semantic_phoneme += tmp1
186
+ self.item_names += tmp2
187
+ if num_not_in > 0:
188
+ print(f"there are {num_not_in} semantic datas not in phoneme datas")
189
+ if num_deleted_bigger > 0:
190
+ print(
191
+ f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
192
+ )
193
+ if num_deleted_ps > 0:
194
+ # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
195
+ print(
196
+ f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
197
+ )
198
+ """
199
+ there are 31 semantic datas not in phoneme datas
200
+ deleted 34 audios who's duration are bigger than 54 seconds
201
+ deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3
202
+ dataset.__len__(): 366463
203
+
204
+ """
205
+ # 345410 for LibriTTS
206
+ print("dataset.__len__():", self.__len__())
207
+
208
+ def __get_item_names__(self) -> List[str]:
209
+ return self.item_names
210
+
211
+ def __len__(self) -> int:
212
+ return len(self.semantic_phoneme)
213
+
214
+ def __getitem__(self, idx: int) -> Dict:
215
+ semantic_ids, phoneme_ids = self.semantic_phoneme[idx]
216
+ item_name = self.item_names[idx]
217
+ phoneme_ids_len = len(phoneme_ids)
218
+ # semantic tokens target
219
+ semantic_ids_len = len(semantic_ids)
220
+
221
+ flag = 0
222
+ path_bert = "%s/%s.pt" % (self.path3, item_name)
223
+ if os.path.exists(path_bert) == True:
224
+ bert_feature = torch.load(path_bert, map_location="cpu")
225
+ else:
226
+ flag = 1
227
+ if flag == 1:
228
+ # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
229
+ bert_feature = None
230
+ else:
231
+ assert bert_feature.shape[-1] == len(phoneme_ids)
232
+ return {
233
+ "idx": idx,
234
+ "phoneme_ids": phoneme_ids,
235
+ "phoneme_ids_len": phoneme_ids_len,
236
+ "semantic_ids": semantic_ids,
237
+ "semantic_ids_len": semantic_ids_len,
238
+ "bert_feature": bert_feature,
239
+ }
240
+
241
+ def get_sample_length(self, idx: int):
242
+ semantic_ids = self.semantic_phoneme[idx][0]
243
+ sec = 1.0 * len(semantic_ids) / self.hz
244
+ return sec
245
+
246
+ def collate(self, examples: List[Dict]) -> Dict:
247
+ sample_index: List[int] = []
248
+ phoneme_ids: List[torch.Tensor] = []
249
+ phoneme_ids_lens: List[int] = []
250
+ semantic_ids: List[torch.Tensor] = []
251
+ semantic_ids_lens: List[int] = []
252
+ # return
253
+
254
+ for item in examples:
255
+ sample_index.append(item["idx"])
256
+ phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64))
257
+ semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64))
258
+ phoneme_ids_lens.append(item["phoneme_ids_len"])
259
+ semantic_ids_lens.append(item["semantic_ids_len"])
260
+
261
+ # pad 0
262
+ phoneme_ids = batch_sequences(phoneme_ids)
263
+ semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD)
264
+
265
+ # # convert each batch to torch.tensor
266
+ phoneme_ids = torch.tensor(phoneme_ids)
267
+ semantic_ids = torch.tensor(semantic_ids)
268
+ phoneme_ids_lens = torch.tensor(phoneme_ids_lens)
269
+ semantic_ids_lens = torch.tensor(semantic_ids_lens)
270
+ bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens))
271
+ bert_padded.zero_()
272
+
273
+ for idx, item in enumerate(examples):
274
+ bert = item["bert_feature"]
275
+ if bert != None:
276
+ bert_padded[idx, :, : bert.shape[-1]] = bert
277
+
278
+ return {
279
+ # List[int]
280
+ "ids": sample_index,
281
+ # torch.Tensor (B, max_phoneme_length)
282
+ "phoneme_ids": phoneme_ids,
283
+ # torch.Tensor (B)
284
+ "phoneme_ids_len": phoneme_ids_lens,
285
+ # torch.Tensor (B, max_semantic_ids_length)
286
+ "semantic_ids": semantic_ids,
287
+ # torch.Tensor (B)
288
+ "semantic_ids_len": semantic_ids_lens,
289
+ # torch.Tensor (B, 1024, max_phoneme_length)
290
+ "bert_feature": bert_padded,
291
+ }
292
+
293
+
294
+ if __name__ == "__main__":
295
+ root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/"
296
+ dataset = Text2SemanticDataset(
297
+ phoneme_path=root_dir + "phoneme_train.npy",
298
+ semantic_path=root_dir + "semantic_train.tsv",
299
+ )
300
+
301
+ batch_size = 12
302
+ dataloader = DataLoader(
303
+ dataset,
304
+ batch_size=batch_size,
305
+ collate_fn=dataset.collate,
306
+ shuffle=False,
307
+ )
308
+ for i, batch in enumerate(dataloader):
309
+ if i % 1000 == 0:
310
+ print(i)
311
+ # if i == 0:
312
+ # print('batch["ids"]:', batch["ids"])
313
+ # print('batch["phoneme_ids"]:', batch["phoneme_ids"],
314
+ # batch["phoneme_ids"].shape)
315
+ # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"],
316
+ # batch["phoneme_ids_len"].shape)
317
+ # print('batch["semantic_ids"]:', batch["semantic_ids"],
318
+ # batch["semantic_ids"].shape)
319
+ # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"],
320
+ # batch["semantic_ids_len"].shape)
GPT_SoVITS/AR/models/__init__.py ADDED
File without changes
GPT_SoVITS/AR/models/t2s_lightning_module.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os
4
+ import sys
5
+
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+ from typing import Dict
9
+
10
+ import torch
11
+ from pytorch_lightning import LightningModule
12
+
13
+ from AR.models.t2s_model import Text2SemanticDecoder
14
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
15
+ from AR.modules.optim import ScaledAdam
16
+
17
+
18
+ class Text2SemanticLightningModule(LightningModule):
19
+ def __init__(self, config, output_dir, is_train=True):
20
+ super().__init__()
21
+ self.config = config
22
+ self.top_k = 3
23
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
24
+ pretrained_s1 = config.get("pretrained_s1")
25
+ if pretrained_s1 and is_train:
26
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
27
+ print(
28
+ self.load_state_dict(
29
+ torch.load(
30
+ pretrained_s1,
31
+ map_location="cpu",
32
+ )["weight"],
33
+ )
34
+ )
35
+ if is_train:
36
+ self.automatic_optimization = False
37
+ self.save_hyperparameters()
38
+ self.eval_dir = output_dir / "eval"
39
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ def training_step(self, batch: Dict, batch_idx: int):
42
+ opt = self.optimizers()
43
+ scheduler = self.lr_schedulers()
44
+ forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
45
+ loss, acc = forward(
46
+ batch["phoneme_ids"],
47
+ batch["phoneme_ids_len"],
48
+ batch["semantic_ids"],
49
+ batch["semantic_ids_len"],
50
+ batch["bert_feature"],
51
+ )
52
+ self.manual_backward(loss)
53
+ if batch_idx > 0 and batch_idx % 4 == 0:
54
+ opt.step()
55
+ opt.zero_grad()
56
+ scheduler.step()
57
+
58
+ self.log(
59
+ "total_loss",
60
+ loss,
61
+ on_step=True,
62
+ on_epoch=True,
63
+ prog_bar=True,
64
+ sync_dist=True,
65
+ )
66
+ self.log(
67
+ "lr",
68
+ scheduler.get_last_lr()[0],
69
+ on_epoch=True,
70
+ prog_bar=True,
71
+ sync_dist=True,
72
+ )
73
+ self.log(
74
+ f"top_{self.top_k}_acc",
75
+ acc,
76
+ on_step=True,
77
+ on_epoch=True,
78
+ prog_bar=True,
79
+ sync_dist=True,
80
+ )
81
+
82
+ def validation_step(self, batch: Dict, batch_idx: int):
83
+ return
84
+
85
+ # # get loss
86
+ # loss, acc = self.model.forward(
87
+ # batch['phoneme_ids'], batch['phoneme_ids_len'],
88
+ # batch['semantic_ids'], batch['semantic_ids_len'],
89
+ # batch['bert_feature']
90
+ # )
91
+ #
92
+ # self.log(
93
+ # "val_total_loss",
94
+ # loss,
95
+ # on_step=True,
96
+ # on_epoch=True,
97
+ # prog_bar=True,
98
+ # sync_dist=True)
99
+ # self.log(
100
+ # f"val_top_{self.top_k}_acc",
101
+ # acc,
102
+ # on_step=True,
103
+ # on_epoch=True,
104
+ # prog_bar=True,
105
+ # sync_dist=True)
106
+ #
107
+ # # get infer output
108
+ # semantic_len = batch['semantic_ids'].size(1)
109
+ # prompt_len = min(int(semantic_len * 0.5), 150)
110
+ # prompt = batch['semantic_ids'][:, :prompt_len]
111
+ # pred_semantic = self.model.infer(batch['phoneme_ids'],
112
+ # batch['phoneme_ids_len'], prompt,
113
+ # batch['bert_feature']
114
+ # )
115
+ # save_name = f'semantic_toks_{batch_idx}.pt'
116
+ # save_path = os.path.join(self.eval_dir, save_name)
117
+ # torch.save(pred_semantic.detach().cpu(), save_path)
118
+
119
+ def configure_optimizers(self):
120
+ model_parameters = self.model.parameters()
121
+ parameters_names = []
122
+ parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
123
+ lm_opt = ScaledAdam(
124
+ model_parameters,
125
+ lr=0.01,
126
+ betas=(0.9, 0.95),
127
+ clipping_scale=2.0,
128
+ parameters_names=parameters_names,
129
+ show_dominant_parameters=False,
130
+ clipping_update_period=1000,
131
+ )
132
+
133
+ return {
134
+ "optimizer": lm_opt,
135
+ "lr_scheduler": {
136
+ "scheduler": WarmupCosineLRSchedule(
137
+ lm_opt,
138
+ init_lr=self.config["optimizer"]["lr_init"],
139
+ peak_lr=self.config["optimizer"]["lr"],
140
+ end_lr=self.config["optimizer"]["lr_end"],
141
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
142
+ total_steps=self.config["optimizer"]["decay_steps"],
143
+ )
144
+ },
145
+ }
GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import os
4
+ import sys
5
+
6
+ now_dir = os.getcwd()
7
+ sys.path.append(now_dir)
8
+ from typing import Dict
9
+
10
+ import torch
11
+ from pytorch_lightning import LightningModule
12
+
13
+ from AR.models.t2s_model_onnx import Text2SemanticDecoder
14
+ from AR.modules.lr_schedulers import WarmupCosineLRSchedule
15
+ from AR.modules.optim import ScaledAdam
16
+
17
+
18
+ class Text2SemanticLightningModule(LightningModule):
19
+ def __init__(self, config, output_dir, is_train=True):
20
+ super().__init__()
21
+ self.config = config
22
+ self.top_k = 3
23
+ self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
24
+ pretrained_s1 = config.get("pretrained_s1")
25
+ if pretrained_s1 and is_train:
26
+ # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
27
+ print(
28
+ self.load_state_dict(
29
+ torch.load(
30
+ pretrained_s1,
31
+ map_location="cpu",
32
+ )["weight"],
33
+ ),
34
+ )
35
+ if is_train:
36
+ self.automatic_optimization = False
37
+ self.save_hyperparameters()
38
+ self.eval_dir = output_dir / "eval"
39
+ self.eval_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ def training_step(self, batch: Dict, batch_idx: int):
42
+ opt = self.optimizers()
43
+ scheduler = self.lr_schedulers()
44
+ loss, acc = self.model.forward(
45
+ batch["phoneme_ids"],
46
+ batch["phoneme_ids_len"],
47
+ batch["semantic_ids"],
48
+ batch["semantic_ids_len"],
49
+ batch["bert_feature"],
50
+ )
51
+ self.manual_backward(loss)
52
+ if batch_idx > 0 and batch_idx % 4 == 0:
53
+ opt.step()
54
+ opt.zero_grad()
55
+ scheduler.step()
56
+
57
+ self.log(
58
+ "total_loss",
59
+ loss,
60
+ on_step=True,
61
+ on_epoch=True,
62
+ prog_bar=True,
63
+ sync_dist=True,
64
+ )
65
+ self.log(
66
+ "lr",
67
+ scheduler.get_last_lr()[0],
68
+ on_epoch=True,
69
+ prog_bar=True,
70
+ sync_dist=True,
71
+ )
72
+ self.log(
73
+ f"top_{self.top_k}_acc",
74
+ acc,
75
+ on_step=True,
76
+ on_epoch=True,
77
+ prog_bar=True,
78
+ sync_dist=True,
79
+ )
80
+
81
+ def validation_step(self, batch: Dict, batch_idx: int):
82
+ return
83
+
84
+ def configure_optimizers(self):
85
+ model_parameters = self.model.parameters()
86
+ parameters_names = []
87
+ parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
88
+ lm_opt = ScaledAdam(
89
+ model_parameters,
90
+ lr=0.01,
91
+ betas=(0.9, 0.95),
92
+ clipping_scale=2.0,
93
+ parameters_names=parameters_names,
94
+ show_dominant_parameters=False,
95
+ clipping_update_period=1000,
96
+ )
97
+
98
+ return {
99
+ "optimizer": lm_opt,
100
+ "lr_scheduler": {
101
+ "scheduler": WarmupCosineLRSchedule(
102
+ lm_opt,
103
+ init_lr=self.config["optimizer"]["lr_init"],
104
+ peak_lr=self.config["optimizer"]["lr"],
105
+ end_lr=self.config["optimizer"]["lr_end"],
106
+ warmup_steps=self.config["optimizer"]["warmup_steps"],
107
+ total_steps=self.config["optimizer"]["decay_steps"],
108
+ )
109
+ },
110
+ }
GPT_SoVITS/AR/models/t2s_model.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import math
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+ from tqdm import tqdm
11
+
12
+ from AR.models.utils import (
13
+ dpo_loss,
14
+ get_batch_logps,
15
+ make_pad_mask,
16
+ make_pad_mask_left,
17
+ make_reject_y,
18
+ sample,
19
+ topk_sampling,
20
+ )
21
+ from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
22
+ from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
23
+
24
+ default_config = {
25
+ "embedding_dim": 512,
26
+ "hidden_dim": 512,
27
+ "num_head": 8,
28
+ "num_layers": 12,
29
+ "num_codebook": 8,
30
+ "p_dropout": 0.0,
31
+ "vocab_size": 1024 + 1,
32
+ "phoneme_vocab_size": 512,
33
+ "EOS": 1024,
34
+ }
35
+
36
+
37
+ # @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
38
+ # Efficient implementation equivalent to the following:
39
+ def scaled_dot_product_attention(
40
+ query: torch.Tensor,
41
+ key: torch.Tensor,
42
+ value: torch.Tensor,
43
+ attn_mask: Optional[torch.Tensor] = None,
44
+ scale: Optional[torch.Tensor] = None,
45
+ ) -> torch.Tensor:
46
+ B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
47
+ if scale is None:
48
+ scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
49
+ else:
50
+ scale_factor = scale
51
+ attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
52
+
53
+ if attn_mask is not None:
54
+ if attn_mask.dtype == torch.bool:
55
+ attn_bias.masked_fill_(attn_mask, float("-inf"))
56
+ else:
57
+ attn_bias += attn_mask
58
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
59
+ attn_weight += attn_bias
60
+ attn_weight = torch.softmax(attn_weight, dim=-1)
61
+
62
+ if attn_mask is not None:
63
+ if attn_mask.dtype == torch.bool:
64
+ attn_weight.masked_fill_(attn_mask, 0)
65
+ else:
66
+ attn_mask[attn_mask != float("-inf")] = 0
67
+ attn_mask[attn_mask == float("-inf")] = 1
68
+ attn_weight.masked_fill_(attn_mask, 0)
69
+
70
+ return attn_weight @ value
71
+
72
+
73
+ @torch.jit.script
74
+ class T2SMLP:
75
+ def __init__(self, w1, b1, w2, b2):
76
+ self.w1 = w1
77
+ self.b1 = b1
78
+ self.w2 = w2
79
+ self.b2 = b2
80
+
81
+ def forward(self, x):
82
+ x = F.relu(F.linear(x, self.w1, self.b1))
83
+ x = F.linear(x, self.w2, self.b2)
84
+ return x
85
+
86
+
87
+ @torch.jit.script
88
+ class T2SBlock:
89
+ def __init__(
90
+ self,
91
+ num_heads,
92
+ hidden_dim: int,
93
+ mlp: T2SMLP,
94
+ qkv_w,
95
+ qkv_b,
96
+ out_w,
97
+ out_b,
98
+ norm_w1,
99
+ norm_b1,
100
+ norm_eps1,
101
+ norm_w2,
102
+ norm_b2,
103
+ norm_eps2,
104
+ ):
105
+ self.num_heads = num_heads
106
+ self.mlp = mlp
107
+ self.hidden_dim: int = hidden_dim
108
+ self.qkv_w = qkv_w
109
+ self.qkv_b = qkv_b
110
+ self.out_w = out_w
111
+ self.out_b = out_b
112
+ self.norm_w1 = norm_w1
113
+ self.norm_b1 = norm_b1
114
+ self.norm_eps1 = norm_eps1
115
+ self.norm_w2 = norm_w2
116
+ self.norm_b2 = norm_b2
117
+ self.norm_eps2 = norm_eps2
118
+
119
+ self.false = torch.tensor(False, dtype=torch.bool)
120
+
121
+ @torch.jit.ignore
122
+ def to_mask(
123
+ self,
124
+ x: torch.Tensor,
125
+ padding_mask: Optional[torch.Tensor],
126
+ ):
127
+ if padding_mask is None:
128
+ return x
129
+
130
+ if padding_mask.dtype == torch.bool:
131
+ return x.masked_fill(padding_mask, 0)
132
+ else:
133
+ return x * padding_mask
134
+
135
+ def process_prompt(
136
+ self,
137
+ x: torch.Tensor,
138
+ attn_mask: torch.Tensor,
139
+ padding_mask: Optional[torch.Tensor] = None,
140
+ torch_sdpa: bool = True,
141
+ ):
142
+ q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
143
+
144
+ batch_size = q.shape[0]
145
+ q_len = q.shape[1]
146
+ kv_len = k.shape[1]
147
+
148
+ q = self.to_mask(q, padding_mask)
149
+ k_cache = self.to_mask(k, padding_mask)
150
+ v_cache = self.to_mask(v, padding_mask)
151
+
152
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
153
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
154
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
155
+
156
+ if torch_sdpa:
157
+ attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
158
+ else:
159
+ attn = scaled_dot_product_attention(q, k, v, attn_mask)
160
+
161
+ attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
162
+ attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
163
+
164
+ x = x + attn
165
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
166
+ x = x + self.mlp.forward(x)
167
+ x = F.layer_norm(
168
+ x,
169
+ [self.hidden_dim],
170
+ self.norm_w2,
171
+ self.norm_b2,
172
+ self.norm_eps2,
173
+ )
174
+ return x, k_cache, v_cache
175
+
176
+ def decode_next_token(
177
+ self,
178
+ x: torch.Tensor,
179
+ k_cache: torch.Tensor,
180
+ v_cache: torch.Tensor,
181
+ attn_mask: torch.Tensor = None,
182
+ torch_sdpa: bool = True,
183
+ ):
184
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
185
+
186
+ k_cache = torch.cat([k_cache, k], dim=1)
187
+ v_cache = torch.cat([v_cache, v], dim=1)
188
+
189
+ batch_size = q.shape[0]
190
+ q_len = q.shape[1]
191
+ kv_len = k_cache.shape[1]
192
+
193
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
194
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
195
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
196
+
197
+ if torch_sdpa:
198
+ attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
199
+ else:
200
+ attn = scaled_dot_product_attention(q, k, v, attn_mask)
201
+
202
+ attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
203
+ attn = F.linear(attn, self.out_w, self.out_b)
204
+
205
+ x = x + attn
206
+ x = F.layer_norm(
207
+ x,
208
+ [self.hidden_dim],
209
+ self.norm_w1,
210
+ self.norm_b1,
211
+ self.norm_eps1,
212
+ )
213
+ x = x + self.mlp.forward(x)
214
+ x = F.layer_norm(
215
+ x,
216
+ [self.hidden_dim],
217
+ self.norm_w2,
218
+ self.norm_b2,
219
+ self.norm_eps2,
220
+ )
221
+ return x, k_cache, v_cache
222
+
223
+
224
+ @torch.jit.script
225
+ class T2STransformer:
226
+ def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
227
+ self.num_blocks: int = num_blocks
228
+ self.blocks = blocks
229
+
230
+ def process_prompt(
231
+ self,
232
+ x: torch.Tensor,
233
+ attn_mask: torch.Tensor,
234
+ padding_mask: Optional[torch.Tensor] = None,
235
+ torch_sdpa: bool = True,
236
+ ):
237
+ k_cache: List[torch.Tensor] = []
238
+ v_cache: List[torch.Tensor] = []
239
+ for i in range(self.num_blocks):
240
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
241
+ k_cache.append(k_cache_)
242
+ v_cache.append(v_cache_)
243
+ return x, k_cache, v_cache
244
+
245
+ def decode_next_token(
246
+ self,
247
+ x: torch.Tensor,
248
+ k_cache: List[torch.Tensor],
249
+ v_cache: List[torch.Tensor],
250
+ attn_mask: torch.Tensor = None,
251
+ torch_sdpa: bool = True,
252
+ ):
253
+ for i in range(self.num_blocks):
254
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
255
+ x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
256
+ )
257
+ return x, k_cache, v_cache
258
+
259
+
260
+ class Text2SemanticDecoder(nn.Module):
261
+ def __init__(self, config, norm_first=False, top_k=3):
262
+ super(Text2SemanticDecoder, self).__init__()
263
+ self.model_dim = config["model"]["hidden_dim"]
264
+ self.embedding_dim = config["model"]["embedding_dim"]
265
+ self.num_head = config["model"]["head"]
266
+ self.num_layers = config["model"]["n_layer"]
267
+ self.norm_first = norm_first
268
+ self.vocab_size = config["model"]["vocab_size"]
269
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
270
+ self.p_dropout = config["model"]["dropout"]
271
+ self.EOS = config["model"]["EOS"]
272
+ self.norm_first = norm_first
273
+ assert self.EOS == self.vocab_size - 1
274
+ # should be same as num of kmeans bin
275
+ # assert self.EOS == 1024
276
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
277
+ self.ar_text_embedding = TokenEmbedding(
278
+ self.embedding_dim,
279
+ self.phoneme_vocab_size,
280
+ self.p_dropout,
281
+ )
282
+ self.ar_text_position = SinePositionalEmbedding(
283
+ self.embedding_dim,
284
+ dropout=0.1,
285
+ scale=False,
286
+ alpha=True,
287
+ )
288
+ self.ar_audio_embedding = TokenEmbedding(
289
+ self.embedding_dim,
290
+ self.vocab_size,
291
+ self.p_dropout,
292
+ )
293
+ self.ar_audio_position = SinePositionalEmbedding(
294
+ self.embedding_dim,
295
+ dropout=0.1,
296
+ scale=False,
297
+ alpha=True,
298
+ )
299
+
300
+ self.h = TransformerEncoder(
301
+ TransformerEncoderLayer(
302
+ d_model=self.model_dim,
303
+ nhead=self.num_head,
304
+ dim_feedforward=self.model_dim * 4,
305
+ dropout=0.1,
306
+ batch_first=True,
307
+ norm_first=norm_first,
308
+ ),
309
+ num_layers=self.num_layers,
310
+ norm=LayerNorm(self.model_dim) if norm_first else None,
311
+ )
312
+
313
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
314
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
315
+
316
+ self.ar_accuracy_metric = MulticlassAccuracy(
317
+ self.vocab_size,
318
+ top_k=top_k,
319
+ average="micro",
320
+ multidim_average="global",
321
+ ignore_index=self.EOS,
322
+ )
323
+
324
+ blocks = []
325
+
326
+ for i in range(self.num_layers):
327
+ layer = self.h.layers[i]
328
+ t2smlp = T2SMLP(
329
+ layer.linear1.weight,
330
+ layer.linear1.bias,
331
+ layer.linear2.weight,
332
+ layer.linear2.bias,
333
+ )
334
+
335
+ block = T2SBlock(
336
+ self.num_head,
337
+ self.model_dim,
338
+ t2smlp,
339
+ layer.self_attn.in_proj_weight,
340
+ layer.self_attn.in_proj_bias,
341
+ layer.self_attn.out_proj.weight,
342
+ layer.self_attn.out_proj.bias,
343
+ layer.norm1.weight,
344
+ layer.norm1.bias,
345
+ layer.norm1.eps,
346
+ layer.norm2.weight,
347
+ layer.norm2.bias,
348
+ layer.norm2.eps,
349
+ )
350
+
351
+ blocks.append(block)
352
+
353
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
354
+
355
+ def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
356
+ x = self.ar_text_embedding(x)
357
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
358
+ x = self.ar_text_position(x)
359
+ x_mask = make_pad_mask(x_lens)
360
+
361
+ y_mask = make_pad_mask(y_lens)
362
+ y_mask_int = y_mask.type(torch.int64)
363
+ codes = y.type(torch.int64) * (1 - y_mask_int)
364
+
365
+ # Training
366
+ # AR Decoder
367
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
368
+ x_len = x_lens.max()
369
+ y_len = y_lens.max()
370
+ y_emb = self.ar_audio_embedding(y)
371
+ y_pos = self.ar_audio_position(y_emb)
372
+
373
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
374
+
375
+ ar_xy_padding_mask = xy_padding_mask
376
+
377
+ x_attn_mask = F.pad(
378
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
379
+ (0, y_len),
380
+ value=True,
381
+ )
382
+ # x_attn_mask[:, x_len]=False
383
+ y_attn_mask = F.pad(
384
+ torch.triu(
385
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
386
+ diagonal=1,
387
+ ),
388
+ (x_len, 0),
389
+ value=False,
390
+ )
391
+
392
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
393
+ bsz, src_len = x.shape[0], x_len + y_len
394
+ _xy_padding_mask = (
395
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
396
+ .expand(-1, self.num_head, -1, -1)
397
+ .reshape(bsz * self.num_head, 1, src_len)
398
+ )
399
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
400
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
401
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
402
+ xy_attn_mask = new_attn_mask
403
+ # x 和完整的 y 一次性输入模型
404
+ xy_pos = torch.concat([x, y_pos], dim=1)
405
+
406
+ return xy_pos, xy_attn_mask, targets
407
+
408
+ def forward(self, x, x_lens, y, y_lens, bert_feature):
409
+ """
410
+ x: phoneme_ids
411
+ y: semantic_ids
412
+ """
413
+
414
+ reject_y, reject_y_lens = make_reject_y(y, y_lens)
415
+
416
+ xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
417
+
418
+ xy_dec, _ = self.h(
419
+ (xy_pos, None),
420
+ mask=xy_attn_mask,
421
+ )
422
+ x_len = x_lens.max()
423
+ logits = self.ar_predict_layer(xy_dec[:, x_len:])
424
+
425
+ ###### DPO #############
426
+ reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
427
+ x, x_lens, reject_y, reject_y_lens, bert_feature
428
+ )
429
+
430
+ reject_xy_dec, _ = self.h(
431
+ (reject_xy_pos, None),
432
+ mask=reject_xy_attn_mask,
433
+ )
434
+ x_len = x_lens.max()
435
+ reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
436
+
437
+ # loss
438
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
439
+
440
+ loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
441
+ acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
442
+
443
+ A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
444
+ loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
445
+
446
+ loss = loss_1 + loss_2
447
+
448
+ return loss, acc
449
+
450
+ def forward_old(self, x, x_lens, y, y_lens, bert_feature):
451
+ """
452
+ x: phoneme_ids
453
+ y: semantic_ids
454
+ """
455
+ x = self.ar_text_embedding(x)
456
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
457
+ x = self.ar_text_position(x)
458
+ x_mask = make_pad_mask(x_lens)
459
+
460
+ y_mask = make_pad_mask(y_lens)
461
+ y_mask_int = y_mask.type(torch.int64)
462
+ codes = y.type(torch.int64) * (1 - y_mask_int)
463
+
464
+ # Training
465
+ # AR Decoder
466
+ y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
467
+ x_len = x_lens.max()
468
+ y_len = y_lens.max()
469
+ y_emb = self.ar_audio_embedding(y)
470
+ y_pos = self.ar_audio_position(y_emb)
471
+
472
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
473
+ ar_xy_padding_mask = xy_padding_mask
474
+
475
+ x_attn_mask = F.pad(
476
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
477
+ (0, y_len),
478
+ value=True,
479
+ )
480
+ y_attn_mask = F.pad(
481
+ torch.triu(
482
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
483
+ diagonal=1,
484
+ ),
485
+ (x_len, 0),
486
+ value=False,
487
+ )
488
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
489
+ bsz, src_len = x.shape[0], x_len + y_len
490
+ _xy_padding_mask = (
491
+ ar_xy_padding_mask.view(bsz, 1, 1, src_len)
492
+ .expand(-1, self.num_head, -1, -1)
493
+ .reshape(bsz * self.num_head, 1, src_len)
494
+ )
495
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
496
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
497
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
498
+ xy_attn_mask = new_attn_mask
499
+ # x 和完整的 y 一次性输入模型
500
+ xy_pos = torch.concat([x, y_pos], dim=1)
501
+ xy_dec, _ = self.h(
502
+ (xy_pos, None),
503
+ mask=xy_attn_mask,
504
+ )
505
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
506
+ # loss
507
+ # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
508
+ loss = F.cross_entropy(logits, targets, reduction="sum")
509
+ acc = self.ar_accuracy_metric(logits.detach(), targets).item()
510
+ return loss, acc
511
+
512
+ # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
513
+ def infer(
514
+ self,
515
+ x,
516
+ x_lens,
517
+ prompts,
518
+ bert_feature,
519
+ top_k: int = -100,
520
+ early_stop_num: int = -1,
521
+ temperature: float = 1.0,
522
+ ):
523
+ x = self.ar_text_embedding(x)
524
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
525
+ x = self.ar_text_position(x)
526
+
527
+ # AR Decoder
528
+ y = prompts
529
+ prefix_len = y.shape[1]
530
+ x_len = x.shape[1]
531
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
532
+ stop = False
533
+ for _ in tqdm(range(1500)):
534
+ y_emb = self.ar_audio_embedding(y)
535
+ y_pos = self.ar_audio_position(y_emb)
536
+ # x 和逐渐增长的 y 一起输入给模型
537
+ xy_pos = torch.concat([x, y_pos], dim=1)
538
+ y_len = y.shape[1]
539
+ x_attn_mask_pad = F.pad(
540
+ x_attn_mask,
541
+ (0, y_len),
542
+ value=True,
543
+ )
544
+ y_attn_mask = F.pad(
545
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
546
+ (x_len, 0),
547
+ value=False,
548
+ )
549
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
550
+
551
+ xy_dec, _ = self.h(
552
+ (xy_pos, None),
553
+ mask=xy_attn_mask,
554
+ )
555
+ logits = self.ar_predict_layer(xy_dec[:, -1])
556
+ samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
557
+
558
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
559
+ print("use early stop num:", early_stop_num)
560
+ stop = True
561
+
562
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
563
+ # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
564
+ stop = True
565
+ if stop:
566
+ if prompts.shape[1] == y.shape[1]:
567
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
568
+ print("bad zero prediction")
569
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
570
+ break
571
+ # 本次生成的 semantic_ids 和之前的 y 构成新的 y
572
+ # print(samples.shape)#[1,1]#第一个1是bs
573
+ # import os
574
+ # os._exit(2333)
575
+ y = torch.concat([y, samples], dim=1)
576
+ return y
577
+
578
+ def pad_y_eos(self, y, y_mask_int, eos_id):
579
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
580
+ # 错位
581
+ return targets[:, :-1], targets[:, 1:]
582
+
583
+ def infer_panel_batch_infer(
584
+ self,
585
+ x: List[torch.LongTensor], #####全部文本token
586
+ x_lens: torch.LongTensor,
587
+ prompts: torch.LongTensor, ####参考音频token
588
+ bert_feature: List[torch.LongTensor],
589
+ top_k: int = -100,
590
+ top_p: int = 100,
591
+ early_stop_num: int = -1,
592
+ temperature: float = 1.0,
593
+ repetition_penalty: float = 1.35,
594
+ **kwargs,
595
+ ):
596
+ if prompts is None:
597
+ print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
598
+ return self.infer_panel_naive_batched(
599
+ x,
600
+ x_lens,
601
+ prompts,
602
+ bert_feature,
603
+ top_k=top_k,
604
+ top_p=top_p,
605
+ early_stop_num=early_stop_num,
606
+ temperature=temperature,
607
+ **kwargs,
608
+ )
609
+
610
+ max_len = kwargs.get("max_len", x_lens.max())
611
+ x_list = []
612
+ for x_item, bert_item in zip(x, bert_feature):
613
+ # max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
614
+ x_item = self.ar_text_embedding(x_item.unsqueeze(0))
615
+ x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
616
+ x_item = self.ar_text_position(x_item).squeeze(0)
617
+ # x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
618
+ x_item = (
619
+ F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
620
+ ) ### padding left
621
+ x_list.append(x_item)
622
+ x: torch.Tensor = torch.stack(x_list, dim=0)
623
+
624
+ # AR Decoder
625
+ y = prompts
626
+
627
+ x_len = x.shape[1]
628
+ stop = False
629
+
630
+ k_cache = None
631
+ v_cache = None
632
+ ################### first step ##########################
633
+ assert y is not None, "Error: Prompt free is not supported batch_infer!"
634
+ ref_free = False
635
+
636
+ y_emb = self.ar_audio_embedding(y)
637
+ y_len = y_emb.shape[1]
638
+ prefix_len = y.shape[1]
639
+ y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
640
+ y_pos = self.ar_audio_position(y_emb)
641
+ xy_pos = torch.concat([x, y_pos], dim=1)
642
+
643
+ ##### create mask #####
644
+ bsz = x.shape[0]
645
+ src_len = x_len + y_len
646
+ y_paddind_mask = make_pad_mask_left(y_lens, y_len)
647
+ x_paddind_mask = make_pad_mask_left(x_lens, max_len)
648
+
649
+ # (bsz, x_len + y_len)
650
+ padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
651
+
652
+ x_mask = F.pad(
653
+ torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
654
+ (0, y_len),
655
+ value=True,
656
+ )
657
+
658
+ y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
659
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
660
+ (x_len, 0),
661
+ value=False,
662
+ )
663
+
664
+ causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
665
+ # padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
666
+ ### 上面是错误的,会导致padding的token被"看见"
667
+
668
+ # 正确的padding_mask应该是:
669
+ # | pad_len | x_len | y_len |
670
+ # [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
671
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
672
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
673
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
674
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
675
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
676
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
677
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
678
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
679
+
680
+ padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
681
+
682
+ attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
683
+ attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
684
+
685
+ # 正确的attn_mask应该是这样的:
686
+ # | pad_len | x_len | y_len |
687
+ # [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
688
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
689
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
690
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
691
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
692
+ # [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
693
+ # [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
694
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
695
+ # [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
696
+
697
+ ###### decode #####
698
+ y_list = [None] * y.shape[0]
699
+ batch_idx_map = list(range(y.shape[0]))
700
+ idx_list = [None] * y.shape[0]
701
+ for idx in tqdm(range(1500)):
702
+ if idx == 0:
703
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
704
+ else:
705
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
706
+ logits = self.ar_predict_layer(xy_dec[:, -1])
707
+
708
+ if idx == 0:
709
+ attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
710
+ logits = logits[:, :-1]
711
+ else:
712
+ attn_mask = F.pad(attn_mask, (0, 1), value=False)
713
+
714
+ samples = sample(
715
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
716
+ )[0]
717
+
718
+ y = torch.concat([y, samples], dim=1)
719
+
720
+ ####### 移除batch中已经生成完毕的序列,进一步优化计算量
721
+ tokens = torch.argmax(logits, dim=-1)
722
+ reserved_idx_of_batch_for_y = None
723
+ if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
724
+ l1 = samples[:, 0] == self.EOS
725
+ l2 = tokens == self.EOS
726
+ l = l1.logical_or(l2)
727
+ removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
728
+ reserved_idx_of_batch_for_y = torch.where(l == False)[0]
729
+ # batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
730
+ for i in removed_idx_of_batch_for_y:
731
+ batch_index = batch_idx_map[i]
732
+ idx_list[batch_index] = idx
733
+ y_list[batch_index] = y[i, :-1]
734
+
735
+ batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
736
+
737
+ # 只保留batch中未生成完毕的序列
738
+ if reserved_idx_of_batch_for_y is not None:
739
+ # index = torch.LongTensor(batch_idx_map).to(y.device)
740
+ y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
741
+ attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
742
+ if k_cache is not None:
743
+ for i in range(len(k_cache)):
744
+ k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
745
+ v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
746
+
747
+ if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
748
+ print("use early stop num:", early_stop_num)
749
+ stop = True
750
+ for i, batch_index in enumerate(batch_idx_map):
751
+ batch_index = batch_idx_map[i]
752
+ idx_list[batch_index] = idx
753
+ y_list[batch_index] = y[i, :-1]
754
+
755
+ if None not in idx_list:
756
+ stop = True
757
+
758
+ if stop:
759
+ if y.shape[1] == 0:
760
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
761
+ print("bad zero prediction")
762
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
763
+ break
764
+
765
+ ####################### update next step ###################################
766
+ y_emb = self.ar_audio_embedding(y[:, -1:])
767
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
768
+ :, y_len + idx
769
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
770
+
771
+ if None in idx_list:
772
+ for i in range(x.shape[0]):
773
+ if idx_list[i] is None:
774
+ idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
775
+
776
+ if ref_free:
777
+ return y_list, [0] * x.shape[0]
778
+ # print(idx_list)
779
+ return y_list, idx_list
780
+
781
+ def infer_panel_naive_batched(
782
+ self,
783
+ x: List[torch.LongTensor], #####全部文本token
784
+ x_lens: torch.LongTensor,
785
+ prompts: torch.LongTensor, ####参考音频token
786
+ bert_feature: List[torch.LongTensor],
787
+ top_k: int = -100,
788
+ top_p: int = 100,
789
+ early_stop_num: int = -1,
790
+ temperature: float = 1.0,
791
+ repetition_penalty: float = 1.35,
792
+ **kwargs,
793
+ ):
794
+ y_list = []
795
+ idx_list = []
796
+ for i in range(len(x)):
797
+ y, idx = self.infer_panel_naive(
798
+ x[i].unsqueeze(0),
799
+ x_lens[i],
800
+ prompts[i].unsqueeze(0) if prompts is not None else None,
801
+ bert_feature[i].unsqueeze(0),
802
+ top_k,
803
+ top_p,
804
+ early_stop_num,
805
+ temperature,
806
+ repetition_penalty,
807
+ **kwargs,
808
+ )
809
+ y_list.append(y[0])
810
+ idx_list.append(idx)
811
+
812
+ return y_list, idx_list
813
+
814
+ def infer_panel_naive(
815
+ self,
816
+ x: torch.LongTensor, #####全部文本token
817
+ x_lens: torch.LongTensor,
818
+ prompts: torch.LongTensor, ####参考音频token
819
+ bert_feature: torch.LongTensor,
820
+ top_k: int = -100,
821
+ top_p: int = 100,
822
+ early_stop_num: int = -1,
823
+ temperature: float = 1.0,
824
+ repetition_penalty: float = 1.35,
825
+ **kwargs,
826
+ ):
827
+ x = self.ar_text_embedding(x)
828
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
829
+ x = self.ar_text_position(x)
830
+
831
+ # AR Decoder
832
+ y = prompts
833
+
834
+ x_len = x.shape[1]
835
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
836
+ stop = False
837
+ # print(1111111,self.num_layers)
838
+
839
+ k_cache = None
840
+ v_cache = None
841
+ ################### first step ##########################
842
+ if y is not None:
843
+ y_emb = self.ar_audio_embedding(y)
844
+ y_len = y_emb.shape[1]
845
+ prefix_len = y.shape[1]
846
+ y_pos = self.ar_audio_position(y_emb)
847
+ xy_pos = torch.concat([x, y_pos], dim=1)
848
+ ref_free = False
849
+ else:
850
+ y_emb = None
851
+ y_len = 0
852
+ prefix_len = 0
853
+ y_pos = None
854
+ xy_pos = x
855
+ y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
856
+ ref_free = True
857
+
858
+ bsz = x.shape[0]
859
+ src_len = x_len + y_len
860
+ x_attn_mask_pad = F.pad(
861
+ x_attn_mask,
862
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
863
+ value=True,
864
+ )
865
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
866
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
867
+ (x_len, 0),
868
+ value=False,
869
+ )
870
+ xy_attn_mask = (
871
+ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
872
+ .unsqueeze(0)
873
+ .expand(bsz * self.num_head, -1, -1)
874
+ .view(bsz, self.num_head, src_len, src_len)
875
+ .to(device=x.device, dtype=torch.bool)
876
+ )
877
+
878
+ for idx in tqdm(range(1500)):
879
+ if xy_attn_mask is not None:
880
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
881
+ else:
882
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
883
+
884
+ logits = self.ar_predict_layer(xy_dec[:, -1])
885
+
886
+ if idx == 0:
887
+ xy_attn_mask = None
888
+ if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
889
+ logits = logits[:, :-1]
890
+
891
+ samples = sample(
892
+ logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
893
+ )[0]
894
+
895
+ y = torch.concat([y, samples], dim=1)
896
+
897
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
898
+ print("use early stop num:", early_stop_num)
899
+ stop = True
900
+
901
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
902
+ stop = True
903
+ if stop:
904
+ if y.shape[1] == 0:
905
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
906
+ print("bad zero prediction")
907
+ print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
908
+ break
909
+
910
+ ####################### update next step ###################################
911
+ y_emb = self.ar_audio_embedding(y[:, -1:])
912
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
913
+ :, y_len + idx
914
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
915
+
916
+ if ref_free:
917
+ return y[:, :-1], 0
918
+ return y[:, :-1], idx
919
+
920
+ def infer_panel(
921
+ self,
922
+ x: torch.LongTensor, #####全部文本token
923
+ x_lens: torch.LongTensor,
924
+ prompts: torch.LongTensor, ####参考音频token
925
+ bert_feature: torch.LongTensor,
926
+ top_k: int = -100,
927
+ top_p: int = 100,
928
+ early_stop_num: int = -1,
929
+ temperature: float = 1.0,
930
+ repetition_penalty: float = 1.35,
931
+ **kwargs,
932
+ ):
933
+ return self.infer_panel_naive(
934
+ x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
935
+ )
GPT_SoVITS/AR/models/t2s_model_onnx.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torchmetrics.classification import MulticlassAccuracy
7
+
8
+ from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
9
+ from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
10
+
11
+ default_config = {
12
+ "embedding_dim": 512,
13
+ "hidden_dim": 512,
14
+ "num_head": 8,
15
+ "num_layers": 12,
16
+ "num_codebook": 8,
17
+ "p_dropout": 0.0,
18
+ "vocab_size": 1024 + 1,
19
+ "phoneme_vocab_size": 512,
20
+ "EOS": 1024,
21
+ }
22
+
23
+ inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
24
+
25
+
26
+ def logits_to_probs(
27
+ logits,
28
+ previous_tokens=None,
29
+ temperature: float = 1.0,
30
+ top_k=None,
31
+ top_p=None,
32
+ repetition_penalty: float = 1.0,
33
+ ):
34
+ previous_tokens = previous_tokens.squeeze()
35
+ if previous_tokens is not None and repetition_penalty != 1.0:
36
+ previous_tokens = previous_tokens.long()
37
+ score = torch.gather(logits, dim=0, index=previous_tokens)
38
+ score = torch.where(
39
+ score < 0,
40
+ score * repetition_penalty,
41
+ score / repetition_penalty,
42
+ )
43
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
44
+
45
+ if top_p is not None and top_p < 1.0:
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cum_probs = torch.cumsum(
48
+ torch.nn.functional.softmax(
49
+ sorted_logits,
50
+ dim=-1,
51
+ ),
52
+ dim=-1,
53
+ )
54
+ sorted_indices_to_remove = cum_probs > top_p
55
+ sorted_indices_to_remove[0] = False # keep at least one option
56
+ indices_to_remove = sorted_indices_to_remove.scatter(
57
+ dim=0,
58
+ index=sorted_indices,
59
+ src=sorted_indices_to_remove,
60
+ )
61
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
62
+
63
+ logits = logits / max(temperature, 1e-5)
64
+
65
+ if top_k is not None:
66
+ v, _ = torch.topk(logits, top_k)
67
+ pivot = v.select(-1, -1).unsqueeze(-1)
68
+ logits = torch.where(logits < pivot, inf_tensor_value, logits)
69
+
70
+ probs = torch.nn.functional.softmax(logits, dim=-1)
71
+ return probs
72
+
73
+
74
+ def multinomial_sample_one_no_sync(
75
+ probs_sort,
76
+ ): # Does multinomial sampling without a cuda synchronization
77
+ q = torch.randn_like(probs_sort)
78
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
79
+
80
+
81
+ def sample(
82
+ logits,
83
+ previous_tokens,
84
+ **sampling_kwargs,
85
+ ):
86
+ probs = logits_to_probs(
87
+ logits=logits,
88
+ previous_tokens=previous_tokens,
89
+ **sampling_kwargs,
90
+ )
91
+ idx_next = multinomial_sample_one_no_sync(probs)
92
+ return idx_next, probs
93
+
94
+
95
+ class OnnxEncoder(nn.Module):
96
+ def __init__(self, ar_text_embedding, bert_proj, ar_text_position):
97
+ super().__init__()
98
+ self.ar_text_embedding = ar_text_embedding
99
+ self.bert_proj = bert_proj
100
+ self.ar_text_position = ar_text_position
101
+
102
+ def forward(self, x, bert_feature):
103
+ x = self.ar_text_embedding(x)
104
+ x = x + self.bert_proj(bert_feature.transpose(1, 2))
105
+ return self.ar_text_position(x)
106
+
107
+
108
+ class T2SFirstStageDecoder(nn.Module):
109
+ def __init__(
110
+ self,
111
+ ar_audio_embedding,
112
+ ar_audio_position,
113
+ h,
114
+ ar_predict_layer,
115
+ loss_fct,
116
+ ar_accuracy_metric,
117
+ top_k,
118
+ early_stop_num,
119
+ num_layers,
120
+ ):
121
+ super().__init__()
122
+ self.ar_audio_embedding = ar_audio_embedding
123
+ self.ar_audio_position = ar_audio_position
124
+ self.h = h
125
+ self.ar_predict_layer = ar_predict_layer
126
+ self.loss_fct = loss_fct
127
+ self.ar_accuracy_metric = ar_accuracy_metric
128
+ self.top_k = top_k
129
+ self.early_stop_num = early_stop_num
130
+ self.num_layers = num_layers
131
+
132
+ def forward(self, x, prompt):
133
+ y = prompt
134
+ x_example = x[:, :, 0] * 0.0
135
+ # N, 1, 512
136
+ cache = {
137
+ "all_stage": self.num_layers,
138
+ "k": None,
139
+ "v": None,
140
+ "y_emb": None,
141
+ "first_infer": 1,
142
+ "stage": 0,
143
+ }
144
+
145
+ y_emb = self.ar_audio_embedding(y)
146
+
147
+ cache["y_emb"] = y_emb
148
+ y_pos = self.ar_audio_position(y_emb)
149
+
150
+ xy_pos = torch.concat([x, y_pos], dim=1)
151
+
152
+ y_example = y_pos[:, :, 0] * 0.0
153
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
154
+ y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
155
+ y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
156
+ torch.ones_like(
157
+ y_example.transpose(0, 1),
158
+ dtype=torch.int64,
159
+ ),
160
+ dim=0,
161
+ )
162
+ y_attn_mask = y_attn_mask > 0
163
+
164
+ x_y_pad = torch.matmul(x_example.transpose(0, 1), y_example).bool()
165
+ y_x_pad = torch.matmul(y_example.transpose(0, 1), x_example).bool()
166
+ x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
167
+ y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
168
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
169
+ cache["k"] = (
170
+ torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
171
+ .unsqueeze(1)
172
+ .repeat(self.num_layers, 1, 1, 1)
173
+ )
174
+ cache["v"] = (
175
+ torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
176
+ .unsqueeze(1)
177
+ .repeat(self.num_layers, 1, 1, 1)
178
+ )
179
+
180
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
181
+ logits = self.ar_predict_layer(xy_dec[:, -1])
182
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
183
+
184
+ y = torch.concat([y, samples], dim=1)
185
+
186
+ return y, cache["k"], cache["v"], cache["y_emb"], x_example
187
+
188
+
189
+ class T2SStageDecoder(nn.Module):
190
+ def __init__(
191
+ self,
192
+ ar_audio_embedding,
193
+ ar_audio_position,
194
+ h,
195
+ ar_predict_layer,
196
+ loss_fct,
197
+ ar_accuracy_metric,
198
+ top_k,
199
+ early_stop_num,
200
+ num_layers,
201
+ ):
202
+ super().__init__()
203
+ self.ar_audio_embedding = ar_audio_embedding
204
+ self.ar_audio_position = ar_audio_position
205
+ self.h = h
206
+ self.ar_predict_layer = ar_predict_layer
207
+ self.loss_fct = loss_fct
208
+ self.ar_accuracy_metric = ar_accuracy_metric
209
+ self.top_k = top_k
210
+ self.early_stop_num = early_stop_num
211
+ self.num_layers = num_layers
212
+
213
+ def forward(self, y, k, v, y_emb, x_example):
214
+ cache = {
215
+ "all_stage": self.num_layers,
216
+ "k": torch.nn.functional.pad(k, (0, 0, 0, 0, 0, 1)),
217
+ "v": torch.nn.functional.pad(v, (0, 0, 0, 0, 0, 1)),
218
+ "y_emb": y_emb,
219
+ "first_infer": 0,
220
+ "stage": 0,
221
+ }
222
+
223
+ y_emb = torch.cat(
224
+ [
225
+ cache["y_emb"],
226
+ self.ar_audio_embedding(y[:, -1:]),
227
+ ],
228
+ 1,
229
+ )
230
+ cache["y_emb"] = y_emb
231
+ y_pos = self.ar_audio_position(y_emb)
232
+
233
+ xy_pos = y_pos[:, -1:]
234
+
235
+ y_example = y_pos[:, :, 0] * 0.0
236
+
237
+ xy_attn_mask = torch.cat([x_example, y_example], dim=1)
238
+ xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
239
+
240
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
241
+ logits = self.ar_predict_layer(xy_dec[:, -1])
242
+ samples = sample(logits[0], y, top_k=self.top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
243
+
244
+ y = torch.concat([y, samples], dim=1)
245
+
246
+ return y, cache["k"], cache["v"], cache["y_emb"], logits, samples
247
+
248
+
249
+ class Text2SemanticDecoder(nn.Module):
250
+ def __init__(self, config, norm_first=False, top_k=3):
251
+ super(Text2SemanticDecoder, self).__init__()
252
+ self.model_dim = config["model"]["hidden_dim"]
253
+ self.embedding_dim = config["model"]["embedding_dim"]
254
+ self.num_head = config["model"]["head"]
255
+ self.num_layers = config["model"]["n_layer"]
256
+ self.norm_first = norm_first
257
+ self.vocab_size = config["model"]["vocab_size"]
258
+ self.phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
259
+ self.p_dropout = float(config["model"]["dropout"])
260
+ self.EOS = config["model"]["EOS"]
261
+ self.norm_first = norm_first
262
+ assert self.EOS == self.vocab_size - 1
263
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
264
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
265
+ self.ar_text_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
266
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
267
+ self.ar_audio_position = SinePositionalEmbedding(self.embedding_dim, dropout=0.1, scale=False, alpha=True)
268
+ self.h = TransformerEncoder(
269
+ TransformerEncoderLayer(
270
+ d_model=self.model_dim,
271
+ nhead=self.num_head,
272
+ dim_feedforward=self.model_dim * 4,
273
+ dropout=0.1,
274
+ batch_first=True,
275
+ norm_first=norm_first,
276
+ ),
277
+ num_layers=self.num_layers,
278
+ norm=LayerNorm(self.model_dim) if norm_first else None,
279
+ )
280
+ self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
281
+ self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
282
+ self.ar_accuracy_metric = MulticlassAccuracy(
283
+ self.vocab_size,
284
+ top_k=top_k,
285
+ average="micro",
286
+ multidim_average="global",
287
+ ignore_index=self.EOS,
288
+ )
289
+ self.top_k = torch.LongTensor([1])
290
+ self.early_stop_num = torch.LongTensor([-1])
291
+
292
+ def init_onnx(self):
293
+ self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
294
+ self.first_stage_decoder = T2SFirstStageDecoder(
295
+ self.ar_audio_embedding,
296
+ self.ar_audio_position,
297
+ self.h,
298
+ self.ar_predict_layer,
299
+ self.loss_fct,
300
+ self.ar_accuracy_metric,
301
+ self.top_k,
302
+ self.early_stop_num,
303
+ self.num_layers,
304
+ )
305
+ self.stage_decoder = T2SStageDecoder(
306
+ self.ar_audio_embedding,
307
+ self.ar_audio_position,
308
+ self.h,
309
+ self.ar_predict_layer,
310
+ self.loss_fct,
311
+ self.ar_accuracy_metric,
312
+ self.top_k,
313
+ self.early_stop_num,
314
+ self.num_layers,
315
+ )
316
+
317
+ def forward(self, x, prompts, bert_feature):
318
+ early_stop_num = self.early_stop_num
319
+ prefix_len = prompts.shape[1]
320
+
321
+ x = self.onnx_encoder(x, bert_feature)
322
+ y, k, v, y_emb, stage, x_example = self.first_stage_decoder(x, prompts)
323
+
324
+ stop = False
325
+ for idx in range(1, 1500):
326
+ enco = self.stage_decoder(y, k, v, y_emb, stage, x_example)
327
+ y, k, v, y_emb, stage, logits, samples = enco
328
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
329
+ stop = True
330
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
331
+ stop = True
332
+ if stop:
333
+ break
334
+ y[0, -1] = 0
335
+ return y, idx
336
+
337
+ def infer(self, x, prompts, bert_feature):
338
+ top_k = self.top_k
339
+ early_stop_num = self.early_stop_num
340
+
341
+ x = self.onnx_encoder(x, bert_feature)
342
+
343
+ y = prompts
344
+ prefix_len = y.shape[1]
345
+ x_len = x.shape[1]
346
+ x_example = x[:, :, 0] * 0.0
347
+ x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
348
+ x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
349
+
350
+ stop = False
351
+ cache = {
352
+ "all_stage": self.num_layers,
353
+ "k": [None] * self.num_layers,
354
+ "v": [None] * self.num_layers,
355
+ "y_emb": None,
356
+ "first_infer": 1,
357
+ "stage": 0,
358
+ }
359
+ for idx in range(1500):
360
+ if cache["first_infer"] == 1:
361
+ y_emb = self.ar_audio_embedding(y)
362
+ else:
363
+ y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
364
+ cache["y_emb"] = y_emb
365
+ y_pos = self.ar_audio_position(y_emb)
366
+ if cache["first_infer"] == 1:
367
+ xy_pos = torch.concat([x, y_pos], dim=1)
368
+ else:
369
+ xy_pos = y_pos[:, -1:]
370
+ y_len = y_pos.shape[1]
371
+ if cache["first_infer"] == 1:
372
+ x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
373
+ y_attn_mask = F.pad(
374
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
375
+ (x_len, 0),
376
+ value=False,
377
+ )
378
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
379
+ else:
380
+ xy_attn_mask = torch.zeros((1, x_len + y_len), dtype=torch.bool)
381
+ xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
382
+ logits = self.ar_predict_layer(xy_dec[:, -1])
383
+ samples = sample(logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35)[0].unsqueeze(0)
384
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
385
+ stop = True
386
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
387
+ stop = True
388
+ if stop:
389
+ if prompts.shape[1] == y.shape[1]:
390
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
391
+ break
392
+ y = torch.concat([y, samples], dim=1)
393
+ cache["first_infer"] = 0
394
+ return y, idx
GPT_SoVITS/AR/models/utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def sequence_mask(length, max_length=None):
10
+ if max_length is None:
11
+ max_length = length.max()
12
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
13
+ return x.unsqueeze(0) < length.unsqueeze(1)
14
+
15
+
16
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
17
+ """
18
+ Args:
19
+ lengths:
20
+ A 1-D tensor containing sentence lengths.
21
+ max_len:
22
+ The length of masks.
23
+ Returns:
24
+ Return a 2-D bool tensor, where masked positions
25
+ are filled with `True` and non-masked positions are
26
+ filled with `False`.
27
+
28
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
29
+ #>>> make_pad_mask(lengths)
30
+ tensor([[False, True, True, True, True],
31
+ [False, False, False, True, True],
32
+ [False, False, True, True, True],
33
+ [False, False, False, False, False]])
34
+ """
35
+ assert lengths.ndim == 1, lengths.ndim
36
+ max_len = max(max_len, lengths.max())
37
+ n = lengths.size(0)
38
+ seq_range = torch.arange(0, max_len, device=lengths.device)
39
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
40
+
41
+ return expaned_lengths >= lengths.unsqueeze(-1)
42
+
43
+
44
+ def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
45
+ """
46
+ Args:
47
+ lengths:
48
+ A 1-D tensor containing sentence lengths.
49
+ max_len:
50
+ The length of masks.
51
+ Returns:
52
+ Return a 2-D bool tensor, where masked positions
53
+ are filled with `True` and non-masked positions are
54
+ filled with `False`.
55
+
56
+ #>>> lengths = torch.tensor([1, 3, 2, 5])
57
+ #>>> make_pad_mask(lengths)
58
+ tensor(
59
+ [
60
+ [True, True, False],
61
+ [True, False, False],
62
+ [True, True, False],
63
+ ...
64
+ ]
65
+ )
66
+ """
67
+ assert lengths.ndim == 1, lengths.ndim
68
+ max_len = max(max_len, lengths.max())
69
+ n = lengths.size(0)
70
+ seq_range = torch.arange(0, max_len, device=lengths.device)
71
+ expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
72
+ expaned_lengths -= (max_len - lengths).unsqueeze(-1)
73
+
74
+ return expaned_lengths < 0
75
+
76
+
77
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
78
+ def top_k_top_p_filtering(
79
+ logits,
80
+ top_k=0,
81
+ top_p=1.0,
82
+ filter_value=-float("Inf"),
83
+ min_tokens_to_keep=1,
84
+ ):
85
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
86
+ Args:
87
+ logits: logits distribution shape (batch size, vocabulary size)
88
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
89
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
90
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
91
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
92
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
93
+ """
94
+ if top_k > 0:
95
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
96
+ # Remove all tokens with a probability less than the last token of the top-k
97
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
98
+ logits[indices_to_remove] = filter_value
99
+
100
+ if top_p < 1.0:
101
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
102
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
103
+
104
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
105
+ sorted_indices_to_remove = cumulative_probs > top_p
106
+ if min_tokens_to_keep > 1:
107
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
108
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
109
+ # Shift the indices to the right to keep also the first token above the threshold
110
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
111
+ sorted_indices_to_remove[..., 0] = 0
112
+
113
+ # scatter sorted tensors to original indexing
114
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
115
+ logits[indices_to_remove] = filter_value
116
+ return logits
117
+
118
+
119
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
120
+ # temperature: (`optional`) float
121
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
122
+ # top_k: (`optional`) int
123
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
124
+ # top_p: (`optional`) float
125
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
126
+
127
+ # Temperature (higher temperature => more likely to sample low probability tokens)
128
+ if temperature != 1.0:
129
+ logits = logits / temperature
130
+ # Top-p/top-k filtering
131
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
132
+ # Sample
133
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
134
+ return token
135
+
136
+
137
+ from typing import Optional
138
+
139
+
140
+ def multinomial_sample_one_no_sync(
141
+ probs_sort,
142
+ ): # Does multinomial sampling without a cuda synchronization
143
+ q = torch.empty_like(probs_sort).exponential_(1)
144
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
145
+
146
+
147
+ def logits_to_probs(
148
+ logits,
149
+ previous_tokens: Optional[torch.Tensor] = None,
150
+ temperature: float = 1.0,
151
+ top_k: Optional[int] = None,
152
+ top_p: Optional[int] = None,
153
+ repetition_penalty: float = 1.0,
154
+ ):
155
+ # if previous_tokens is not None:
156
+ # previous_tokens = previous_tokens.squeeze()
157
+ # print(logits.shape,previous_tokens.shape)
158
+ # pdb.set_trace()
159
+ if previous_tokens is not None and repetition_penalty != 1.0:
160
+ previous_tokens = previous_tokens.long()
161
+ score = torch.gather(logits, dim=1, index=previous_tokens)
162
+ score = torch.where(
163
+ score < 0,
164
+ score * repetition_penalty,
165
+ score / repetition_penalty,
166
+ )
167
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
168
+
169
+ if top_p is not None and top_p < 1.0:
170
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
171
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
172
+ sorted_indices_to_remove = cum_probs > top_p
173
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
174
+ indices_to_remove = sorted_indices_to_remove.scatter(
175
+ dim=1,
176
+ index=sorted_indices,
177
+ src=sorted_indices_to_remove,
178
+ )
179
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
180
+
181
+ logits = logits / max(temperature, 1e-5)
182
+
183
+ if top_k is not None:
184
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
185
+ pivot = v[:, -1].unsqueeze(-1)
186
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
187
+
188
+ probs = torch.nn.functional.softmax(logits, dim=-1)
189
+ return probs
190
+
191
+
192
+ def sample(
193
+ logits,
194
+ previous_tokens: Optional[torch.Tensor] = None,
195
+ **sampling_kwargs,
196
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
197
+ probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
198
+ idx_next = multinomial_sample_one_no_sync(probs)
199
+ return idx_next, probs
200
+
201
+
202
+ def dpo_loss(
203
+ policy_chosen_logps: torch.FloatTensor,
204
+ policy_rejected_logps: torch.FloatTensor,
205
+ reference_chosen_logps: torch.FloatTensor,
206
+ reference_rejected_logps: torch.FloatTensor,
207
+ beta: float,
208
+ reference_free: bool = False,
209
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
210
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
211
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
212
+
213
+ if reference_free:
214
+ ref_logratios = 0
215
+
216
+ logits = pi_logratios - ref_logratios
217
+
218
+ losses = -F.logsigmoid(beta * logits)
219
+ chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
220
+ rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
221
+
222
+ return losses.mean(), chosen_rewards, rejected_rewards
223
+
224
+
225
+ def get_batch_logps(
226
+ logits_target: torch.FloatTensor,
227
+ logits_reject: torch.FloatTensor,
228
+ labels_target: torch.LongTensor,
229
+ labels_reject: torch.LongTensor,
230
+ average_log_prob: bool = False,
231
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
232
+ # dummy token; we'll ignore the losses on these tokens later
233
+
234
+ per_token_logps_target = torch.gather(
235
+ logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
236
+ ).squeeze(2)
237
+ per_token_logps_reject = torch.gather(
238
+ logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
239
+ ).squeeze(2)
240
+
241
+ return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
242
+
243
+
244
+ def make_reject_y(y_o, y_lens):
245
+ def repeat_P(y):
246
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
247
+ pre = y[: range_idx[0]]
248
+ shf = y[range_idx[1] :]
249
+ range_text = y[range_idx[0] : range_idx[1]]
250
+ new_y = torch.cat([pre, range_text, range_text, shf])
251
+ return new_y
252
+
253
+ def lost_P(y):
254
+ range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
255
+ pre = y[: range_idx[0]]
256
+ shf = y[range_idx[1] :]
257
+ range_text = y[range_idx[0] : range_idx[1]]
258
+ new_y = torch.cat([pre, shf])
259
+ return new_y
260
+
261
+ bs = len(y_lens)
262
+ reject_y = []
263
+ reject_y_lens = []
264
+ for b in range(bs):
265
+ process_item_idx = torch.randint(0, 1, size=(1,))[0]
266
+ if process_item_idx == 0:
267
+ new_y = repeat_P(y_o[b])
268
+ reject_y.append(new_y)
269
+ reject_y_lens.append(len(new_y))
270
+ elif process_item_idx == 1:
271
+ new_y = lost_P(y_o[b])
272
+ reject_y.append(new_y)
273
+ reject_y_lens.append(len(new_y))
274
+ max_length = max(reject_y_lens)
275
+ for b in range(bs):
276
+ pad_length = max_length - reject_y_lens[b]
277
+ reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
278
+
279
+ reject_y = torch.stack(reject_y, dim=0)
280
+ reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
281
+
282
+ return reject_y, reject_y_lens
GPT_SoVITS/AR/modules/__init__.py ADDED
File without changes
GPT_SoVITS/AR/modules/activation.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+
12
+ from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
13
+
14
+ F.multi_head_attention_forward = multi_head_attention_forward_patched
15
+
16
+
17
+ class MultiheadAttention(Module):
18
+ r"""Allows the model to jointly attend to information
19
+ from different representation subspaces as described in the paper:
20
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
21
+
22
+ Multi-Head Attention is defined as:
23
+
24
+ .. math::
25
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
26
+
27
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
28
+
29
+ ``forward()`` will use a special optimized implementation if all of the following
30
+ conditions are met:
31
+
32
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
33
+ restriction will be loosened in the future.)
34
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
35
+ - training is disabled (using ``.eval()``)
36
+ - dropout is 0
37
+ - ``add_bias_kv`` is ``False``
38
+ - ``add_zero_attn`` is ``False``
39
+ - ``batch_first`` is ``True`` and the input is batched
40
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
41
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
42
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
43
+ nor ``attn_mask`` is passed
44
+
45
+ If the optimized implementation is in use, a
46
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
47
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
48
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
49
+ will be returned, and an additional speedup proportional to the fraction of the input
50
+ that is padding can be expected.
51
+
52
+ Args:
53
+ embed_dim: Total dimension of the model.
54
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
55
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
56
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
57
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
58
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
59
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
60
+ Default: ``False``.
61
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
62
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
63
+ batch_first: If ``True``, then the input and output tensors are provided
64
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
65
+
66
+ Examples::
67
+
68
+ >>> # xdoctest: +SKIP
69
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
70
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
71
+
72
+ """
73
+
74
+ __constants__ = ["batch_first"]
75
+ bias_k: Optional[torch.Tensor]
76
+ bias_v: Optional[torch.Tensor]
77
+
78
+ def __init__(
79
+ self,
80
+ embed_dim,
81
+ num_heads,
82
+ dropout=0.0,
83
+ bias=True,
84
+ add_bias_kv=False,
85
+ add_zero_attn=False,
86
+ kdim=None,
87
+ vdim=None,
88
+ batch_first=False,
89
+ linear1_cls=Linear,
90
+ linear2_cls=Linear,
91
+ device=None,
92
+ dtype=None,
93
+ ) -> None:
94
+ factory_kwargs = {"device": device, "dtype": dtype}
95
+ super(MultiheadAttention, self).__init__()
96
+ self.embed_dim = embed_dim
97
+ self.kdim = kdim if kdim is not None else embed_dim
98
+ self.vdim = vdim if vdim is not None else embed_dim
99
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
100
+
101
+ self.num_heads = num_heads
102
+ self.dropout = dropout
103
+ self.batch_first = batch_first
104
+ self.head_dim = embed_dim // num_heads
105
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
106
+
107
+ if add_bias_kv:
108
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
109
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
110
+ else:
111
+ self.bias_k = self.bias_v = None
112
+
113
+ if linear1_cls == Linear:
114
+ if not self._qkv_same_embed_dim:
115
+ self.q_proj_weight = Parameter(
116
+ torch.empty((embed_dim, embed_dim), **factory_kwargs),
117
+ )
118
+ self.k_proj_weight = Parameter(
119
+ torch.empty((embed_dim, self.kdim), **factory_kwargs),
120
+ )
121
+ self.v_proj_weight = Parameter(
122
+ torch.empty((embed_dim, self.vdim), **factory_kwargs),
123
+ )
124
+ self.register_parameter("in_proj_weight", None)
125
+ else:
126
+ self.in_proj_weight = Parameter(
127
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
128
+ )
129
+ self.register_parameter("q_proj_weight", None)
130
+ self.register_parameter("k_proj_weight", None)
131
+ self.register_parameter("v_proj_weight", None)
132
+
133
+ if bias:
134
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
135
+ else:
136
+ self.register_parameter("in_proj_bias", None)
137
+ self.out_proj = NonDynamicallyQuantizableLinear(
138
+ embed_dim,
139
+ embed_dim,
140
+ bias=bias,
141
+ **factory_kwargs,
142
+ )
143
+
144
+ self._reset_parameters()
145
+ else:
146
+ if not self._qkv_same_embed_dim:
147
+ raise NotImplementedError
148
+ else:
149
+ self.in_proj_linear = linear1_cls(
150
+ embed_dim,
151
+ 3 * embed_dim,
152
+ bias=bias,
153
+ **factory_kwargs,
154
+ )
155
+ self.in_proj_weight = self.in_proj_linear.weight
156
+
157
+ self.register_parameter("q_proj_weight", None)
158
+ self.register_parameter("k_proj_weight", None)
159
+ self.register_parameter("v_proj_weight", None)
160
+
161
+ if bias:
162
+ self.in_proj_bias = self.in_proj_linear.bias
163
+ else:
164
+ self.register_parameter("in_proj_bias", None)
165
+
166
+ self.out_proj = linear2_cls(
167
+ embed_dim,
168
+ embed_dim,
169
+ bias=bias,
170
+ **factory_kwargs,
171
+ )
172
+
173
+ if self.bias_k is not None:
174
+ xavier_normal_(self.bias_k)
175
+ if self.bias_v is not None:
176
+ xavier_normal_(self.bias_v)
177
+
178
+ self.add_zero_attn = add_zero_attn
179
+
180
+ def _reset_parameters(self):
181
+ if self._qkv_same_embed_dim:
182
+ xavier_uniform_(self.in_proj_weight)
183
+ else:
184
+ xavier_uniform_(self.q_proj_weight)
185
+ xavier_uniform_(self.k_proj_weight)
186
+ xavier_uniform_(self.v_proj_weight)
187
+
188
+ if self.in_proj_bias is not None:
189
+ constant_(self.in_proj_bias, 0.0)
190
+ constant_(self.out_proj.bias, 0.0)
191
+
192
+ if self.bias_k is not None:
193
+ xavier_normal_(self.bias_k)
194
+ if self.bias_v is not None:
195
+ xavier_normal_(self.bias_v)
196
+
197
+ def __setstate__(self, state):
198
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
199
+ if "_qkv_same_embed_dim" not in state:
200
+ state["_qkv_same_embed_dim"] = True
201
+
202
+ super(MultiheadAttention, self).__setstate__(state)
203
+
204
+ def forward(
205
+ self,
206
+ query: Tensor,
207
+ key: Tensor,
208
+ value: Tensor,
209
+ key_padding_mask: Optional[Tensor] = None,
210
+ need_weights: bool = True,
211
+ attn_mask: Optional[Tensor] = None,
212
+ average_attn_weights: bool = True,
213
+ cache=None,
214
+ ) -> Tuple[Tensor, Optional[Tensor]]:
215
+ r"""
216
+ Args:
217
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
218
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
219
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
220
+ Queries are compared against key-value pairs to produce the output.
221
+ See "Attention Is All You Need" for more details.
222
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
223
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
224
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
225
+ See "Attention Is All You Need" for more details.
226
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
227
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
228
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
229
+ See "Attention Is All You Need" for more details.
230
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
231
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
232
+ Binary and byte masks are supported.
233
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
234
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
235
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
236
+ Default: ``True``.
237
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
238
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
239
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
240
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
241
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
242
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
243
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
244
+ the attention weight.
245
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
246
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
247
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
248
+
249
+ Outputs:
250
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
251
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
252
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
253
+ embedding dimension ``embed_dim``.
254
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
255
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
256
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
257
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
258
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
259
+
260
+ .. note::
261
+ `batch_first` argument is ignored for unbatched inputs.
262
+ """
263
+ is_batched = query.dim() == 3
264
+ if key_padding_mask is not None:
265
+ _kpm_dtype = key_padding_mask.dtype
266
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
267
+ key_padding_mask,
268
+ ):
269
+ raise AssertionError("only bool and floating types of key_padding_mask are supported")
270
+ why_not_fast_path = ""
271
+ if not is_batched:
272
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
273
+ elif query is not key or key is not value:
274
+ # When lifting this restriction, don't forget to either
275
+ # enforce that the dtypes all match or test cases where
276
+ # they don't!
277
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
278
+ elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
279
+ why_not_fast_path = (
280
+ f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
281
+ )
282
+ elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
283
+ # this case will fail anyway, but at least they'll get a useful error message.
284
+ why_not_fast_path = (
285
+ f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
286
+ )
287
+ elif self.training:
288
+ why_not_fast_path = "training is enabled"
289
+ elif not self.batch_first:
290
+ why_not_fast_path = "batch_first was not True"
291
+ elif self.bias_k is not None:
292
+ why_not_fast_path = "self.bias_k was not None"
293
+ elif self.bias_v is not None:
294
+ why_not_fast_path = "self.bias_v was not None"
295
+ elif self.dropout:
296
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
297
+ elif self.add_zero_attn:
298
+ why_not_fast_path = "add_zero_attn was enabled"
299
+ elif not self._qkv_same_embed_dim:
300
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
301
+ elif attn_mask is not None:
302
+ why_not_fast_path = "attn_mask was not None"
303
+ elif query.is_nested and key_padding_mask is not None:
304
+ why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
305
+ elif self.num_heads % 2 == 1:
306
+ why_not_fast_path = "num_heads is odd"
307
+ elif torch.is_autocast_enabled():
308
+ why_not_fast_path = "autocast is enabled"
309
+
310
+ if not why_not_fast_path:
311
+ tensor_args = (
312
+ query,
313
+ key,
314
+ value,
315
+ self.in_proj_weight,
316
+ self.in_proj_bias,
317
+ self.out_proj.weight,
318
+ self.out_proj.bias,
319
+ )
320
+ # We have to use list comprehensions below because TorchScript does not support
321
+ # generator expressions.
322
+ if torch.overrides.has_torch_function(tensor_args):
323
+ why_not_fast_path = "some Tensor argument has_torch_function"
324
+ elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
325
+ why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
326
+ elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
327
+ why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
328
+ if not why_not_fast_path:
329
+ return torch._native_multi_head_attention(
330
+ query,
331
+ key,
332
+ value,
333
+ self.embed_dim,
334
+ self.num_heads,
335
+ self.in_proj_weight,
336
+ self.in_proj_bias,
337
+ self.out_proj.weight,
338
+ self.out_proj.bias,
339
+ key_padding_mask if key_padding_mask is not None else attn_mask,
340
+ need_weights,
341
+ average_attn_weights,
342
+ 1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
343
+ )
344
+
345
+ any_nested = query.is_nested or key.is_nested or value.is_nested
346
+ assert not any_nested, (
347
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
348
+ + f"The fast path was not hit because {why_not_fast_path}"
349
+ )
350
+
351
+ if self.batch_first and is_batched:
352
+ # make sure that the transpose op does not affect the "is" property
353
+ if key is value:
354
+ if query is key:
355
+ query = key = value = query.transpose(1, 0)
356
+ else:
357
+ query, key = [x.transpose(1, 0) for x in (query, key)]
358
+ value = key
359
+ else:
360
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
361
+
362
+ if not self._qkv_same_embed_dim:
363
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
364
+ query,
365
+ key,
366
+ value,
367
+ self.embed_dim,
368
+ self.num_heads,
369
+ self.in_proj_weight,
370
+ self.in_proj_bias,
371
+ self.bias_k,
372
+ self.bias_v,
373
+ self.add_zero_attn,
374
+ self.dropout,
375
+ self.out_proj.weight,
376
+ self.out_proj.bias,
377
+ training=self.training,
378
+ key_padding_mask=key_padding_mask,
379
+ need_weights=need_weights,
380
+ attn_mask=attn_mask,
381
+ use_separate_proj_weight=True,
382
+ q_proj_weight=self.q_proj_weight,
383
+ k_proj_weight=self.k_proj_weight,
384
+ v_proj_weight=self.v_proj_weight,
385
+ average_attn_weights=average_attn_weights,
386
+ cache=cache,
387
+ )
388
+ else:
389
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
390
+ query,
391
+ key,
392
+ value,
393
+ self.embed_dim,
394
+ self.num_heads,
395
+ self.in_proj_weight,
396
+ self.in_proj_bias,
397
+ self.bias_k,
398
+ self.bias_v,
399
+ self.add_zero_attn,
400
+ self.dropout,
401
+ self.out_proj.weight,
402
+ self.out_proj.bias,
403
+ training=self.training,
404
+ key_padding_mask=key_padding_mask,
405
+ need_weights=need_weights,
406
+ attn_mask=attn_mask,
407
+ average_attn_weights=average_attn_weights,
408
+ cache=cache,
409
+ )
410
+ if self.batch_first and is_batched:
411
+ return attn_output.transpose(1, 0), attn_output_weights
412
+ else:
413
+ return attn_output, attn_output_weights
GPT_SoVITS/AR/modules/activation_onnx.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
9
+ from torch.nn.parameter import Parameter
10
+
11
+ from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
12
+
13
+
14
+ class MultiheadAttention(Module):
15
+ __constants__ = ["batch_first"]
16
+ bias_k: Optional[torch.Tensor]
17
+ bias_v: Optional[torch.Tensor]
18
+
19
+ def __init__(
20
+ self,
21
+ embed_dim,
22
+ num_heads,
23
+ dropout=0.0,
24
+ bias=True,
25
+ add_bias_kv=False,
26
+ add_zero_attn=False,
27
+ kdim=None,
28
+ vdim=None,
29
+ batch_first=False,
30
+ linear1_cls=Linear,
31
+ linear2_cls=Linear,
32
+ device=None,
33
+ dtype=None,
34
+ ) -> None:
35
+ factory_kwargs = {"device": device, "dtype": dtype}
36
+ super(MultiheadAttention, self).__init__()
37
+ self.embed_dim = embed_dim
38
+ self.kdim = kdim if kdim is not None else embed_dim
39
+ self.vdim = vdim if vdim is not None else embed_dim
40
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
41
+
42
+ self.num_heads = num_heads
43
+ self.dropout = dropout
44
+ self.batch_first = batch_first
45
+ self.head_dim = embed_dim // num_heads
46
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
47
+
48
+ if add_bias_kv:
49
+ self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
50
+ self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
51
+ else:
52
+ self.bias_k = self.bias_v = None
53
+
54
+ if linear1_cls == Linear:
55
+ if not self._qkv_same_embed_dim:
56
+ self.q_proj_weight = Parameter(
57
+ torch.empty(
58
+ (embed_dim, embed_dim),
59
+ **factory_kwargs,
60
+ )
61
+ )
62
+ self.k_proj_weight = Parameter(
63
+ torch.empty(
64
+ (embed_dim, self.kdim),
65
+ **factory_kwargs,
66
+ )
67
+ )
68
+ self.v_proj_weight = Parameter(
69
+ torch.empty(
70
+ (embed_dim, self.vdim),
71
+ **factory_kwargs,
72
+ )
73
+ )
74
+ self.register_parameter("in_proj_weight", None)
75
+ else:
76
+ self.in_proj_weight = Parameter(
77
+ torch.empty(
78
+ (3 * embed_dim, embed_dim),
79
+ **factory_kwargs,
80
+ )
81
+ )
82
+ self.register_parameter("q_proj_weight", None)
83
+ self.register_parameter("k_proj_weight", None)
84
+ self.register_parameter("v_proj_weight", None)
85
+
86
+ if bias:
87
+ self.in_proj_bias = Parameter(
88
+ torch.empty(3 * embed_dim, **factory_kwargs),
89
+ )
90
+ else:
91
+ self.register_parameter("in_proj_bias", None)
92
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
93
+
94
+ self._reset_parameters()
95
+ else:
96
+ if not self._qkv_same_embed_dim:
97
+ raise NotImplementedError
98
+ else:
99
+ self.in_proj_linear = linear1_cls(
100
+ embed_dim,
101
+ 3 * embed_dim,
102
+ bias=bias,
103
+ **factory_kwargs,
104
+ )
105
+ self.in_proj_weight = self.in_proj_linear.weight
106
+
107
+ self.register_parameter("q_proj_weight", None)
108
+ self.register_parameter("k_proj_weight", None)
109
+ self.register_parameter("v_proj_weight", None)
110
+
111
+ if bias:
112
+ self.in_proj_bias = self.in_proj_linear.bias
113
+ else:
114
+ self.register_parameter("in_proj_bias", None)
115
+
116
+ self.out_proj = linear2_cls(
117
+ embed_dim,
118
+ embed_dim,
119
+ bias=bias,
120
+ **factory_kwargs,
121
+ )
122
+
123
+ if self.bias_k is not None:
124
+ xavier_normal_(self.bias_k)
125
+ if self.bias_v is not None:
126
+ xavier_normal_(self.bias_v)
127
+
128
+ self.add_zero_attn = add_zero_attn
129
+
130
+ def _reset_parameters(self):
131
+ if self._qkv_same_embed_dim:
132
+ xavier_uniform_(self.in_proj_weight)
133
+ else:
134
+ xavier_uniform_(self.q_proj_weight)
135
+ xavier_uniform_(self.k_proj_weight)
136
+ xavier_uniform_(self.v_proj_weight)
137
+
138
+ if self.in_proj_bias is not None:
139
+ constant_(self.in_proj_bias, 0.0)
140
+ constant_(self.out_proj.bias, 0.0)
141
+
142
+ if self.bias_k is not None:
143
+ xavier_normal_(self.bias_k)
144
+ if self.bias_v is not None:
145
+ xavier_normal_(self.bias_v)
146
+
147
+ def __setstate__(self, state):
148
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
149
+ if "_qkv_same_embed_dim" not in state:
150
+ state["_qkv_same_embed_dim"] = True
151
+
152
+ super(MultiheadAttention, self).__setstate__(state)
153
+
154
+ def forward(
155
+ self,
156
+ query: Tensor,
157
+ key: Tensor,
158
+ value: Tensor,
159
+ key_padding_mask: Optional[Tensor] = None,
160
+ need_weights: bool = True,
161
+ attn_mask: Optional[Tensor] = None,
162
+ average_attn_weights: bool = True,
163
+ cache=None,
164
+ ) -> Tuple[Tensor, Optional[Tensor]]:
165
+ any_nested = query.is_nested or key.is_nested or value.is_nested
166
+ query = key = value = query.transpose(1, 0)
167
+ attn_output = multi_head_attention_forward_patched(
168
+ query,
169
+ key,
170
+ value,
171
+ self.embed_dim,
172
+ self.num_heads,
173
+ self.in_proj_weight,
174
+ self.in_proj_bias,
175
+ self.bias_k,
176
+ self.bias_v,
177
+ self.add_zero_attn,
178
+ self.dropout,
179
+ self.out_proj.weight,
180
+ self.out_proj.bias,
181
+ training=self.training,
182
+ key_padding_mask=key_padding_mask,
183
+ need_weights=need_weights,
184
+ attn_mask=attn_mask,
185
+ average_attn_weights=average_attn_weights,
186
+ cache=cache,
187
+ )
188
+ return attn_output.transpose(1, 0)
GPT_SoVITS/AR/modules/embedding.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+
50
+ self.reverse = False
51
+ self.pe = None
52
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
53
+
54
+ def extend_pe(self, x):
55
+ """Reset the positional encodings."""
56
+ if self.pe is not None:
57
+ if self.pe.size(1) >= x.size(1):
58
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
59
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
60
+ return
61
+ pe = torch.zeros(x.size(1), self.embedding_dim)
62
+ if self.reverse:
63
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
64
+ else:
65
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
+ div_term = torch.exp(
67
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
68
+ )
69
+ pe[:, 0::2] = torch.sin(position * div_term)
70
+ pe[:, 1::2] = torch.cos(position * div_term)
71
+ pe = pe.unsqueeze(0)
72
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ self.extend_pe(x)
76
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
77
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
78
+ return self.dropout(output)
GPT_SoVITS/AR/modules/embedding_onnx.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class TokenEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embedding_dim: int,
12
+ vocab_size: int,
13
+ dropout: float = 0.0,
14
+ ):
15
+ super().__init__()
16
+
17
+ self.vocab_size = vocab_size
18
+ self.embedding_dim = embedding_dim
19
+
20
+ self.dropout = torch.nn.Dropout(p=dropout)
21
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
22
+
23
+ @property
24
+ def weight(self) -> torch.Tensor:
25
+ return self.word_embeddings.weight
26
+
27
+ def embedding(self, index: int) -> torch.Tensor:
28
+ return self.word_embeddings.weight[index : index + 1]
29
+
30
+ def forward(self, x: torch.Tensor):
31
+ x = self.word_embeddings(x)
32
+ x = self.dropout(x)
33
+ return x
34
+
35
+
36
+ class SinePositionalEmbedding(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embedding_dim: int,
40
+ dropout: float = 0.0,
41
+ scale: bool = False,
42
+ alpha: bool = False,
43
+ ):
44
+ super().__init__()
45
+ self.embedding_dim = embedding_dim
46
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
47
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
48
+ self.dropout = torch.nn.Dropout(p=dropout)
49
+ self.reverse = False
50
+ self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
51
+
52
+ def extend_pe(self, x):
53
+ position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
54
+ scpe = (position * self.div_term).unsqueeze(0)
55
+ pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
56
+ pe = pe.contiguous().view(1, -1, self.embedding_dim)
57
+ return pe
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ pe = self.extend_pe(x)
61
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
62
+ output = output * self.x_scale + self.alpha * pe
63
+ return self.dropout(output)
GPT_SoVITS/AR/modules/lr_schedulers.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import math
4
+
5
+ import torch
6
+ from matplotlib import pyplot as plt
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+
10
+
11
+ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
12
+ """
13
+ Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ optimizer,
19
+ init_lr,
20
+ peak_lr,
21
+ end_lr,
22
+ warmup_steps=10000,
23
+ total_steps=400000,
24
+ current_step=0,
25
+ ):
26
+ self.init_lr = init_lr
27
+ self.peak_lr = peak_lr
28
+ self.end_lr = end_lr
29
+ self.optimizer = optimizer
30
+ self._warmup_rate = (peak_lr - init_lr) / warmup_steps
31
+ self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps)
32
+ self._current_step = current_step
33
+ self.lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.total_steps = total_steps
36
+ self._last_lr = [self.lr]
37
+
38
+ def set_lr(self, lr):
39
+ self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
40
+ for g in self.optimizer.param_groups:
41
+ # g['lr'] = lr
42
+ g["lr"] = self.end_lr ###锁定用线性
43
+
44
+ def step(self):
45
+ if self._current_step < self.warmup_steps:
46
+ lr = self.init_lr + self._warmup_rate * self._current_step
47
+
48
+ elif self._current_step > self.total_steps:
49
+ lr = self.end_lr
50
+
51
+ else:
52
+ decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
53
+ if decay_ratio < 0.0 or decay_ratio > 1.0:
54
+ raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
55
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
56
+ lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
57
+
58
+ self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
59
+ self.set_lr(lr)
60
+ self.lr = lr
61
+ self._current_step += 1
62
+ return self.lr
63
+
64
+
65
+ if __name__ == "__main__":
66
+ m = nn.Linear(10, 10)
67
+ opt = Adam(m.parameters(), lr=1e-4)
68
+ s = WarmupCosineLRSchedule(
69
+ opt,
70
+ 1e-6,
71
+ 2e-4,
72
+ 1e-6,
73
+ warmup_steps=2000,
74
+ total_steps=20000,
75
+ current_step=0,
76
+ )
77
+ lrs = []
78
+ for i in range(25000):
79
+ s.step()
80
+ lrs.append(s.lr)
81
+ print(s.lr)
82
+
83
+ plt.plot(lrs)
84
+ plt.plot(range(0, 25000), lrs)
85
+ plt.show()
GPT_SoVITS/AR/modules/optim.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import contextlib
17
+ import logging
18
+ from collections import defaultdict
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+ from torch.optim import Optimizer
24
+
25
+
26
+ class BatchedOptimizer(Optimizer):
27
+ """
28
+ This class adds to class Optimizer the capability to optimize parameters in batches:
29
+ it will stack the parameters and their grads for you so the optimizer can work
30
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
31
+ as it reduces the number of kernels launched in the optimizer.
32
+
33
+ Args:
34
+ params:
35
+ """
36
+
37
+ def __init__(self, params, defaults):
38
+ super(BatchedOptimizer, self).__init__(params, defaults)
39
+
40
+ @contextlib.contextmanager
41
+ def batched_params(self, param_group, group_params_names):
42
+ """
43
+ This function returns (technically, yields) a list of
44
+ of tuples (p, state), where
45
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
46
+ that share the same shape, and its gradient is also stacked;
47
+ `state` is the state corresponding to this batch of parameters
48
+ (it will be physically located in the "state" for one of the real
49
+ parameters, the last one that has any particular shape and dtype).
50
+
51
+ This function is decorated as a context manager so that it can
52
+ write parameters back to their "real" locations.
53
+
54
+ The idea is, instead of doing:
55
+ <code>
56
+ for p in group["params"]:
57
+ state = self.state[p]
58
+ ...
59
+ </code>
60
+ you can do:
61
+ <code>
62
+ with self.batched_params(group["params"]) as batches:
63
+ for p, state, p_names in batches:
64
+ ...
65
+ </code>
66
+
67
+ Args:
68
+ group: a parameter group, which is a list of parameters; should be
69
+ one of self.param_groups.
70
+ group_params_names: name for each parameter in group,
71
+ which is List[str].
72
+ """
73
+ batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
74
+ batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
75
+
76
+ assert len(param_group) == len(group_params_names)
77
+ for p, named_p in zip(param_group, group_params_names):
78
+ key = (str(p.dtype), *p.shape)
79
+ batches[key].append(p)
80
+ batches_names[key].append(named_p)
81
+
82
+ batches_names_keys = list(batches_names.keys())
83
+ sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
84
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
85
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
86
+
87
+ stacked_params_dict = dict()
88
+
89
+ # turn batches into a list, in deterministic order.
90
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
91
+ # one for each batch in `batches`.
92
+ tuples = []
93
+
94
+ for batch, batch_names in zip(batches, batches_names):
95
+ p = batch[0]
96
+ # we arbitrarily store the state in the
97
+ # state corresponding to the 1st parameter in the
98
+ # group. class Optimizer will take care of saving/loading state.
99
+ state = self.state[p]
100
+ p_stacked = torch.stack(batch)
101
+ grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
102
+ p_stacked.grad = grad
103
+ stacked_params_dict[key] = p_stacked
104
+ tuples.append((p_stacked, state, batch_names))
105
+
106
+ yield tuples # <-- calling code will do the actual optimization here!
107
+
108
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
109
+ for i, p in enumerate(batch): # batch is list of Parameter
110
+ p.copy_(stacked_params[i])
111
+
112
+
113
+ class ScaledAdam(BatchedOptimizer):
114
+ """
115
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
116
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
117
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
118
+ param = underlying_param * log_scale.exp())
119
+
120
+
121
+ Args:
122
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
123
+ lr: The learning rate. We will typically use a learning rate schedule that starts
124
+ at 0.03 and decreases over time, i.e. much higher than other common
125
+ optimizers.
126
+ clipping_scale: (e.g. 2.0)
127
+ A scale for gradient-clipping: if specified, the normalized gradients
128
+ over the whole model will be clipped to have 2-norm equal to
129
+ `clipping_scale` times the median 2-norm over the most recent period
130
+ of `clipping_update_period` minibatches. By "normalized gradients",
131
+ we mean after multiplying by the rms parameter value for this tensor
132
+ [for non-scalars]; this is appropriate because our update is scaled
133
+ by this quantity.
134
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
135
+ Must satisfy 0 < beta <= beta2 < 1.
136
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
137
+ scale of each parameter tensor and scalar parameters of the mode..
138
+ If each parameter were decomposed
139
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
140
+ would be a the scaling factor on the learning rate of p_scale.
141
+ eps: A general-purpose epsilon to prevent division by zero
142
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
143
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
144
+ parameter tensor to be >= this value)
145
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
146
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
147
+ parameter tensor to be <= this value)
148
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
149
+ model has any parameters with numel() == 1).
150
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
151
+ of the parameter tensor. This is provided to save a little time
152
+ in the update.
153
+ clipping_update_period: if clipping_scale is specified, this is the period
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ params,
159
+ lr=3e-02,
160
+ clipping_scale=None,
161
+ betas=(0.9, 0.98),
162
+ scalar_lr_scale=0.1,
163
+ eps=1.0e-08,
164
+ param_min_rms=1.0e-05,
165
+ param_max_rms=3.0,
166
+ scalar_max=10.0,
167
+ size_update_period=4,
168
+ clipping_update_period=100,
169
+ parameters_names=None,
170
+ show_dominant_parameters=True,
171
+ ):
172
+ assert parameters_names is not None, (
173
+ "Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
174
+ )
175
+ defaults = dict(
176
+ lr=lr,
177
+ clipping_scale=clipping_scale,
178
+ betas=betas,
179
+ scalar_lr_scale=scalar_lr_scale,
180
+ eps=eps,
181
+ param_min_rms=param_min_rms,
182
+ param_max_rms=param_max_rms,
183
+ scalar_max=scalar_max,
184
+ size_update_period=size_update_period,
185
+ clipping_update_period=clipping_update_period,
186
+ )
187
+
188
+ super(ScaledAdam, self).__init__(params, defaults)
189
+ assert len(self.param_groups) == len(parameters_names)
190
+ self.parameters_names = parameters_names
191
+ self.show_dominant_parameters = show_dominant_parameters
192
+
193
+ def __setstate__(self, state):
194
+ super(ScaledAdam, self).__setstate__(state)
195
+
196
+ @torch.no_grad()
197
+ def step(self, closure=None):
198
+ """Performs a single optimization step.
199
+
200
+ Arguments:
201
+ closure (callable, optional): A closure that reevaluates the model
202
+ and returns the loss.
203
+ """
204
+ loss = None
205
+ if closure is not None:
206
+ with torch.enable_grad():
207
+ loss = closure()
208
+
209
+ batch = True
210
+
211
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
212
+ with self.batched_params(group["params"], group_params_names) as batches:
213
+ # batches is list of pairs (stacked_param, state). stacked_param is like
214
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
215
+ # a stacking dim, it is not a real dim.
216
+
217
+ if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
218
+ clipping_scale = 1
219
+ else:
220
+ clipping_scale = self._get_clipping_scale(group, batches)
221
+
222
+ for p, state, _ in batches:
223
+ # Perform optimization step.
224
+ # grad is not going to be None, we handled that when creating the batches.
225
+ grad = p.grad
226
+ if grad.is_sparse:
227
+ raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
228
+ # State initialization
229
+ if len(state) == 0:
230
+ self._init_state(group, p, state)
231
+
232
+ self._step_one_batch(group, p, state, clipping_scale)
233
+
234
+ return loss
235
+
236
+ def _init_state(self, group: dict, p: Tensor, state: dict):
237
+ """
238
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
239
+ is actually the batch dimension, corresponding to batched-together
240
+ parameters of a given shape.
241
+
242
+
243
+ Args:
244
+ group: Dict to look up configuration values.
245
+ p: The parameter that we are initializing the state for
246
+ state: Dict from string to whatever state we are initializing
247
+ """
248
+ size_update_period = group["size_update_period"]
249
+
250
+ state["step"] = 0
251
+
252
+ kwargs = {"device": p.device, "dtype": p.dtype}
253
+
254
+ # 'delta' implements conventional momentum. There are
255
+ # several different kinds of update going on, so rather than
256
+ # compute "exp_avg" like in Adam, we store and decay a
257
+ # parameter-change "delta", which combines all forms of
258
+ # update. this is equivalent to how it's done in Adam,
259
+ # except for the first few steps.
260
+ state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
261
+
262
+ batch_size = p.shape[0]
263
+ numel = p.numel() // batch_size
264
+ numel = p.numel()
265
+
266
+ if numel > 1:
267
+ # "param_rms" just periodically records the scalar root-mean-square value of
268
+ # the parameter tensor.
269
+ # it has a shape like (batch_size, 1, 1, 1, 1)
270
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
271
+ state["param_rms"] = param_rms
272
+
273
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
274
+ state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
275
+
276
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
277
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
278
+
279
+ def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
280
+ """
281
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
282
+ by this amount before applying the rest of the update.
283
+
284
+ Args:
285
+ group: the parameter group, an item in self.param_groups
286
+ tuples: a list of tuples of (param, state, param_names)
287
+ where param is a batched set of parameters,
288
+ with a .grad (1st dim is batch dim)
289
+ and state is the state-dict where optimization parameters are kept.
290
+ param_names is a List[str] while each str is name for a parameter
291
+ in batched set of parameters "param".
292
+ """
293
+ assert len(tuples) >= 1
294
+ clipping_scale = group["clipping_scale"]
295
+ (first_p, first_state, _) = tuples[0]
296
+ step = first_state["step"]
297
+ if clipping_scale is None or step == 0:
298
+ # no clipping. return early on step == 0 because the other
299
+ # parameters' state won't have been initialized yet.
300
+ return 1.0
301
+ clipping_update_period = group["clipping_update_period"]
302
+
303
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
304
+ for p, state, param_names in tuples:
305
+ grad = p.grad
306
+ if grad.is_sparse:
307
+ raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
308
+ if p.numel() == p.shape[0]: # a batch of scalars
309
+ tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
310
+ else:
311
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
312
+
313
+ tot_norm = tot_sumsq.sqrt()
314
+ if "model_norms" not in first_state:
315
+ first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
316
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
317
+
318
+ if step % clipping_update_period == 0:
319
+ # Print some stats.
320
+ # We don't reach here if step == 0 because we would have returned
321
+ # above.
322
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
323
+ quartiles = []
324
+ for n in range(0, 5):
325
+ index = min(
326
+ clipping_update_period - 1,
327
+ (clipping_update_period // 4) * n,
328
+ )
329
+ quartiles.append(sorted_norms[index].item())
330
+
331
+ median = quartiles[2]
332
+ threshold = clipping_scale * median
333
+ first_state["model_norm_threshold"] = threshold
334
+ percent_clipped = (
335
+ first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
336
+ )
337
+ first_state["num_clipped"] = 0
338
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
339
+ logging.info(
340
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
341
+ )
342
+
343
+ if step < clipping_update_period:
344
+ return 1.0 # We have not yet estimated a norm to clip to.
345
+ else:
346
+ try:
347
+ model_norm_threshold = first_state["model_norm_threshold"]
348
+ except KeyError:
349
+ logging.info(
350
+ "Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
351
+ )
352
+ return 1.0
353
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
354
+ if ans < 1.0:
355
+ first_state["num_clipped"] += 1
356
+ if ans < 0.1:
357
+ logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
358
+ if self.show_dominant_parameters:
359
+ assert p.shape[0] == len(param_names)
360
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
361
+ return ans
362
+
363
+ def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
364
+ """
365
+ Show information of parameter wihch dominanting tot_sumsq.
366
+
367
+ Args:
368
+ tuples: a list of tuples of (param, state, param_names)
369
+ where param is a batched set of parameters,
370
+ with a .grad (1st dim is batch dim)
371
+ and state is the state-dict where optimization parameters are kept.
372
+ param_names is a List[str] while each str is name for a parameter
373
+ in batched set of parameters "param".
374
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
375
+ from tuples, we still pass it to save some time.
376
+ """
377
+ all_sumsq_orig = {}
378
+ for p, state, batch_param_names in tuples:
379
+ # p is a stacked batch parameters.
380
+ batch_grad = p.grad
381
+ if p.numel() == p.shape[0]: # a batch of scalars
382
+ batch_sumsq_orig = batch_grad**2
383
+ # Dummpy values used by following `zip` statement.
384
+ batch_rms_orig = torch.ones(p.shape[0])
385
+ else:
386
+ batch_rms_orig = state["param_rms"]
387
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
388
+
389
+ for name, sumsq_orig, rms, grad in zip(
390
+ batch_param_names,
391
+ batch_sumsq_orig,
392
+ batch_rms_orig,
393
+ batch_grad,
394
+ ):
395
+ proportion_orig = sumsq_orig / tot_sumsq
396
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
397
+
398
+ assert torch.isclose(
399
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
400
+ torch.tensor(1.0),
401
+ )
402
+ sorted_by_proportion = {
403
+ k: v
404
+ for k, v in sorted(
405
+ all_sumsq_orig.items(),
406
+ key=lambda item: item[1][0],
407
+ reverse=True,
408
+ )
409
+ }
410
+ dominant_param_name = next(iter(sorted_by_proportion))
411
+ (
412
+ dominant_proportion,
413
+ dominant_sumsq,
414
+ dominant_rms,
415
+ dominant_grad,
416
+ ) = sorted_by_proportion[dominant_param_name]
417
+ logging.info(
418
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
419
+ f" with proportion {dominant_proportion:.2f},"
420
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
421
+ f"={dominant_sumsq:.3e},"
422
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
423
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
424
+ )
425
+
426
+ def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
427
+ """
428
+ Do the step for one parameter, which is actually going to be a batch of
429
+ `real` parameters, with dim 0 as the batch dim.
430
+ Args:
431
+ group: dict to look up configuration values
432
+ p: parameter to update (actually multiple parameters stacked together
433
+ as a batch)
434
+ state: state-dict for p, to look up the optimizer state
435
+ """
436
+ lr = group["lr"]
437
+ size_update_period = group["size_update_period"]
438
+ beta1 = group["betas"][0]
439
+
440
+ grad = p.grad
441
+ if clipping_scale != 1.0:
442
+ grad = grad * clipping_scale
443
+ step = state["step"]
444
+ delta = state["delta"]
445
+
446
+ delta.mul_(beta1)
447
+ batch_size = p.shape[0]
448
+ numel = p.numel() // batch_size
449
+ if numel > 1:
450
+ # Update the size/scale of p, and set param_rms
451
+ scale_grads = state["scale_grads"]
452
+ scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
453
+ if step % size_update_period == size_update_period - 1:
454
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
455
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
456
+ if step > 0:
457
+ # self._size_update() learns the overall scale on the
458
+ # parameter, by shrinking or expanding it.
459
+ self._size_update(group, scale_grads, p, state)
460
+
461
+ if numel == 1:
462
+ # For parameters with 1 element we just use regular Adam.
463
+ # Updates delta.
464
+ self._step_scalar(group, p, state)
465
+ else:
466
+ self._step(group, p, state)
467
+
468
+ state["step"] = step + 1
469
+
470
+ def _size_update(
471
+ self,
472
+ group: dict,
473
+ scale_grads: Tensor,
474
+ p: Tensor,
475
+ state: dict,
476
+ ) -> None:
477
+ """
478
+ Called only where p.numel() > 1, this updates the scale of the parameter.
479
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
480
+ gradient descent on underlying param and on scale, this function does the update
481
+ on `scale`.
482
+
483
+ Args:
484
+ group: dict to look up configuration values
485
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
486
+ grads w.r.t. the scales.
487
+ p: The parameter to update
488
+ state: The state-dict of p
489
+ """
490
+
491
+ param_rms = state["param_rms"]
492
+ beta1, beta2 = group["betas"]
493
+ size_lr = group["lr"] * group["scalar_lr_scale"]
494
+ param_min_rms = group["param_min_rms"]
495
+ param_max_rms = group["param_max_rms"]
496
+ eps = group["eps"]
497
+ step = state["step"]
498
+ batch_size = p.shape[0]
499
+
500
+ size_update_period = scale_grads.shape[0]
501
+ # correct beta2 for the size update period: we will have
502
+ # faster decay at this level.
503
+ beta2_corr = beta2**size_update_period
504
+
505
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
506
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
507
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
508
+ alpha=1 - beta2_corr,
509
+ ) # shape is (batch_size, 1, 1, ...)
510
+
511
+ # The 1st time we reach here is when size_step == 1.
512
+ size_step = (step + 1) // size_update_period
513
+ bias_correction2 = 1 - beta2_corr**size_step
514
+ # we don't bother with bias_correction1; this will help prevent divergence
515
+ # at the start of training.
516
+
517
+ denom = scale_exp_avg_sq.sqrt() + eps
518
+
519
+ scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
520
+
521
+ is_too_small = param_rms < param_min_rms
522
+ is_too_large = param_rms > param_max_rms
523
+
524
+ # when the param gets too small, just don't shrink it any further.
525
+ scale_step.masked_fill_(is_too_small, 0.0)
526
+ # when it gets too large, stop it from getting any larger.
527
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
528
+ delta = state["delta"]
529
+ # the factor of (1-beta1) relates to momentum.
530
+ delta.add_(p * scale_step, alpha=(1 - beta1))
531
+
532
+ def _step(self, group: dict, p: Tensor, state: dict):
533
+ """
534
+ This function does the core update of self.step(), in the case where the members of
535
+ the batch have more than 1 element.
536
+
537
+ Args:
538
+ group: A dict which will be used to look up configuration values
539
+ p: The parameter to be updated
540
+ grad: The grad of p
541
+ state: The state-dict corresponding to parameter p
542
+
543
+ This function modifies p.
544
+ """
545
+ grad = p.grad
546
+ lr = group["lr"]
547
+ beta1, beta2 = group["betas"]
548
+ eps = group["eps"]
549
+ param_min_rms = group["param_min_rms"]
550
+ step = state["step"]
551
+
552
+ exp_avg_sq = state["exp_avg_sq"]
553
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
554
+
555
+ this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
556
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
557
+ if bias_correction2 < 0.99:
558
+ # note: not in-place.
559
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
560
+
561
+ denom = exp_avg_sq.sqrt()
562
+ denom += eps
563
+ grad = grad / denom
564
+
565
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
566
+
567
+ delta = state["delta"]
568
+ delta.add_(grad * alpha)
569
+ p.add_(delta)
570
+
571
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
572
+ """
573
+ A simplified form of the core update for scalar tensors, where we cannot get a good
574
+ estimate of the parameter rms.
575
+ """
576
+ beta1, beta2 = group["betas"]
577
+ scalar_max = group["scalar_max"]
578
+ eps = group["eps"]
579
+ lr = group["lr"] * group["scalar_lr_scale"]
580
+ grad = p.grad
581
+
582
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
583
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
584
+
585
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
586
+ # slower update at the start will help stability anyway.
587
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
588
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
589
+
590
+ delta = state["delta"]
591
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
592
+ p.clamp_(min=-scalar_max, max=scalar_max)
593
+ p.add_(delta)
GPT_SoVITS/AR/modules/patched_mha_with_cache.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _mha_shape_check,
4
+ _canonical_mask,
5
+ _none_or_dtype,
6
+ _in_projection_packed,
7
+ )
8
+ import torch
9
+ # Tensor = torch.Tensor
10
+ # from typing import Callable, List, Optional, Tuple, Union
11
+
12
+
13
+ def multi_head_attention_forward_patched(
14
+ query,
15
+ key,
16
+ value,
17
+ embed_dim_to_check,
18
+ num_heads,
19
+ in_proj_weight,
20
+ in_proj_bias,
21
+ bias_k,
22
+ bias_v,
23
+ add_zero_attn,
24
+ dropout_p: float,
25
+ out_proj_weight,
26
+ out_proj_bias,
27
+ training=True,
28
+ key_padding_mask=None,
29
+ need_weights=True,
30
+ attn_mask=None,
31
+ use_separate_proj_weight=False,
32
+ q_proj_weight=None,
33
+ k_proj_weight=None,
34
+ v_proj_weight=None,
35
+ static_k=None,
36
+ static_v=None,
37
+ average_attn_weights=True,
38
+ is_causal=False,
39
+ cache=None,
40
+ ):
41
+ r"""
42
+ Args:
43
+ query, key, value: map a query and a set of key-value pairs to an output.
44
+ See "Attention Is All You Need" for more details.
45
+ embed_dim_to_check: total dimension of the model.
46
+ num_heads: parallel attention heads.
47
+ in_proj_weight, in_proj_bias: input projection weight and bias.
48
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
49
+ add_zero_attn: add a new batch of zeros to the key and
50
+ value sequences at dim=1.
51
+ dropout_p: probability of an element to be zeroed.
52
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
53
+ training: apply dropout if is ``True``.
54
+ key_padding_mask: if provided, specified padding elements in the key will
55
+ be ignored by the attention. This is an binary mask. When the value is True,
56
+ the corresponding value on the attention layer will be filled with -inf.
57
+ need_weights: output attn_output_weights.
58
+ Default: `True`
59
+ Note: `needs_weight` defaults to `True`, but should be set to `False`
60
+ For best performance when attention weights are not nedeeded.
61
+ *Setting needs_weights to `True`
62
+ leads to a significant performance degradation.*
63
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
64
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
65
+ is_causal: If specified, applies a causal mask as attention mask, and ignores
66
+ attn_mask for computing scaled dot product attention.
67
+ Default: ``False``.
68
+ .. warning::
69
+ is_causal is provides a hint that the attn_mask is the
70
+ causal mask.Providing incorrect hints can result in
71
+ incorrect execution, including forward and backward
72
+ compatibility.
73
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
74
+ and value in different forms. If false, in_proj_weight will be used, which is
75
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
76
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
77
+ static_k, static_v: static key and value used for attention operators.
78
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across heads.
79
+ Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an effect
80
+ when ``need_weights=True.``. Default: True
81
+
82
+
83
+ Shape:
84
+ Inputs:
85
+ - query: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
86
+ the embedding dimension.
87
+ - key: :math:`(S, E)` or :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
88
+ the embedding dimension.
89
+ - value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
90
+ the embedding dimension.
91
+ - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
92
+ If a FloatTensor is provided, it will be directly added to the value.
93
+ If a BoolTensor is provided, the positions with the
94
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
95
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
96
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
97
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
98
+ positions. If a BoolTensor is provided, positions with ``True``
99
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
100
+ is provided, it will be added to the attention weight.
101
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
102
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
103
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
104
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
105
+
106
+ Outputs:
107
+ - attn_output: :math:`(L, E)` or :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
108
+ E is the embedding dimension.
109
+ - attn_output_weights: Only returned when ``need_weights=True``. If ``average_attn_weights=True``, returns
110
+ attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
111
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
112
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
113
+ head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
114
+ """
115
+ tens_ops = (
116
+ query,
117
+ key,
118
+ value,
119
+ in_proj_weight,
120
+ in_proj_bias,
121
+ bias_k,
122
+ bias_v,
123
+ out_proj_weight,
124
+ out_proj_bias,
125
+ )
126
+ if has_torch_function(tens_ops):
127
+ return handle_torch_function(
128
+ multi_head_attention_forward,
129
+ tens_ops,
130
+ query,
131
+ key,
132
+ value,
133
+ embed_dim_to_check,
134
+ num_heads,
135
+ in_proj_weight,
136
+ in_proj_bias,
137
+ bias_k,
138
+ bias_v,
139
+ add_zero_attn,
140
+ dropout_p,
141
+ out_proj_weight,
142
+ out_proj_bias,
143
+ training=training,
144
+ key_padding_mask=key_padding_mask,
145
+ need_weights=need_weights,
146
+ attn_mask=attn_mask,
147
+ is_causal=is_causal,
148
+ use_separate_proj_weight=use_separate_proj_weight,
149
+ q_proj_weight=q_proj_weight,
150
+ k_proj_weight=k_proj_weight,
151
+ v_proj_weight=v_proj_weight,
152
+ static_k=static_k,
153
+ static_v=static_v,
154
+ average_attn_weights=average_attn_weights,
155
+ cache=cache,
156
+ )
157
+
158
+ is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
159
+
160
+ # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
161
+ # is batched, run the computation and before returning squeeze the
162
+ # batch dimension so that the output doesn't carry this temporary batch dimension.
163
+ if not is_batched:
164
+ # unsqueeze if the input is unbatched
165
+ query = query.unsqueeze(1)
166
+ key = key.unsqueeze(1)
167
+ value = value.unsqueeze(1)
168
+ if key_padding_mask is not None:
169
+ key_padding_mask = key_padding_mask.unsqueeze(0)
170
+
171
+ # set up shape vars
172
+ tgt_len, bsz, embed_dim = query.shape
173
+ src_len, _, _ = key.shape
174
+
175
+ key_padding_mask = _canonical_mask(
176
+ mask=key_padding_mask,
177
+ mask_name="key_padding_mask",
178
+ other_type=_none_or_dtype(attn_mask),
179
+ other_name="attn_mask",
180
+ target_type=query.dtype,
181
+ )
182
+
183
+ if is_causal and attn_mask is None:
184
+ raise RuntimeError(
185
+ "Need attn_mask if specifying the is_causal hint. "
186
+ "You may use the Transformer module method "
187
+ "`generate_square_subsequent_mask` to create this mask."
188
+ )
189
+
190
+ if is_causal and key_padding_mask is None and not need_weights:
191
+ # when we have a kpm or need weights, we need attn_mask
192
+ # Otherwise, we use the is_causal hint go as is_causal
193
+ # indicator to SDPA.
194
+ attn_mask = None
195
+ else:
196
+ attn_mask = _canonical_mask(
197
+ mask=attn_mask,
198
+ mask_name="attn_mask",
199
+ other_type=None,
200
+ other_name="",
201
+ target_type=query.dtype,
202
+ check_other=False,
203
+ )
204
+
205
+ if key_padding_mask is not None:
206
+ # We have the attn_mask, and use that to merge kpm into it.
207
+ # Turn off use of is_causal hint, as the merged mask is no
208
+ # longer causal.
209
+ is_causal = False
210
+
211
+ assert embed_dim == embed_dim_to_check, (
212
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
213
+ )
214
+ if isinstance(embed_dim, torch.Tensor):
215
+ # embed_dim can be a tensor when JIT tracing
216
+ head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
217
+ else:
218
+ head_dim = embed_dim // num_heads
219
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
220
+ if use_separate_proj_weight:
221
+ # allow MHA to have different embedding dimensions when separate projection weights are used
222
+ assert key.shape[:2] == value.shape[:2], (
223
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
224
+ )
225
+ else:
226
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
227
+
228
+ #
229
+ # compute in-projection
230
+ #
231
+ if not use_separate_proj_weight:
232
+ assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
233
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
234
+ else:
235
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
236
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
237
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
238
+ if in_proj_bias is None:
239
+ b_q = b_k = b_v = None
240
+ else:
241
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
242
+ q, k, v = _in_projection(
243
+ query,
244
+ key,
245
+ value,
246
+ q_proj_weight,
247
+ k_proj_weight,
248
+ v_proj_weight,
249
+ b_q,
250
+ b_k,
251
+ b_v,
252
+ )
253
+ if cache != None:
254
+ if cache["first_infer"] == 1:
255
+ cache["k"][cache["stage"]] = k
256
+ # print(0,cache["k"].shape)
257
+ cache["v"][cache["stage"]] = v
258
+ else: ###12个layer每个都要留自己的cache_kv
259
+ # print(1,cache["k"].shape)
260
+ cache["k"][cache["stage"]] = torch.cat(
261
+ [cache["k"][cache["stage"]], k], 0
262
+ ) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
263
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
264
+ # print(2, cache["k"].shape)
265
+ src_len = cache["k"][cache["stage"]].shape[0]
266
+ k = cache["k"][cache["stage"]]
267
+ v = cache["v"][cache["stage"]]
268
+ # if attn_mask is not None:
269
+ # attn_mask=attn_mask[-1:,]
270
+ # print(attn_mask.shape,attn_mask)
271
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
272
+ # print(2333,cache)
273
+ # prep attention mask
274
+
275
+ attn_mask = _canonical_mask(
276
+ mask=attn_mask,
277
+ mask_name="attn_mask",
278
+ other_type=None,
279
+ other_name="",
280
+ target_type=q.dtype,
281
+ check_other=False,
282
+ )
283
+
284
+ if attn_mask is not None:
285
+ # ensure attn_mask's dim is 3
286
+ if attn_mask.dim() == 2:
287
+ correct_2d_size = (tgt_len, src_len)
288
+ if attn_mask.shape != correct_2d_size:
289
+ raise RuntimeError(
290
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
291
+ )
292
+ attn_mask = attn_mask.unsqueeze(0)
293
+ elif attn_mask.dim() == 3:
294
+ correct_3d_size = (bsz * num_heads, tgt_len, src_len)
295
+ if attn_mask.shape != correct_3d_size:
296
+ raise RuntimeError(
297
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
298
+ )
299
+ else:
300
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
301
+
302
+ # add bias along batch dimension (currently second)
303
+ if bias_k is not None and bias_v is not None:
304
+ assert static_k is None, "bias cannot be added to static key."
305
+ assert static_v is None, "bias cannot be added to static value."
306
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
307
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
308
+ if attn_mask is not None:
309
+ attn_mask = pad(attn_mask, (0, 1))
310
+ if key_padding_mask is not None:
311
+ key_padding_mask = pad(key_padding_mask, (0, 1))
312
+ else:
313
+ assert bias_k is None
314
+ assert bias_v is None
315
+
316
+ #
317
+ # reshape q, k, v for multihead attention and make em batch first
318
+ #
319
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
320
+ if static_k is None:
321
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
322
+ else:
323
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
324
+ assert static_k.size(0) == bsz * num_heads, (
325
+ f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
326
+ )
327
+ assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
328
+ k = static_k
329
+ if static_v is None:
330
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
331
+ else:
332
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
333
+ assert static_v.size(0) == bsz * num_heads, (
334
+ f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
335
+ )
336
+ assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
337
+ v = static_v
338
+
339
+ # add zero attention along batch dimension (now first)
340
+ if add_zero_attn:
341
+ zero_attn_shape = (bsz * num_heads, 1, head_dim)
342
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
343
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
344
+ if attn_mask is not None:
345
+ attn_mask = pad(attn_mask, (0, 1))
346
+ if key_padding_mask is not None:
347
+ key_padding_mask = pad(key_padding_mask, (0, 1))
348
+
349
+ # update source sequence length after adjustments
350
+ src_len = k.size(1)
351
+
352
+ # merge key padding and attention masks
353
+ if key_padding_mask is not None:
354
+ assert key_padding_mask.shape == (
355
+ bsz,
356
+ src_len,
357
+ ), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
358
+ key_padding_mask = (
359
+ key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
360
+ )
361
+ if attn_mask is None:
362
+ attn_mask = key_padding_mask
363
+ else:
364
+ attn_mask = attn_mask + key_padding_mask
365
+
366
+ # adjust dropout probability
367
+ if not training:
368
+ dropout_p = 0.0
369
+
370
+ #
371
+ # (deep breath) calculate attention and out projection
372
+ #
373
+
374
+ if need_weights:
375
+ B, Nt, E = q.shape
376
+ q_scaled = q / math.sqrt(E)
377
+
378
+ assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
379
+
380
+ if attn_mask is not None:
381
+ attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
382
+ else:
383
+ attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
384
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
385
+ if dropout_p > 0.0:
386
+ attn_output_weights = dropout(attn_output_weights, p=dropout_p)
387
+
388
+ attn_output = torch.bmm(attn_output_weights, v)
389
+
390
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
391
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
392
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
393
+
394
+ # optionally average attention weights over heads
395
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
396
+ if average_attn_weights:
397
+ attn_output_weights = attn_output_weights.mean(dim=1)
398
+
399
+ if not is_batched:
400
+ # squeeze the output if input was unbatched
401
+ attn_output = attn_output.squeeze(1)
402
+ attn_output_weights = attn_output_weights.squeeze(0)
403
+ return attn_output, attn_output_weights
404
+ else:
405
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
406
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
407
+ # in order to match the input for SDPA of (N, num_heads, L, S)
408
+ if attn_mask is not None:
409
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
410
+ attn_mask = attn_mask.unsqueeze(0)
411
+ else:
412
+ attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
413
+
414
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
415
+ k = k.view(bsz, num_heads, src_len, head_dim)
416
+ v = v.view(bsz, num_heads, src_len, head_dim)
417
+
418
+ # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
419
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
420
+
421
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
422
+
423
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
424
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
425
+ if not is_batched:
426
+ # squeeze the output if input was unbatched
427
+ attn_output = attn_output.squeeze(1)
428
+ return attn_output, None
GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn.functional import *
2
+ from torch.nn.functional import (
3
+ _canonical_mask,
4
+ )
5
+
6
+
7
+ def multi_head_attention_forward_patched(
8
+ query,
9
+ key,
10
+ value,
11
+ embed_dim_to_check: int,
12
+ num_heads: int,
13
+ in_proj_weight,
14
+ in_proj_bias: Optional[Tensor],
15
+ bias_k: Optional[Tensor],
16
+ bias_v: Optional[Tensor],
17
+ add_zero_attn: bool,
18
+ dropout_p: float,
19
+ out_proj_weight: Tensor,
20
+ out_proj_bias: Optional[Tensor],
21
+ training: bool = True,
22
+ key_padding_mask: Optional[Tensor] = None,
23
+ need_weights: bool = True,
24
+ attn_mask: Optional[Tensor] = None,
25
+ use_separate_proj_weight: bool = False,
26
+ q_proj_weight: Optional[Tensor] = None,
27
+ k_proj_weight: Optional[Tensor] = None,
28
+ v_proj_weight: Optional[Tensor] = None,
29
+ static_k: Optional[Tensor] = None,
30
+ static_v: Optional[Tensor] = None,
31
+ average_attn_weights: bool = True,
32
+ is_causal: bool = False,
33
+ cache=None,
34
+ ) -> Tuple[Tensor, Optional[Tensor]]:
35
+ # set up shape vars
36
+ _, _, embed_dim = query.shape
37
+ attn_mask = _canonical_mask(
38
+ mask=attn_mask,
39
+ mask_name="attn_mask",
40
+ other_type=None,
41
+ other_name="",
42
+ target_type=query.dtype,
43
+ check_other=False,
44
+ )
45
+ head_dim = embed_dim // num_heads
46
+
47
+ proj_qkv = linear(query, in_proj_weight, in_proj_bias)
48
+ proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
49
+ q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
50
+
51
+ if cache["first_infer"] == 1:
52
+ cache["k"][cache["stage"]] = k
53
+ cache["v"][cache["stage"]] = v
54
+ else:
55
+ cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0)
56
+ cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0)
57
+ k = cache["k"][cache["stage"]]
58
+ v = cache["v"][cache["stage"]]
59
+ cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
60
+
61
+ attn_mask = _canonical_mask(
62
+ mask=attn_mask,
63
+ mask_name="attn_mask",
64
+ other_type=None,
65
+ other_name="",
66
+ target_type=q.dtype,
67
+ check_other=False,
68
+ )
69
+ attn_mask = attn_mask.unsqueeze(0)
70
+
71
+ q = q.view(-1, num_heads, head_dim).transpose(0, 1)
72
+ k = k.view(-1, num_heads, head_dim).transpose(0, 1)
73
+ v = v.view(-1, num_heads, head_dim).transpose(0, 1)
74
+
75
+ dropout_p = 0.0
76
+ attn_mask = attn_mask.unsqueeze(0)
77
+ q = q.view(num_heads, -1, head_dim).unsqueeze(0)
78
+ k = k.view(num_heads, -1, head_dim).unsqueeze(0)
79
+ v = v.view(num_heads, -1, head_dim).unsqueeze(0)
80
+ attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
81
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
82
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
83
+ attn_output = attn_output.view(-1, 1, attn_output.size(1))
84
+
85
+ return attn_output
GPT_SoVITS/AR/modules/scaling.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import random
17
+ from typing import Optional
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch import Tensor
23
+
24
+
25
+ class DoubleSwishFunction(torch.autograd.Function):
26
+ """
27
+ double_swish(x) = x * torch.sigmoid(x-1)
28
+ This is a definition, originally motivated by its close numerical
29
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
30
+
31
+ Memory-efficient derivative computation:
32
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
33
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
34
+ Now, s'(x) = s(x) * (1-s(x)).
35
+ double_swish'(x) = x * s'(x) + s(x).
36
+ = x * s(x) * (1-s(x)) + s(x).
37
+ = double_swish(x) * (1-s(x)) + s(x)
38
+ ... so we just need to remember s(x) but not x itself.
39
+ """
40
+
41
+ @staticmethod
42
+ def forward(ctx, x: Tensor) -> Tensor:
43
+ requires_grad = x.requires_grad
44
+ x_dtype = x.dtype
45
+ if x.dtype == torch.float16:
46
+ x = x.to(torch.float32)
47
+
48
+ s = torch.sigmoid(x - 1.0)
49
+ y = x * s
50
+
51
+ if requires_grad:
52
+ deriv = y * (1 - s) + s
53
+ # notes on derivative of x * sigmoid(x - 1):
54
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
55
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
56
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
57
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
58
+ # floors), should be expectation-preserving.
59
+ floor = -0.043637
60
+ ceil = 1.2
61
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
62
+ if __name__ == "__main__":
63
+ # for self-testing only.
64
+ assert d_scaled.min() >= 0.0
65
+ assert d_scaled.max() < 256.0
66
+ d_int = d_scaled.to(torch.uint8)
67
+ ctx.save_for_backward(d_int)
68
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
69
+ y = y.to(torch.float16)
70
+ return y
71
+
72
+ @staticmethod
73
+ def backward(ctx, y_grad: Tensor) -> Tensor:
74
+ (d,) = ctx.saved_tensors
75
+ # the same constants as used in forward pass.
76
+ floor = -0.043637
77
+ ceil = 1.2
78
+ d = d * ((ceil - floor) / 255.0) + floor
79
+ return y_grad * d
80
+
81
+
82
+ class DoubleSwish(torch.nn.Module):
83
+ def forward(self, x: Tensor) -> Tensor:
84
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
85
+ that we approximate closely with x * sigmoid(x-1).
86
+ """
87
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
88
+ return x * torch.sigmoid(x - 1.0)
89
+ return DoubleSwishFunction.apply(x)
90
+
91
+
92
+ class ActivationBalancerFunction(torch.autograd.Function):
93
+ @staticmethod
94
+ def forward(
95
+ ctx,
96
+ x: Tensor,
97
+ scale_factor: Tensor,
98
+ sign_factor: Optional[Tensor],
99
+ channel_dim: int,
100
+ ) -> Tensor:
101
+ if channel_dim < 0:
102
+ channel_dim += x.ndim
103
+ ctx.channel_dim = channel_dim
104
+ xgt0 = x > 0
105
+ if sign_factor is None:
106
+ ctx.save_for_backward(xgt0, scale_factor)
107
+ else:
108
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
109
+ return x
110
+
111
+ @staticmethod
112
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
113
+ if len(ctx.saved_tensors) == 3:
114
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
115
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
116
+ scale_factor = scale_factor.unsqueeze(-1)
117
+ sign_factor = sign_factor.unsqueeze(-1)
118
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
119
+ else:
120
+ xgt0, scale_factor = ctx.saved_tensors
121
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
122
+ scale_factor = scale_factor.unsqueeze(-1)
123
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
124
+ neg_delta_grad = x_grad.abs() * factor
125
+ return (
126
+ x_grad - neg_delta_grad,
127
+ None,
128
+ None,
129
+ None,
130
+ )
131
+
132
+
133
+ def _compute_scale_factor(
134
+ x: Tensor,
135
+ channel_dim: int,
136
+ min_abs: float,
137
+ max_abs: float,
138
+ gain_factor: float,
139
+ max_factor: float,
140
+ ) -> Tensor:
141
+ if channel_dim < 0:
142
+ channel_dim += x.ndim
143
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
144
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
145
+
146
+ if min_abs == 0.0:
147
+ below_threshold = 0.0
148
+ else:
149
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
150
+ # x_abs)_mean , min_abs.
151
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
152
+
153
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
154
+
155
+ return below_threshold - above_threshold
156
+
157
+
158
+ def _compute_sign_factor(
159
+ x: Tensor,
160
+ channel_dim: int,
161
+ min_positive: float,
162
+ max_positive: float,
163
+ gain_factor: float,
164
+ max_factor: float,
165
+ ) -> Tensor:
166
+ if channel_dim < 0:
167
+ channel_dim += x.ndim
168
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
169
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
170
+ if min_positive == 0.0:
171
+ factor1 = 0.0
172
+ else:
173
+ # 0 if proportion_positive >= min_positive, else can be
174
+ # as large as max_factor.
175
+ factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
176
+
177
+ if max_positive == 1.0:
178
+ factor2 = 0.0
179
+ else:
180
+ # 0 if self.proportion_positive <= max_positive, else can be
181
+ # as large as -max_factor.
182
+ factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
183
+ min=0, max=max_factor
184
+ )
185
+ sign_factor = factor1 - factor2
186
+ # require min_positive != 0 or max_positive != 1:
187
+ assert not isinstance(sign_factor, float)
188
+ return sign_factor
189
+
190
+
191
+ class ActivationBalancer(torch.nn.Module):
192
+ """
193
+ Modifies the backpropped derivatives of a function to try to encourage, for
194
+ each channel, that it is positive at least a proportion `threshold` of the
195
+ time. It does this by multiplying negative derivative values by up to
196
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
197
+ interpolated from 1 at the threshold to those extremal values when none
198
+ of the inputs are positive.
199
+
200
+ Args:
201
+ num_channels: the number of channels
202
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
203
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
204
+ min_positive: the minimum, per channel, of the proportion of the time
205
+ that (x > 0), below which we start to modify the derivatives.
206
+ max_positive: the maximum, per channel, of the proportion of the time
207
+ that (x > 0), above which we start to modify the derivatives.
208
+ max_factor: the maximum factor by which we modify the derivatives for
209
+ either the sign constraint or the magnitude constraint;
210
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
211
+ values in the range [0.98..1.02].
212
+ sign_gain_factor: determines the 'gain' with which we increase the
213
+ change in gradient once the constraints on min_positive and max_positive
214
+ are violated.
215
+ scale_gain_factor: determines the 'gain' with which we increase the
216
+ change in gradient once the constraints on min_abs and max_abs
217
+ are violated.
218
+ min_abs: the minimum average-absolute-value difference from the mean
219
+ value per channel, which we allow, before we start to modify
220
+ the derivatives to prevent this.
221
+ max_abs: the maximum average-absolute-value difference from the mean
222
+ value per channel, which we allow, before we start to modify
223
+ the derivatives to prevent this.
224
+ min_prob: determines the minimum probability with which we modify the
225
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
226
+ on each forward(). This is done randomly to prevent all layers
227
+ from doing it at the same time. Early in training we may use
228
+ higher probabilities than this; it will decay to this value.
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ num_channels: int,
234
+ channel_dim: int,
235
+ min_positive: float = 0.05,
236
+ max_positive: float = 0.95,
237
+ max_factor: float = 0.04,
238
+ sign_gain_factor: float = 0.01,
239
+ scale_gain_factor: float = 0.02,
240
+ min_abs: float = 0.2,
241
+ max_abs: float = 100.0,
242
+ min_prob: float = 0.1,
243
+ ):
244
+ super(ActivationBalancer, self).__init__()
245
+ self.num_channels = num_channels
246
+ self.channel_dim = channel_dim
247
+ self.min_positive = min_positive
248
+ self.max_positive = max_positive
249
+ self.max_factor = max_factor
250
+ self.min_abs = min_abs
251
+ self.max_abs = max_abs
252
+ self.min_prob = min_prob
253
+ self.sign_gain_factor = sign_gain_factor
254
+ self.scale_gain_factor = scale_gain_factor
255
+
256
+ # count measures how many times the forward() function has been called.
257
+ # We occasionally sync this to a tensor called `count`, that exists to
258
+ # make sure it is synced to disk when we load and save the model.
259
+ self.cpu_count = 0
260
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
261
+
262
+ def forward(self, x: Tensor) -> Tensor:
263
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
264
+ return _no_op(x)
265
+
266
+ count = self.cpu_count
267
+ self.cpu_count += 1
268
+
269
+ if random.random() < 0.01:
270
+ # Occasionally sync self.cpu_count with self.count.
271
+ # count affects the decay of 'prob'. don't do this on every iter,
272
+ # because syncing with the GPU is slow.
273
+ self.cpu_count = max(self.cpu_count, self.count.item())
274
+ self.count.fill_(self.cpu_count)
275
+
276
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
277
+ # a floor at min_prob (==0.1, by default)
278
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
279
+
280
+ if random.random() < prob:
281
+ sign_gain_factor = 0.5
282
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
283
+ sign_factor = _compute_sign_factor(
284
+ x,
285
+ self.channel_dim,
286
+ self.min_positive,
287
+ self.max_positive,
288
+ gain_factor=self.sign_gain_factor / prob,
289
+ max_factor=self.max_factor,
290
+ )
291
+ else:
292
+ sign_factor = None
293
+
294
+ scale_factor = _compute_scale_factor(
295
+ x.detach(),
296
+ self.channel_dim,
297
+ min_abs=self.min_abs,
298
+ max_abs=self.max_abs,
299
+ gain_factor=self.scale_gain_factor / prob,
300
+ max_factor=self.max_factor,
301
+ )
302
+ return ActivationBalancerFunction.apply(
303
+ x,
304
+ scale_factor,
305
+ sign_factor,
306
+ self.channel_dim,
307
+ )
308
+ else:
309
+ return _no_op(x)
310
+
311
+
312
+ def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
313
+ """
314
+ ActivationBalancer -> DoubleSwish
315
+ """
316
+ balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
317
+ return nn.Sequential(
318
+ balancer,
319
+ DoubleSwish(),
320
+ )
GPT_SoVITS/AR/modules/transformer.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
46
+ self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
74
+
75
+ def extra_repr(self) -> str:
76
+ return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
77
+
78
+
79
+ class IdentityNorm(nn.Module):
80
+ def __init__(
81
+ self,
82
+ d_model: int,
83
+ eps: float = 1e-5,
84
+ device=None,
85
+ dtype=None,
86
+ ) -> None:
87
+ super(IdentityNorm, self).__init__()
88
+
89
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
90
+ if isinstance(input, tuple):
91
+ return input
92
+
93
+ assert embedding is None
94
+ return input
95
+
96
+
97
+ class TransformerEncoder(nn.Module):
98
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
99
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
100
+
101
+ Args:
102
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
103
+ num_layers: the number of sub-encoder-layers in the encoder (required).
104
+ norm: the layer normalization component (optional).
105
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
106
+ (and convert back on output). This will improve the overall performance of
107
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
108
+
109
+ Examples::
110
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
111
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
112
+ >>> src = torch.rand(10, 32, 512)
113
+ >>> out = transformer_encoder(src)
114
+ """
115
+
116
+ __constants__ = ["norm"]
117
+
118
+ def __init__(self, encoder_layer, num_layers, norm=None):
119
+ super(TransformerEncoder, self).__init__()
120
+ self.layers = _get_clones(encoder_layer, num_layers)
121
+ self.num_layers = num_layers
122
+ self.norm = norm
123
+
124
+ def forward(
125
+ self,
126
+ src: Tensor,
127
+ mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ return_layer_states: bool = False,
130
+ cache=None,
131
+ ) -> Tensor:
132
+ r"""Pass the input through the encoder layers in turn.
133
+
134
+ Args:
135
+ src: the sequence to the encoder (required).
136
+ mask: the mask for the src sequence (optional).
137
+ src_key_padding_mask: the mask for the src keys per batch (optional).
138
+ return_layer_states: return layers' state (optional).
139
+
140
+ Shape:
141
+ see the docs in Transformer class.
142
+ """
143
+ if return_layer_states:
144
+ layer_states = [] # layers' output
145
+ output = src
146
+ for mod in self.layers:
147
+ output = mod(
148
+ output,
149
+ src_mask=mask,
150
+ src_key_padding_mask=src_key_padding_mask,
151
+ cache=cache,
152
+ )
153
+ layer_states.append(output[0])
154
+
155
+ if self.norm is not None:
156
+ output = self.norm(output)
157
+
158
+ return layer_states, output
159
+
160
+ output = src
161
+ for mod in self.layers:
162
+ output = mod(
163
+ output,
164
+ src_mask=mask,
165
+ src_key_padding_mask=src_key_padding_mask,
166
+ cache=cache,
167
+ )
168
+
169
+ if self.norm is not None:
170
+ output = self.norm(output)
171
+
172
+ return output
173
+
174
+
175
+ class TransformerEncoderLayer(nn.Module):
176
+ __constants__ = ["batch_first", "norm_first"]
177
+
178
+ def __init__(
179
+ self,
180
+ d_model: int,
181
+ nhead: int,
182
+ dim_feedforward: int = 2048,
183
+ dropout: float = 0.1,
184
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
185
+ batch_first: bool = False,
186
+ norm_first: bool = False,
187
+ device=None,
188
+ dtype=None,
189
+ linear1_self_attention_cls: nn.Module = nn.Linear,
190
+ linear2_self_attention_cls: nn.Module = nn.Linear,
191
+ linear1_feedforward_cls: nn.Module = nn.Linear,
192
+ linear2_feedforward_cls: nn.Module = nn.Linear,
193
+ layer_norm_cls: nn.Module = LayerNorm,
194
+ layer_norm_eps: float = 1e-5,
195
+ adaptive_layer_norm=False,
196
+ ) -> None:
197
+ factory_kwargs = {"device": device, "dtype": dtype}
198
+ super(TransformerEncoderLayer, self).__init__()
199
+ # print(233333333333,d_model,nhead)
200
+ # import os
201
+ # os._exit(2333333)
202
+ self.self_attn = MultiheadAttention(
203
+ d_model, # 512 16
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
214
+ self.dropout = nn.Dropout(dropout)
215
+ self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
216
+
217
+ self.norm_first = norm_first
218
+ self.dropout1 = nn.Dropout(dropout)
219
+ self.dropout2 = nn.Dropout(dropout)
220
+
221
+ # Legacy string support for activation function.
222
+ if isinstance(activation, str):
223
+ activation = _get_activation_fn(activation)
224
+ elif isinstance(activation, partial):
225
+ activation = activation(d_model)
226
+ elif activation == BalancedDoubleSwish:
227
+ activation = BalancedDoubleSwish(d_model)
228
+
229
+ # # We can't test self.activation in forward() in TorchScript,
230
+ # # so stash some information about it instead.
231
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
232
+ # self.activation_relu_or_gelu = 1
233
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
234
+ # self.activation_relu_or_gelu = 2
235
+ # else:
236
+ # self.activation_relu_or_gelu = 0
237
+ self.activation = activation
238
+
239
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
240
+ if layer_norm_cls == IdentityNorm:
241
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
242
+ else:
243
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+
245
+ if adaptive_layer_norm:
246
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
247
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
248
+ else:
249
+ self.norm1 = norm1
250
+ self.norm2 = norm2
251
+
252
+ def __setstate__(self, state):
253
+ super(TransformerEncoderLayer, self).__setstate__(state)
254
+ if not hasattr(self, "activation"):
255
+ self.activation = F.relu
256
+
257
+ def forward(
258
+ self,
259
+ src: Tensor,
260
+ src_mask: Optional[Tensor] = None,
261
+ src_key_padding_mask: Optional[Tensor] = None,
262
+ cache=None,
263
+ ) -> Tensor:
264
+ r"""Pass the input through the encoder layer.
265
+
266
+ Args:
267
+ src: the sequence to the encoder layer (required).
268
+ src_mask: the mask for the src sequence (optional).
269
+ src_key_padding_mask: the mask for the src keys per batch (optional).
270
+
271
+ Shape:
272
+ see the docs in Transformer class.
273
+ """
274
+ x, stage_embedding = src, None
275
+ is_src_tuple = False
276
+ if isinstance(src, tuple):
277
+ x, stage_embedding = src
278
+ is_src_tuple = True
279
+
280
+ if src_key_padding_mask is not None:
281
+ _skpm_dtype = src_key_padding_mask.dtype
282
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
283
+ raise AssertionError("only bool and floating types of key_padding_mask are supported")
284
+
285
+ if self.norm_first:
286
+ x = x + self._sa_block(
287
+ self.norm1(x, stage_embedding),
288
+ src_mask,
289
+ src_key_padding_mask,
290
+ cache=cache,
291
+ )
292
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
293
+ else:
294
+ x = self.norm1(
295
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
296
+ stage_embedding,
297
+ )
298
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
299
+
300
+ if is_src_tuple:
301
+ return (x, stage_embedding)
302
+ return x
303
+
304
+ # self-attention block
305
+ def _sa_block(
306
+ self,
307
+ x: Tensor,
308
+ attn_mask: Optional[Tensor],
309
+ key_padding_mask: Optional[Tensor],
310
+ cache=None,
311
+ ) -> Tensor:
312
+ # print(x.shape,attn_mask.shape,key_padding_mask)
313
+ # torch.Size([1, 188, 512]) torch.Size([188, 188]) None
314
+ # import os
315
+ # os._exit(23333)
316
+ x = self.self_attn(
317
+ x,
318
+ x,
319
+ x,
320
+ attn_mask=attn_mask,
321
+ key_padding_mask=key_padding_mask,
322
+ need_weights=False,
323
+ cache=cache,
324
+ )[0]
325
+ return self.dropout1(x)
326
+
327
+ # feed forward block
328
+ def _ff_block(self, x: Tensor) -> Tensor:
329
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
330
+ return self.dropout2(x)
331
+
332
+
333
+ class AdaptiveLayerNorm(nn.Module):
334
+ r"""Adaptive Layer Normalization"""
335
+
336
+ def __init__(self, d_model, norm) -> None:
337
+ super(AdaptiveLayerNorm, self).__init__()
338
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
339
+ self.norm = norm
340
+ self.d_model = d_model
341
+ self.eps = self.norm.eps
342
+
343
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
344
+ if isinstance(input, tuple):
345
+ input, embedding = input
346
+ weight, bias = torch.split(
347
+ self.project_layer(embedding),
348
+ split_size_or_sections=self.d_model,
349
+ dim=-1,
350
+ )
351
+ return (weight * self.norm(input) + bias, embedding)
352
+
353
+ weight, bias = torch.split(
354
+ self.project_layer(embedding),
355
+ split_size_or_sections=self.d_model,
356
+ dim=-1,
357
+ )
358
+ return weight * self.norm(input) + bias
359
+
360
+
361
+ def _get_clones(module, N):
362
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
GPT_SoVITS/AR/modules/transformer_onnx.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any
6
+ from typing import Callable
7
+ from typing import List
8
+ from typing import Optional
9
+ from typing import Tuple
10
+ from typing import Union
11
+
12
+ import torch
13
+ from AR.modules.activation_onnx import MultiheadAttention
14
+ from AR.modules.scaling import BalancedDoubleSwish
15
+ from torch import nn
16
+ from torch import Tensor
17
+ from torch.nn import functional as F
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ # mypy error: incompatible types in assignment
40
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
41
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
42
+ self.eps = eps
43
+ self.elementwise_affine = elementwise_affine
44
+ if self.elementwise_affine:
45
+ self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
46
+ self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
74
+
75
+ def extra_repr(self) -> str:
76
+ return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
77
+
78
+
79
+ class IdentityNorm(nn.Module):
80
+ def __init__(
81
+ self,
82
+ d_model: int,
83
+ eps: float = 1e-5,
84
+ device=None,
85
+ dtype=None,
86
+ ) -> None:
87
+ super(IdentityNorm, self).__init__()
88
+
89
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
90
+ if isinstance(input, tuple):
91
+ return input
92
+
93
+ assert embedding is None
94
+ return input
95
+
96
+
97
+ class TransformerEncoder(nn.Module):
98
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
99
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
100
+
101
+ Args:
102
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
103
+ num_layers: the number of sub-encoder-layers in the encoder (required).
104
+ norm: the layer normalization component (optional).
105
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
106
+ (and convert back on output). This will improve the overall performance of
107
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
108
+
109
+ Examples::
110
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
111
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
112
+ >>> src = torch.rand(10, 32, 512)
113
+ >>> out = transformer_encoder(src)
114
+ """
115
+
116
+ __constants__ = ["norm"]
117
+
118
+ def __init__(self, encoder_layer, num_layers, norm=None):
119
+ super(TransformerEncoder, self).__init__()
120
+ self.layers = _get_clones(encoder_layer, num_layers)
121
+ self.num_layers = num_layers
122
+ self.norm = norm
123
+
124
+ def forward(
125
+ self,
126
+ src: Tensor,
127
+ mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ return_layer_states: bool = False,
130
+ cache=None,
131
+ ) -> Tensor:
132
+ output = src
133
+ for mod in self.layers:
134
+ output = mod(
135
+ output,
136
+ src_mask=mask,
137
+ src_key_padding_mask=src_key_padding_mask,
138
+ cache=cache,
139
+ )
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+
144
+ return output
145
+
146
+
147
+ class TransformerEncoderLayer(nn.Module):
148
+ __constants__ = ["batch_first", "norm_first"]
149
+
150
+ def __init__(
151
+ self,
152
+ d_model: int,
153
+ nhead: int,
154
+ dim_feedforward: int = 2048,
155
+ dropout: float = 0.1,
156
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
157
+ batch_first: bool = False,
158
+ norm_first: bool = False,
159
+ device=None,
160
+ dtype=None,
161
+ linear1_self_attention_cls: nn.Module = nn.Linear,
162
+ linear2_self_attention_cls: nn.Module = nn.Linear,
163
+ linear1_feedforward_cls: nn.Module = nn.Linear,
164
+ linear2_feedforward_cls: nn.Module = nn.Linear,
165
+ layer_norm_cls: nn.Module = LayerNorm,
166
+ layer_norm_eps: float = 1e-5,
167
+ adaptive_layer_norm=False,
168
+ ) -> None:
169
+ factory_kwargs = {"device": device, "dtype": dtype}
170
+ super(TransformerEncoderLayer, self).__init__()
171
+ self.self_attn = MultiheadAttention(
172
+ d_model, # 512 16
173
+ nhead,
174
+ dropout=dropout,
175
+ batch_first=batch_first,
176
+ linear1_cls=linear1_self_attention_cls,
177
+ linear2_cls=linear2_self_attention_cls,
178
+ **factory_kwargs,
179
+ )
180
+ self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
181
+ self.dropout = nn.Dropout(dropout)
182
+ self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
183
+ self.norm_first = norm_first
184
+ self.dropout1 = nn.Dropout(dropout)
185
+ self.dropout2 = nn.Dropout(dropout)
186
+ if isinstance(activation, str):
187
+ activation = _get_activation_fn(activation)
188
+ elif isinstance(activation, partial):
189
+ activation = activation(d_model)
190
+ elif activation == BalancedDoubleSwish:
191
+ activation = BalancedDoubleSwish(d_model)
192
+ self.activation = activation
193
+
194
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
195
+ if layer_norm_cls == IdentityNorm:
196
+ norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
197
+ else:
198
+ norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
199
+
200
+ if adaptive_layer_norm:
201
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
202
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
203
+ else:
204
+ self.norm1 = norm1
205
+ self.norm2 = norm2
206
+
207
+ def __setstate__(self, state):
208
+ super(TransformerEncoderLayer, self).__setstate__(state)
209
+ if not hasattr(self, "activation"):
210
+ self.activation = F.relu
211
+
212
+ def forward(
213
+ self,
214
+ src: Tensor,
215
+ src_mask: Optional[Tensor] = None,
216
+ src_key_padding_mask: Optional[Tensor] = None,
217
+ cache=None,
218
+ ) -> Tensor:
219
+ x = src
220
+ stage_embedding = None
221
+ x = self.norm1(
222
+ x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache),
223
+ stage_embedding,
224
+ )
225
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
226
+
227
+ return x
228
+
229
+ def _sa_block(
230
+ self,
231
+ x: Tensor,
232
+ attn_mask: Optional[Tensor],
233
+ key_padding_mask: Optional[Tensor],
234
+ cache=None,
235
+ ) -> Tensor:
236
+ x = self.self_attn(
237
+ x,
238
+ x,
239
+ x,
240
+ attn_mask=attn_mask,
241
+ key_padding_mask=key_padding_mask,
242
+ need_weights=False,
243
+ cache=cache,
244
+ )
245
+ return self.dropout1(x)
246
+
247
+ def _ff_block(self, x: Tensor) -> Tensor:
248
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
249
+ return self.dropout2(x)
250
+
251
+
252
+ class AdaptiveLayerNorm(nn.Module):
253
+ r"""Adaptive Layer Normalization"""
254
+
255
+ def __init__(self, d_model, norm) -> None:
256
+ super(AdaptiveLayerNorm, self).__init__()
257
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
258
+ self.norm = norm
259
+ self.d_model = d_model
260
+ self.eps = self.norm.eps
261
+
262
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
263
+ if isinstance(input, tuple):
264
+ input, embedding = input
265
+ weight, bias = torch.split(
266
+ self.project_layer(embedding),
267
+ split_size_or_sections=self.d_model,
268
+ dim=-1,
269
+ )
270
+ return (weight * self.norm(input) + bias, embedding)
271
+
272
+ weight, bias = torch.split(
273
+ self.project_layer(embedding),
274
+ split_size_or_sections=self.d_model,
275
+ dim=-1,
276
+ )
277
+ return weight * self.norm(input) + bias
278
+
279
+
280
+ def _get_clones(module, N):
281
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
GPT_SoVITS/AR/text_processing/__init__.py ADDED
File without changes
GPT_SoVITS/AR/text_processing/phonemizer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import itertools
4
+ import re
5
+ from typing import Dict
6
+ from typing import List
7
+
8
+ import regex
9
+ from gruut import sentences
10
+ from gruut.const import Sentence
11
+ from gruut.const import Word
12
+ from AR.text_processing.symbols import SYMBOL_TO_ID
13
+
14
+
15
+ class GruutPhonemizer:
16
+ def __init__(self, language: str):
17
+ self._phonemizer = sentences
18
+ self.lang = language
19
+ self.symbol_to_id = SYMBOL_TO_ID
20
+ self._special_cases_dict: Dict[str] = {
21
+ r"\.\.\.": "... ",
22
+ ";": "; ",
23
+ ":": ": ",
24
+ ",": ", ",
25
+ r"\.": ". ",
26
+ "!": "! ",
27
+ r"\?": "? ",
28
+ "—": "—",
29
+ "…": "… ",
30
+ "«": "«",
31
+ "»": "»",
32
+ }
33
+ self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
34
+
35
+ def _normalize_punctuation(self, text: str) -> str:
36
+ text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
37
+ text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
38
+ text = regex.sub(r"\pZ+", r" ", text)
39
+ return text.strip()
40
+
41
+ def _convert_punctuation(self, word: Word) -> str:
42
+ if not word.phonemes:
43
+ return ""
44
+ if word.phonemes[0] in ["‖", "|"]:
45
+ return word.text.strip()
46
+
47
+ phonemes = "".join(word.phonemes)
48
+ # remove modifier characters ˈˌː with regex
49
+ phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
50
+ return phonemes.strip()
51
+
52
+ def phonemize(self, text: str, espeak: bool = False) -> str:
53
+ text_to_phonemize: str = self._normalize_punctuation(text)
54
+ sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
55
+ words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
56
+ return " ".join(words)
57
+
58
+ def transform(self, phonemes):
59
+ # convert phonemes to ids
60
+ # dictionary is in symbols.py
61
+ return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
62
+
63
+
64
+ if __name__ == "__main__":
65
+ phonemizer = GruutPhonemizer("en-us")
66
+ # text -> IPA
67
+ phonemes = phonemizer.phonemize("Hello, wor-ld ?")
68
+ print("phonemes:", phonemes)
69
+ print("len(phonemes):", len(phonemes))
70
+ phoneme_ids = phonemizer.transform(phonemes)
71
+ print("phoneme_ids:", phoneme_ids)
72
+ print("len(phoneme_ids):", len(phoneme_ids))
GPT_SoVITS/AR/text_processing/symbols.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ PAD = "_"
4
+ PUNCTUATION = ';:,.!?¡¿—…"«»“” '
5
+ LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
6
+ IPA_LETTERS = (
7
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
8
+ )
9
+ SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
10
+ SPACE_ID = SYMBOLS.index(" ")
11
+ SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
12
+ ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
GPT_SoVITS/AR/utils/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def str2bool(str):
5
+ return True if str.lower() == "true" else False
6
+
7
+
8
+ def get_newest_ckpt(string_list):
9
+ # 定义一个正则表达式模式,用于匹配字符串中的数字
10
+ pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
11
+
12
+ # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
13
+ extracted_info = []
14
+ for string in string_list:
15
+ match = re.match(pattern, string)
16
+ if match:
17
+ epoch = int(match.group(1))
18
+ step = int(match.group(2))
19
+ extracted_info.append((epoch, step, string))
20
+ # 按照 epoch 后面的数字和 step 后面的数字进行排序
21
+ sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
22
+ # 获取最新的 ckpt 文件名
23
+ newest_ckpt = sorted_info[0][2]
24
+ return newest_ckpt
25
+
26
+
27
+ # 文本存在且不为空时 return True
28
+ def check_txt_file(file_path):
29
+ try:
30
+ with open(file_path, "r") as file:
31
+ text = file.readline().strip()
32
+ assert text.strip() != ""
33
+ return text
34
+ except Exception:
35
+ return False
36
+ return False
GPT_SoVITS/AR/utils/initialize.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Initialize modules for espnet2 neural networks."""
3
+
4
+ import torch
5
+ from typeguard import check_argument_types
6
+
7
+
8
+ def initialize(model: torch.nn.Module, init: str):
9
+ """Initialize weights of a neural network module.
10
+
11
+ Parameters are initialized using the given method or distribution.
12
+
13
+ Custom initialization routines can be implemented into submodules
14
+ as function `espnet_initialization_fn` within the custom module.
15
+
16
+ Args:
17
+ model: Target.
18
+ init: Method of initialization.
19
+ """
20
+ assert check_argument_types()
21
+ print("init with", init)
22
+
23
+ # weight init
24
+ for p in model.parameters():
25
+ if p.dim() > 1:
26
+ if init == "xavier_uniform":
27
+ torch.nn.init.xavier_uniform_(p.data)
28
+ elif init == "xavier_normal":
29
+ torch.nn.init.xavier_normal_(p.data)
30
+ elif init == "kaiming_uniform":
31
+ torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
32
+ elif init == "kaiming_normal":
33
+ torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
34
+ else:
35
+ raise ValueError("Unknown initialization: " + init)
36
+ # bias init
37
+ for name, p in model.named_parameters():
38
+ if ".bias" in name and p.dim() == 1:
39
+ p.data.zero_()
GPT_SoVITS/AR/utils/io.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ import yaml
5
+
6
+
7
+ def load_yaml_config(path):
8
+ with open(path) as f:
9
+ config = yaml.full_load(f)
10
+ return config
11
+
12
+
13
+ def save_config_to_yaml(config, path):
14
+ assert path.endswith(".yaml")
15
+ with open(path, "w") as f:
16
+ f.write(yaml.dump(config))
17
+ f.close()
18
+
19
+
20
+ def write_args(args, path):
21
+ args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
22
+ with open(path, "a") as args_file:
23
+ args_file.write("==> torch version: {}\n".format(torch.__version__))
24
+ args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
25
+ args_file.write("==> Cmd:\n")
26
+ args_file.write(str(sys.argv))
27
+ args_file.write("\n==> args:\n")
28
+ for k, v in sorted(args_dict.items()):
29
+ args_file.write(" %s: %s\n" % (str(k), str(v)))
30
+ args_file.close()
GPT_SoVITS/download.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.insert(0, now_dir)
6
+ from text.g2pw import G2PWPinyin
7
+
8
+ g2pw = G2PWPinyin(
9
+ model_dir="GPT_SoVITS/text/G2PWModel",
10
+ model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
11
+ v_to_u=False,
12
+ neutral_tone_with_five=True,
13
+ )
GPT_SoVITS/export_torch_script.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import argparse
4
+ from typing import Optional
5
+ from my_utils import load_audio
6
+ import torch
7
+ import torchaudio
8
+
9
+ from torch import IntTensor, LongTensor, Tensor, nn
10
+ from torch.nn import functional as F
11
+
12
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
13
+ from feature_extractor import cnhubert
14
+
15
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
16
+ from module.models_onnx import SynthesizerTrn
17
+
18
+ from inference_webui import get_phones_and_bert
19
+
20
+ import os
21
+ import soundfile
22
+
23
+ default_config = {
24
+ "embedding_dim": 512,
25
+ "hidden_dim": 512,
26
+ "num_head": 8,
27
+ "num_layers": 12,
28
+ "num_codebook": 8,
29
+ "p_dropout": 0.0,
30
+ "vocab_size": 1024 + 1,
31
+ "phoneme_vocab_size": 512,
32
+ "EOS": 1024,
33
+ }
34
+
35
+
36
+ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
37
+ config = dict_s1["config"]
38
+ config["model"]["dropout"] = float(config["model"]["dropout"])
39
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
40
+ t2s_model.load_state_dict(dict_s1["weight"])
41
+ t2s_model = t2s_model.eval()
42
+ return t2s_model
43
+
44
+
45
+ @torch.jit.script
46
+ def logits_to_probs(
47
+ logits,
48
+ previous_tokens: Optional[torch.Tensor] = None,
49
+ temperature: float = 1.0,
50
+ top_k: Optional[int] = None,
51
+ top_p: Optional[int] = None,
52
+ repetition_penalty: float = 1.0,
53
+ ):
54
+ # if previous_tokens is not None:
55
+ # previous_tokens = previous_tokens.squeeze()
56
+ # print(logits.shape,previous_tokens.shape)
57
+ # pdb.set_trace()
58
+ if previous_tokens is not None and repetition_penalty != 1.0:
59
+ previous_tokens = previous_tokens.long()
60
+ score = torch.gather(logits, dim=1, index=previous_tokens)
61
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
62
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
63
+
64
+ if top_p is not None and top_p < 1.0:
65
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
66
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
67
+ sorted_indices_to_remove = cum_probs > top_p
68
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
69
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
70
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
71
+
72
+ logits = logits / max(temperature, 1e-5)
73
+
74
+ if top_k is not None:
75
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
76
+ pivot = v[:, -1].unsqueeze(-1)
77
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
78
+
79
+ probs = torch.nn.functional.softmax(logits, dim=-1)
80
+ return probs
81
+
82
+
83
+ @torch.jit.script
84
+ def multinomial_sample_one_no_sync(probs_sort):
85
+ # Does multinomial sampling without a cuda synchronization
86
+ q = torch.randn_like(probs_sort)
87
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
88
+
89
+
90
+ @torch.jit.script
91
+ def sample(
92
+ logits,
93
+ previous_tokens,
94
+ temperature: float = 1.0,
95
+ top_k: Optional[int] = None,
96
+ top_p: Optional[int] = None,
97
+ repetition_penalty: float = 1.0,
98
+ ):
99
+ probs = logits_to_probs(
100
+ logits=logits,
101
+ previous_tokens=previous_tokens,
102
+ temperature=temperature,
103
+ top_k=top_k,
104
+ top_p=top_p,
105
+ repetition_penalty=repetition_penalty,
106
+ )
107
+ idx_next = multinomial_sample_one_no_sync(probs)
108
+ return idx_next, probs
109
+
110
+
111
+ @torch.jit.script
112
+ def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
113
+ hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
114
+ y = torch.nn.functional.pad(
115
+ y.unsqueeze(1),
116
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
117
+ mode="reflect",
118
+ )
119
+ y = y.squeeze(1)
120
+ spec = torch.stft(
121
+ y,
122
+ n_fft,
123
+ hop_length=hop_size,
124
+ win_length=win_size,
125
+ window=hann_window,
126
+ center=center,
127
+ pad_mode="reflect",
128
+ normalized=False,
129
+ onesided=True,
130
+ return_complex=False,
131
+ )
132
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
133
+ return spec
134
+
135
+
136
+ class DictToAttrRecursive(dict):
137
+ def __init__(self, input_dict):
138
+ super().__init__(input_dict)
139
+ for key, value in input_dict.items():
140
+ if isinstance(value, dict):
141
+ value = DictToAttrRecursive(value)
142
+ self[key] = value
143
+ setattr(self, key, value)
144
+
145
+ def __getattr__(self, item):
146
+ try:
147
+ return self[item]
148
+ except KeyError:
149
+ raise AttributeError(f"Attribute {item} not found")
150
+
151
+ def __setattr__(self, key, value):
152
+ if isinstance(value, dict):
153
+ value = DictToAttrRecursive(value)
154
+ super(DictToAttrRecursive, self).__setitem__(key, value)
155
+ super().__setattr__(key, value)
156
+
157
+ def __delattr__(self, item):
158
+ try:
159
+ del self[item]
160
+ except KeyError:
161
+ raise AttributeError(f"Attribute {item} not found")
162
+
163
+
164
+ @torch.jit.script
165
+ class T2SMLP:
166
+ def __init__(self, w1, b1, w2, b2):
167
+ self.w1 = w1
168
+ self.b1 = b1
169
+ self.w2 = w2
170
+ self.b2 = b2
171
+
172
+ def forward(self, x):
173
+ x = F.relu(F.linear(x, self.w1, self.b1))
174
+ x = F.linear(x, self.w2, self.b2)
175
+ return x
176
+
177
+
178
+ @torch.jit.script
179
+ class T2SBlock:
180
+ def __init__(
181
+ self,
182
+ num_heads: int,
183
+ hidden_dim: int,
184
+ mlp: T2SMLP,
185
+ qkv_w,
186
+ qkv_b,
187
+ out_w,
188
+ out_b,
189
+ norm_w1,
190
+ norm_b1,
191
+ norm_eps1: float,
192
+ norm_w2,
193
+ norm_b2,
194
+ norm_eps2: float,
195
+ ):
196
+ self.num_heads = num_heads
197
+ self.mlp = mlp
198
+ self.hidden_dim: int = hidden_dim
199
+ self.qkv_w = qkv_w
200
+ self.qkv_b = qkv_b
201
+ self.out_w = out_w
202
+ self.out_b = out_b
203
+ self.norm_w1 = norm_w1
204
+ self.norm_b1 = norm_b1
205
+ self.norm_eps1 = norm_eps1
206
+ self.norm_w2 = norm_w2
207
+ self.norm_b2 = norm_b2
208
+ self.norm_eps2 = norm_eps2
209
+
210
+ self.false = torch.tensor(False, dtype=torch.bool)
211
+
212
+ @torch.jit.ignore
213
+ def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
214
+ if padding_mask is None:
215
+ return x
216
+
217
+ if padding_mask.dtype == torch.bool:
218
+ return x.masked_fill(padding_mask, 0)
219
+ else:
220
+ return x * padding_mask
221
+
222
+ def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
223
+ q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
224
+
225
+ batch_size = q.shape[0]
226
+ q_len = q.shape[1]
227
+ kv_len = k.shape[1]
228
+
229
+ q = self.to_mask(q, padding_mask)
230
+ k_cache = self.to_mask(k, padding_mask)
231
+ v_cache = self.to_mask(v, padding_mask)
232
+
233
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
234
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
235
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
236
+
237
+ attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
238
+
239
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
240
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
241
+ attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
242
+
243
+ if padding_mask is not None:
244
+ for i in range(batch_size):
245
+ # mask = padding_mask[i,:,0]
246
+ if self.false.device != padding_mask.device:
247
+ self.false = self.false.to(padding_mask.device)
248
+ idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
249
+ x_item = x[i, idx, :].unsqueeze(0)
250
+ attn_item = attn[i, idx, :].unsqueeze(0)
251
+ x_item = x_item + attn_item
252
+ x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
253
+ x_item = x_item + self.mlp.forward(x_item)
254
+ x_item = F.layer_norm(
255
+ x_item,
256
+ [self.hidden_dim],
257
+ self.norm_w2,
258
+ self.norm_b2,
259
+ self.norm_eps2,
260
+ )
261
+ x[i, idx, :] = x_item.squeeze(0)
262
+ x = self.to_mask(x, padding_mask)
263
+ else:
264
+ x = x + attn
265
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
266
+ x = x + self.mlp.forward(x)
267
+ x = F.layer_norm(
268
+ x,
269
+ [self.hidden_dim],
270
+ self.norm_w2,
271
+ self.norm_b2,
272
+ self.norm_eps2,
273
+ )
274
+ return x, k_cache, v_cache
275
+
276
+ def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
277
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
278
+
279
+ k_cache = torch.cat([k_cache, k], dim=1)
280
+ v_cache = torch.cat([v_cache, v], dim=1)
281
+
282
+ batch_size = q.shape[0]
283
+ q_len = q.shape[1]
284
+ kv_len = k_cache.shape[1]
285
+
286
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
287
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
288
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
289
+
290
+ attn = F.scaled_dot_product_attention(q, k, v)
291
+
292
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
293
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
294
+ attn = F.linear(attn, self.out_w, self.out_b)
295
+
296
+ x = x + attn
297
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
298
+ x = x + self.mlp.forward(x)
299
+ x = F.layer_norm(
300
+ x,
301
+ [self.hidden_dim],
302
+ self.norm_w2,
303
+ self.norm_b2,
304
+ self.norm_eps2,
305
+ )
306
+ return x, k_cache, v_cache
307
+
308
+
309
+ @torch.jit.script
310
+ class T2STransformer:
311
+ def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
312
+ self.num_blocks: int = num_blocks
313
+ self.blocks = blocks
314
+
315
+ def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
316
+ k_cache: list[torch.Tensor] = []
317
+ v_cache: list[torch.Tensor] = []
318
+ for i in range(self.num_blocks):
319
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
320
+ k_cache.append(k_cache_)
321
+ v_cache.append(v_cache_)
322
+ return x, k_cache, v_cache
323
+
324
+ def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
325
+ for i in range(self.num_blocks):
326
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
327
+ return x, k_cache, v_cache
328
+
329
+
330
+ class VitsModel(nn.Module):
331
+ def __init__(self, vits_path):
332
+ super().__init__()
333
+ # dict_s2 = torch.load(vits_path,map_location="cpu")
334
+ dict_s2 = torch.load(vits_path)
335
+ self.hps = dict_s2["config"]
336
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
337
+ self.hps["model"]["version"] = "v1"
338
+ else:
339
+ self.hps["model"]["version"] = "v2"
340
+
341
+ self.hps = DictToAttrRecursive(self.hps)
342
+ self.hps.model.semantic_frame_rate = "25hz"
343
+ self.vq_model = SynthesizerTrn(
344
+ self.hps.data.filter_length // 2 + 1,
345
+ self.hps.train.segment_size // self.hps.data.hop_length,
346
+ n_speakers=self.hps.data.n_speakers,
347
+ **self.hps.model,
348
+ )
349
+ self.vq_model.eval()
350
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
351
+
352
+ def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
353
+ refer = spectrogram_torch(
354
+ ref_audio,
355
+ self.hps.data.filter_length,
356
+ self.hps.data.sampling_rate,
357
+ self.hps.data.hop_length,
358
+ self.hps.data.win_length,
359
+ center=False,
360
+ )
361
+ return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
362
+
363
+
364
+ class T2SModel(nn.Module):
365
+ def __init__(self, raw_t2s: Text2SemanticLightningModule):
366
+ super(T2SModel, self).__init__()
367
+ self.model_dim = raw_t2s.model.model_dim
368
+ self.embedding_dim = raw_t2s.model.embedding_dim
369
+ self.num_head = raw_t2s.model.num_head
370
+ self.num_layers = raw_t2s.model.num_layers
371
+ self.vocab_size = raw_t2s.model.vocab_size
372
+ self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
373
+ # self.p_dropout = float(raw_t2s.model.p_dropout)
374
+ self.EOS: int = int(raw_t2s.model.EOS)
375
+ self.norm_first = raw_t2s.model.norm_first
376
+ assert self.EOS == self.vocab_size - 1
377
+ self.hz = 50
378
+
379
+ self.bert_proj = raw_t2s.model.bert_proj
380
+ self.ar_text_embedding = raw_t2s.model.ar_text_embedding
381
+ self.ar_text_position = raw_t2s.model.ar_text_position
382
+ self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
383
+ self.ar_audio_position = raw_t2s.model.ar_audio_position
384
+
385
+ # self.t2s_transformer = T2STransformer(self.num_layers, blocks)
386
+ # self.t2s_transformer = raw_t2s.model.t2s_transformer
387
+
388
+ blocks = []
389
+ h = raw_t2s.model.h
390
+
391
+ for i in range(self.num_layers):
392
+ layer = h.layers[i]
393
+ t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
394
+
395
+ block = T2SBlock(
396
+ self.num_head,
397
+ self.model_dim,
398
+ t2smlp,
399
+ layer.self_attn.in_proj_weight,
400
+ layer.self_attn.in_proj_bias,
401
+ layer.self_attn.out_proj.weight,
402
+ layer.self_attn.out_proj.bias,
403
+ layer.norm1.weight,
404
+ layer.norm1.bias,
405
+ layer.norm1.eps,
406
+ layer.norm2.weight,
407
+ layer.norm2.bias,
408
+ layer.norm2.eps,
409
+ )
410
+
411
+ blocks.append(block)
412
+
413
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
414
+
415
+ # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
416
+ self.ar_predict_layer = raw_t2s.model.ar_predict_layer
417
+ # self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
418
+ self.max_sec = raw_t2s.config["data"]["max_sec"]
419
+ self.top_k = int(raw_t2s.config["inference"]["top_k"])
420
+ self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
421
+
422
+ def forward(
423
+ self,
424
+ prompts: LongTensor,
425
+ ref_seq: LongTensor,
426
+ text_seq: LongTensor,
427
+ ref_bert: torch.Tensor,
428
+ text_bert: torch.Tensor,
429
+ top_k: LongTensor,
430
+ ):
431
+ bert = torch.cat([ref_bert.T, text_bert.T], 1)
432
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
433
+ bert = bert.unsqueeze(0)
434
+
435
+ x = self.ar_text_embedding(all_phoneme_ids)
436
+ x = x + self.bert_proj(bert.transpose(1, 2))
437
+ x: torch.Tensor = self.ar_text_position(x)
438
+
439
+ early_stop_num = self.early_stop_num
440
+
441
+ # [1,N,512] [1,N]
442
+ # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
443
+ y = prompts
444
+ # x_example = x[:,:,0] * 0.0
445
+
446
+ x_len = x.shape[1]
447
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
448
+
449
+ y_emb = self.ar_audio_embedding(y)
450
+ y_len = y_emb.shape[1]
451
+ prefix_len = y.shape[1]
452
+ y_pos = self.ar_audio_position(y_emb)
453
+ xy_pos = torch.concat([x, y_pos], dim=1)
454
+
455
+ bsz = x.shape[0]
456
+ src_len = x_len + y_len
457
+ x_attn_mask_pad = F.pad(
458
+ x_attn_mask,
459
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
460
+ value=True,
461
+ )
462
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
463
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
464
+ (x_len, 0),
465
+ value=False,
466
+ )
467
+ xy_attn_mask = (
468
+ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
469
+ .unsqueeze(0)
470
+ .expand(bsz * self.num_head, -1, -1)
471
+ .view(bsz, self.num_head, src_len, src_len)
472
+ .to(device=x.device, dtype=torch.bool)
473
+ )
474
+
475
+ idx = 0
476
+ top_k = int(top_k)
477
+
478
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
479
+
480
+ logits = self.ar_predict_layer(xy_dec[:, -1])
481
+ logits = logits[:, :-1]
482
+ samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
483
+ y = torch.concat([y, samples], dim=1)
484
+ y_emb = self.ar_audio_embedding(y[:, -1:])
485
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
486
+ :, y_len + idx
487
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
488
+
489
+ stop = False
490
+ # for idx in range(1, 50):
491
+ for idx in range(1, 1500):
492
+ # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
493
+ # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
494
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
495
+ logits = self.ar_predict_layer(xy_dec[:, -1])
496
+
497
+ if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
498
+ logits = logits[:, :-1]
499
+
500
+ samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
501
+
502
+ y = torch.concat([y, samples], dim=1)
503
+
504
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
505
+ stop = True
506
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
507
+ stop = True
508
+ if stop:
509
+ if y.shape[1] == 0:
510
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
511
+ break
512
+
513
+ y_emb = self.ar_audio_embedding(y[:, -1:])
514
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
515
+ :, y_len + idx
516
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
517
+
518
+ y[0, -1] = 0
519
+
520
+ return y[:, -idx:].unsqueeze(0)
521
+
522
+
523
+ bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
524
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
525
+ cnhubert.cnhubert_base_path = cnhubert_base_path
526
+
527
+
528
+ @torch.jit.script
529
+ def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
530
+ phone_level_feature = []
531
+ for i in range(word2ph.shape[0]):
532
+ repeat_feature = res[i].repeat(word2ph[i].item(), 1)
533
+ phone_level_feature.append(repeat_feature)
534
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
535
+ # [sum(word2ph), 1024]
536
+ return phone_level_feature
537
+
538
+
539
+ class MyBertModel(torch.nn.Module):
540
+ def __init__(self, bert_model):
541
+ super(MyBertModel, self).__init__()
542
+ self.bert = bert_model
543
+
544
+ def forward(
545
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
546
+ ):
547
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
548
+ # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
549
+ res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
550
+ return build_phone_level_feature(res, word2ph)
551
+
552
+
553
+ class SSLModel(torch.nn.Module):
554
+ def __init__(self):
555
+ super().__init__()
556
+ self.ssl = cnhubert.get_model().model
557
+
558
+ def forward(self, ref_audio_16k) -> torch.Tensor:
559
+ ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
560
+ return ssl_content
561
+
562
+
563
+ class ExportSSLModel(torch.nn.Module):
564
+ def __init__(self, ssl: SSLModel):
565
+ super().__init__()
566
+ self.ssl = ssl
567
+
568
+ def forward(self, ref_audio: torch.Tensor):
569
+ return self.ssl(ref_audio)
570
+
571
+ @torch.jit.export
572
+ def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
573
+ audio = resamplex(ref_audio, src_sr, dst_sr).float()
574
+ return audio
575
+
576
+
577
+ def export_bert(output_path):
578
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
579
+
580
+ text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
581
+ ref_bert_inputs = tokenizer(text, return_tensors="pt")
582
+ word2ph = []
583
+ for c in text:
584
+ if c in [",", "。", ":", "?", ",", ".", "?"]:
585
+ word2ph.append(1)
586
+ else:
587
+ word2ph.append(2)
588
+ ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
589
+
590
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
591
+ my_bert_model = MyBertModel(bert_model)
592
+
593
+ ref_bert_inputs = {
594
+ "input_ids": ref_bert_inputs["input_ids"],
595
+ "attention_mask": ref_bert_inputs["attention_mask"],
596
+ "token_type_ids": ref_bert_inputs["token_type_ids"],
597
+ "word2ph": ref_bert_inputs["word2ph"],
598
+ }
599
+
600
+ torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
601
+ torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
602
+ torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
603
+ torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
604
+
605
+ my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
606
+ output_path = os.path.join(output_path, "bert_model.pt")
607
+ my_bert_model.save(output_path)
608
+ print("#### exported bert ####")
609
+
610
+
611
+ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
612
+ if not os.path.exists(output_path):
613
+ os.makedirs(output_path)
614
+ print(f"目录已创建: {output_path}")
615
+ else:
616
+ print(f"目录已存在: {output_path}")
617
+
618
+ ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
619
+ ssl = SSLModel()
620
+ if export_bert_and_ssl:
621
+ s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
622
+ ssl_path = os.path.join(output_path, "ssl_model.pt")
623
+ torch.jit.script(s).save(ssl_path)
624
+ print("#### exported ssl ####")
625
+ export_bert(output_path)
626
+ else:
627
+ s = ExportSSLModel(ssl)
628
+
629
+ print(f"device: {device}")
630
+
631
+ ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
632
+ ref_seq = torch.LongTensor([ref_seq_id]).to(device)
633
+ ref_bert = ref_bert_T.T.to(ref_seq.device)
634
+ text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
635
+ "这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
636
+ )
637
+ text_seq = torch.LongTensor([text_seq_id]).to(device)
638
+ text_bert = text_bert_T.T.to(text_seq.device)
639
+
640
+ ssl_content = ssl(ref_audio).to(device)
641
+
642
+ # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
643
+ vits = VitsModel(vits_path).to(device)
644
+ vits.eval()
645
+
646
+ # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
647
+ # dict_s1 = torch.load(gpt_path, map_location=device)
648
+ dict_s1 = torch.load(gpt_path)
649
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
650
+ print("#### get_raw_t2s_model ####")
651
+ print(raw_t2s.config)
652
+ t2s_m = T2SModel(raw_t2s)
653
+ t2s_m.eval()
654
+ t2s = torch.jit.script(t2s_m).to(device)
655
+ print("#### script t2s_m ####")
656
+
657
+ print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
658
+ gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
659
+ gpt_sovits.eval()
660
+
661
+ ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
662
+
663
+ torch._dynamo.mark_dynamic(ssl_content, 2)
664
+ torch._dynamo.mark_dynamic(ref_audio_sr, 1)
665
+ torch._dynamo.mark_dynamic(ref_seq, 1)
666
+ torch._dynamo.mark_dynamic(text_seq, 1)
667
+ torch._dynamo.mark_dynamic(ref_bert, 0)
668
+ torch._dynamo.mark_dynamic(text_bert, 0)
669
+
670
+ top_k = torch.LongTensor([5]).to(device)
671
+
672
+ with torch.no_grad():
673
+ gpt_sovits_export = torch.jit.trace(
674
+ gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
675
+ )
676
+
677
+ gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
678
+ gpt_sovits_export.save(gpt_sovits_path)
679
+ print("#### exported gpt_sovits ####")
680
+
681
+
682
+ @torch.jit.script
683
+ def parse_audio(ref_audio):
684
+ ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
685
+ ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
686
+ return ref_audio_16k, ref_audio_sr
687
+
688
+
689
+ @torch.jit.script
690
+ def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
691
+ return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
692
+
693
+
694
+ class GPT_SoVITS(nn.Module):
695
+ def __init__(self, t2s: T2SModel, vits: VitsModel):
696
+ super().__init__()
697
+ self.t2s = t2s
698
+ self.vits = vits
699
+
700
+ def forward(
701
+ self,
702
+ ssl_content: torch.Tensor,
703
+ ref_audio_sr: torch.Tensor,
704
+ ref_seq: Tensor,
705
+ text_seq: Tensor,
706
+ ref_bert: Tensor,
707
+ text_bert: Tensor,
708
+ top_k: LongTensor,
709
+ speed=1.0,
710
+ ):
711
+ codes = self.vits.vq_model.extract_latent(ssl_content)
712
+ prompt_semantic = codes[0, 0]
713
+ prompts = prompt_semantic.unsqueeze(0)
714
+
715
+ pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
716
+ audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
717
+ return audio
718
+
719
+
720
+ def test():
721
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
722
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
723
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
724
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
725
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
726
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
727
+
728
+ args = parser.parse_args()
729
+ gpt_path = args.gpt_model
730
+ vits_path = args.sovits_model
731
+ ref_audio_path = args.ref_audio
732
+ ref_text = args.ref_text
733
+
734
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
735
+ # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
736
+ # bert = MyBertModel(bert_model)
737
+ my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
738
+
739
+ # dict_s1 = torch.load(gpt_path, map_location="cuda")
740
+ # raw_t2s = get_raw_t2s_model(dict_s1)
741
+ # t2s = T2SModel(raw_t2s)
742
+ # t2s.eval()
743
+ # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
744
+
745
+ # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
746
+ # vits = VitsModel(vits_path)
747
+ # vits.eval()
748
+
749
+ # ssl = ExportSSLModel(SSLModel()).to('cuda')
750
+ # ssl.eval()
751
+ ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
752
+
753
+ # gpt_sovits = GPT_SoVITS(t2s,vits)
754
+ gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
755
+
756
+ ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
757
+ ref_seq = torch.LongTensor([ref_seq_id])
758
+ ref_bert = ref_bert_T.T.to(ref_seq.device)
759
+ # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
760
+ text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
761
+
762
+ text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
763
+
764
+ test_bert = tokenizer(text, return_tensors="pt")
765
+ word2ph = []
766
+ for c in text:
767
+ if c in [",", "。", ":", "?", "?", ",", "."]:
768
+ word2ph.append(1)
769
+ else:
770
+ word2ph.append(2)
771
+ test_bert["word2ph"] = torch.Tensor(word2ph).int()
772
+
773
+ test_bert = my_bert(
774
+ test_bert["input_ids"].to("cuda"),
775
+ test_bert["attention_mask"].to("cuda"),
776
+ test_bert["token_type_ids"].to("cuda"),
777
+ test_bert["word2ph"].to("cuda"),
778
+ )
779
+
780
+ text_seq = torch.LongTensor([text_seq_id])
781
+ text_bert = text_bert_T.T.to(text_seq.device)
782
+
783
+ print("text_bert:", text_bert.shape, text_bert)
784
+ print("test_bert:", test_bert.shape, test_bert)
785
+ print(torch.allclose(text_bert.to("cuda"), test_bert))
786
+
787
+ print("text_seq:", text_seq.shape)
788
+ print("text_bert:", text_bert.shape, text_bert.type())
789
+
790
+ # [1,N]
791
+ ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
792
+ print("ref_audio:", ref_audio.shape)
793
+
794
+ ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
795
+ print("start ssl")
796
+ ssl_content = ssl(ref_audio)
797
+
798
+ print("start gpt_sovits:")
799
+ print("ssl_content:", ssl_content.shape)
800
+ print("ref_audio_sr:", ref_audio_sr.shape)
801
+ print("ref_seq:", ref_seq.shape)
802
+ ref_seq = ref_seq.to("cuda")
803
+ print("text_seq:", text_seq.shape)
804
+ text_seq = text_seq.to("cuda")
805
+ print("ref_bert:", ref_bert.shape)
806
+ ref_bert = ref_bert.to("cuda")
807
+ print("text_bert:", text_bert.shape)
808
+ text_bert = text_bert.to("cuda")
809
+
810
+ top_k = torch.LongTensor([5]).to("cuda")
811
+
812
+ with torch.no_grad():
813
+ audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
814
+ print("start write wav")
815
+ soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
816
+
817
+
818
+ import text
819
+ import json
820
+
821
+
822
+ def export_symbel(version="v2"):
823
+ if version == "v1":
824
+ symbols = text._symbol_to_id_v1
825
+ with open("onnx/symbols_v1.json", "w") as file:
826
+ json.dump(symbols, file, indent=4)
827
+ else:
828
+ symbols = text._symbol_to_id_v2
829
+ with open("onnx/symbols_v2.json", "w") as file:
830
+ json.dump(symbols, file, indent=4)
831
+
832
+
833
+ def main():
834
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
835
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
836
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
837
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
838
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
839
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
840
+ parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
841
+ parser.add_argument("--device", help="Device to use")
842
+
843
+ args = parser.parse_args()
844
+ export(
845
+ gpt_path=args.gpt_model,
846
+ vits_path=args.sovits_model,
847
+ ref_audio_path=args.ref_audio,
848
+ ref_text=args.ref_text,
849
+ output_path=args.output_path,
850
+ device=args.device,
851
+ export_bert_and_ssl=args.export_common_model,
852
+ )
853
+
854
+
855
+ import inference_webui
856
+
857
+ if __name__ == "__main__":
858
+ inference_webui.is_half = False
859
+ inference_webui.dtype = torch.float32
860
+ main()
861
+ # test()
GPT_SoVITS/export_torch_script_v3.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from export_torch_script import (
3
+ T2SModel,
4
+ get_raw_t2s_model,
5
+ resamplex,
6
+ spectrogram_torch,
7
+ )
8
+ from f5_tts.model.backbones.dit import DiT
9
+ from inference_webui import get_phones_and_bert
10
+ import librosa
11
+ from module import commons
12
+ from module.mel_processing import mel_spectrogram_torch
13
+ from module.models_onnx import CFM, SynthesizerTrnV3
14
+ import numpy as np
15
+ import torch._dynamo.config
16
+ import torchaudio
17
+ import logging
18
+ import uvicorn
19
+ import torch
20
+ import soundfile
21
+ from librosa.filters import mel as librosa_mel_fn
22
+
23
+
24
+ from inference_webui import get_spepc, norm_spec, resample, ssl_model
25
+
26
+ logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
27
+ logger = logging.getLogger("uvicorn")
28
+
29
+ is_half = True
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ now_dir = os.getcwd()
32
+
33
+
34
+ class MelSpectrgram(torch.nn.Module):
35
+ def __init__(
36
+ self,
37
+ dtype,
38
+ device,
39
+ n_fft,
40
+ num_mels,
41
+ sampling_rate,
42
+ hop_size,
43
+ win_size,
44
+ fmin,
45
+ fmax,
46
+ center=False,
47
+ ):
48
+ super().__init__()
49
+ self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
50
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
51
+ self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
52
+ self.n_fft: int = n_fft
53
+ self.hop_size: int = hop_size
54
+ self.win_size: int = win_size
55
+ self.center: bool = center
56
+
57
+ def forward(self, y):
58
+ y = torch.nn.functional.pad(
59
+ y.unsqueeze(1),
60
+ (
61
+ int((self.n_fft - self.hop_size) / 2),
62
+ int((self.n_fft - self.hop_size) / 2),
63
+ ),
64
+ mode="reflect",
65
+ )
66
+ y = y.squeeze(1)
67
+ spec = torch.stft(
68
+ y,
69
+ self.n_fft,
70
+ hop_length=self.hop_size,
71
+ win_length=self.win_size,
72
+ window=self.hann_window,
73
+ center=self.center,
74
+ pad_mode="reflect",
75
+ normalized=False,
76
+ onesided=True,
77
+ return_complex=False,
78
+ )
79
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
80
+ spec = torch.matmul(self.mel_basis, spec)
81
+ # spec = spectral_normalize_torch(spec)
82
+ spec = torch.log(torch.clamp(spec, min=1e-5))
83
+ return spec
84
+
85
+
86
+ class ExportDitBlocks(torch.nn.Module):
87
+ def __init__(self, dit: DiT):
88
+ super().__init__()
89
+ self.transformer_blocks = dit.transformer_blocks
90
+ self.norm_out = dit.norm_out
91
+ self.proj_out = dit.proj_out
92
+ self.depth = dit.depth
93
+
94
+ def forward(self, x, t, mask, rope):
95
+ for block in self.transformer_blocks:
96
+ x = block(x, t, mask=mask, rope=(rope, 1.0))
97
+ x = self.norm_out(x, t)
98
+ output = self.proj_out(x)
99
+ return output
100
+
101
+
102
+ class ExportDitEmbed(torch.nn.Module):
103
+ def __init__(self, dit: DiT):
104
+ super().__init__()
105
+ self.time_embed = dit.time_embed
106
+ self.d_embed = dit.d_embed
107
+ self.text_embed = dit.text_embed
108
+ self.input_embed = dit.input_embed
109
+ self.rotary_embed = dit.rotary_embed
110
+ self.rotary_embed.inv_freq.to(device)
111
+
112
+ def forward(
113
+ self,
114
+ x0: torch.Tensor, # nosied input audio # noqa: F722
115
+ cond0: torch.Tensor, # masked cond audio # noqa: F722
116
+ x_lens: torch.Tensor,
117
+ time: torch.Tensor, # time step # noqa: F821 F722
118
+ dt_base_bootstrap: torch.Tensor,
119
+ text0: torch.Tensor, # noqa: F722#####condition feature
120
+ ):
121
+ x = x0.transpose(2, 1)
122
+ cond = cond0.transpose(2, 1)
123
+ text = text0.transpose(2, 1)
124
+ mask = commons.sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
125
+
126
+ t = self.time_embed(time) + self.d_embed(dt_base_bootstrap)
127
+ text_embed = self.text_embed(text, x.shape[1])
128
+ rope_t = torch.arange(x.shape[1], device=device)
129
+ rope, _ = self.rotary_embed(rope_t)
130
+ x = self.input_embed(x, cond, text_embed)
131
+ return x, t, mask, rope
132
+
133
+
134
+ class ExportDiT(torch.nn.Module):
135
+ def __init__(self, dit: DiT):
136
+ super().__init__()
137
+ if dit != None:
138
+ self.embed = ExportDitEmbed(dit)
139
+ self.blocks = ExportDitBlocks(dit)
140
+ else:
141
+ self.embed = None
142
+ self.blocks = None
143
+
144
+ def forward( # x, prompt_x, x_lens, t, style,cond
145
+ self, # d is channel,n is T
146
+ x0: torch.Tensor, # nosied input audio # noqa: F722
147
+ cond0: torch.Tensor, # masked cond audio # noqa: F722
148
+ x_lens: torch.Tensor,
149
+ time: torch.Tensor, # time step # noqa: F821 F722
150
+ dt_base_bootstrap: torch.Tensor,
151
+ text0: torch.Tensor, # noqa: F722#####condition feature
152
+ ):
153
+ x, t, mask, rope = self.embed(x0, cond0, x_lens, time, dt_base_bootstrap, text0)
154
+ output = self.blocks(x, t, mask, rope)
155
+ return output
156
+
157
+
158
+ class ExportCFM(torch.nn.Module):
159
+ def __init__(self, cfm: CFM):
160
+ super().__init__()
161
+ self.cfm = cfm
162
+
163
+ def forward(
164
+ self,
165
+ fea_ref: torch.Tensor,
166
+ fea_todo_chunk: torch.Tensor,
167
+ mel2: torch.Tensor,
168
+ sample_steps: torch.LongTensor,
169
+ ):
170
+ T_min = fea_ref.size(2)
171
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
172
+ cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
173
+ cfm_res = cfm_res[:, :, mel2.shape[2] :]
174
+ mel2 = cfm_res[:, :, -T_min:]
175
+ fea_ref = fea_todo_chunk[:, :, -T_min:]
176
+ return cfm_res, fea_ref, mel2
177
+
178
+
179
+ mel_fn = lambda x: mel_spectrogram_torch(
180
+ x,
181
+ **{
182
+ "n_fft": 1024,
183
+ "win_size": 1024,
184
+ "hop_size": 256,
185
+ "num_mels": 100,
186
+ "sampling_rate": 24000,
187
+ "fmin": 0,
188
+ "fmax": None,
189
+ "center": False,
190
+ },
191
+ )
192
+
193
+ spec_min = -12
194
+ spec_max = 2
195
+
196
+
197
+ @torch.jit.script
198
+ def norm_spec(x):
199
+ spec_min = -12
200
+ spec_max = 2
201
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
202
+
203
+
204
+ def denorm_spec(x):
205
+ spec_min = -12
206
+ spec_max = 2
207
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
208
+
209
+
210
+ class ExportGPTSovitsHalf(torch.nn.Module):
211
+ def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
212
+ super().__init__()
213
+ self.hps = hps
214
+ self.t2s_m = t2s_m
215
+ self.vq_model = vq_model
216
+ self.mel2 = MelSpectrgram(
217
+ dtype=torch.float32,
218
+ device=device,
219
+ n_fft=1024,
220
+ num_mels=100,
221
+ sampling_rate=24000,
222
+ hop_size=256,
223
+ win_size=1024,
224
+ fmin=0,
225
+ fmax=None,
226
+ center=False,
227
+ )
228
+ # self.dtype = dtype
229
+ self.filter_length: int = hps.data.filter_length
230
+ self.sampling_rate: int = hps.data.sampling_rate
231
+ self.hop_length: int = hps.data.hop_length
232
+ self.win_length: int = hps.data.win_length
233
+
234
+ def forward(
235
+ self,
236
+ ssl_content,
237
+ ref_audio_32k: torch.FloatTensor,
238
+ phoneme_ids0,
239
+ phoneme_ids1,
240
+ bert1,
241
+ bert2,
242
+ top_k,
243
+ ):
244
+ refer = spectrogram_torch(
245
+ ref_audio_32k,
246
+ self.filter_length,
247
+ self.sampling_rate,
248
+ self.hop_length,
249
+ self.win_length,
250
+ center=False,
251
+ ).to(ssl_content.dtype)
252
+
253
+ codes = self.vq_model.extract_latent(ssl_content)
254
+ prompt_semantic = codes[0, 0]
255
+ prompt = prompt_semantic.unsqueeze(0)
256
+ # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
257
+
258
+ pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
259
+ # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
260
+
261
+ ge = self.vq_model.create_ge(refer)
262
+ # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
263
+
264
+ prompt_ = prompt.unsqueeze(0)
265
+ fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
266
+ # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
267
+ # print(prompt_.shape, phoneme_ids0.shape, ge.shape)
268
+ # print(fea_ref.shape)
269
+
270
+ ref_24k = resamplex(ref_audio_32k, 32000, 24000)
271
+ mel2 = norm_spec(self.mel2(ref_24k)).to(ssl_content.dtype)
272
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
273
+ mel2 = mel2[:, :, :T_min]
274
+ fea_ref = fea_ref[:, :, :T_min]
275
+ if T_min > 468:
276
+ mel2 = mel2[:, :, -468:]
277
+ fea_ref = fea_ref[:, :, -468:]
278
+ T_min = 468
279
+
280
+ fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
281
+ # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
282
+ # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
283
+ # print(fea_todo.shape)
284
+
285
+ return fea_ref, fea_todo, mel2
286
+
287
+
288
+ class GPTSoVITSV3(torch.nn.Module):
289
+ def __init__(self, gpt_sovits_half, cfm, bigvgan):
290
+ super().__init__()
291
+ self.gpt_sovits_half = gpt_sovits_half
292
+ self.cfm = cfm
293
+ self.bigvgan = bigvgan
294
+
295
+ def forward(
296
+ self,
297
+ ssl_content,
298
+ ref_audio_32k: torch.FloatTensor,
299
+ phoneme_ids0: torch.LongTensor,
300
+ phoneme_ids1: torch.LongTensor,
301
+ bert1,
302
+ bert2,
303
+ top_k: torch.LongTensor,
304
+ sample_steps: torch.LongTensor,
305
+ ):
306
+ # current_time = datetime.now()
307
+ # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
308
+ fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
309
+ ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
310
+ )
311
+ chunk_len = 934 - fea_ref.shape[2]
312
+ wav_gen_list = []
313
+ idx = 0
314
+ wav_gen_length = fea_todo.shape[2] * 256
315
+ while 1:
316
+ # current_time = datetime.now()
317
+ # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
318
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
319
+ if fea_todo_chunk.shape[-1] == 0:
320
+ break
321
+
322
+ # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
323
+ # 所以在这里补0让他shape维持不变
324
+ # 但是这样会导致生成的音频长度不对,所以在最后截取一下。
325
+ # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
326
+ complete_len = chunk_len - fea_todo_chunk.shape[-1]
327
+ if complete_len != 0:
328
+ fea_todo_chunk = torch.cat(
329
+ [
330
+ fea_todo_chunk,
331
+ torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
332
+ ],
333
+ 2,
334
+ )
335
+
336
+ cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
337
+ idx += chunk_len
338
+
339
+ cfm_res = denorm_spec(cfm_res)
340
+ bigvgan_res = self.bigvgan(cfm_res)
341
+ wav_gen_list.append(bigvgan_res)
342
+
343
+ wav_gen = torch.cat(wav_gen_list, 2)
344
+ return wav_gen[0][0][:wav_gen_length]
345
+
346
+
347
+ def init_bigvgan():
348
+ global bigvgan_model
349
+ from BigVGAN import bigvgan
350
+
351
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained(
352
+ "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
353
+ use_cuda_kernel=False,
354
+ ) # if True, RuntimeError: Ninja is required to load C++ extensions
355
+ # remove weight norm in the model and set to eval mode
356
+ bigvgan_model.remove_weight_norm()
357
+ bigvgan_model = bigvgan_model.eval()
358
+ if is_half == True:
359
+ bigvgan_model = bigvgan_model.half().to(device)
360
+ else:
361
+ bigvgan_model = bigvgan_model.to(device)
362
+
363
+
364
+ class Sovits:
365
+ def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
366
+ self.vq_model = vq_model
367
+ self.hps = hps
368
+ cfm.estimator = ExportDiT(cfm.estimator)
369
+ self.cfm = cfm
370
+
371
+
372
+ class DictToAttrRecursive(dict):
373
+ def __init__(self, input_dict):
374
+ super().__init__(input_dict)
375
+ for key, value in input_dict.items():
376
+ if isinstance(value, dict):
377
+ value = DictToAttrRecursive(value)
378
+ self[key] = value
379
+ setattr(self, key, value)
380
+
381
+ def __getattr__(self, item):
382
+ try:
383
+ return self[item]
384
+ except KeyError:
385
+ raise AttributeError(f"Attribute {item} not found")
386
+
387
+ def __setattr__(self, key, value):
388
+ if isinstance(value, dict):
389
+ value = DictToAttrRecursive(value)
390
+ super(DictToAttrRecursive, self).__setitem__(key, value)
391
+ super().__setattr__(key, value)
392
+
393
+ def __delattr__(self, item):
394
+ try:
395
+ del self[item]
396
+ except KeyError:
397
+ raise AttributeError(f"Attribute {item} not found")
398
+
399
+
400
+ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
401
+
402
+
403
+ def get_sovits_weights(sovits_path):
404
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
405
+ is_exist_s2gv3 = os.path.exists(path_sovits_v3)
406
+
407
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
408
+ if if_lora_v3 == True and is_exist_s2gv3 == False:
409
+ logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
410
+
411
+ dict_s2 = load_sovits_new(sovits_path)
412
+ hps = dict_s2["config"]
413
+ hps = DictToAttrRecursive(hps)
414
+ hps.model.semantic_frame_rate = "25hz"
415
+ if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
416
+ hps.model.version = "v2" # v3model,v2sybomls
417
+ elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
418
+ hps.model.version = "v1"
419
+ else:
420
+ hps.model.version = "v2"
421
+
422
+ if model_version == "v3":
423
+ hps.model.version = "v3"
424
+
425
+ logger.info(f"hps: {hps}")
426
+
427
+ vq_model = SynthesizerTrnV3(
428
+ hps.data.filter_length // 2 + 1,
429
+ hps.train.segment_size // hps.data.hop_length,
430
+ n_speakers=hps.data.n_speakers,
431
+ **hps.model,
432
+ )
433
+ # init_bigvgan()
434
+ model_version = hps.model.version
435
+ logger.info(f"模型版本: {model_version}")
436
+
437
+ if is_half == True:
438
+ vq_model = vq_model.half().to(device)
439
+ else:
440
+ vq_model = vq_model.to(device)
441
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
442
+ vq_model.eval()
443
+
444
+ cfm = vq_model.cfm
445
+ del vq_model.cfm
446
+
447
+ sovits = Sovits(vq_model, cfm, hps)
448
+ return sovits
449
+
450
+
451
+ logger.info(f"torch version {torch.__version__}")
452
+ # ssl_model = cnhubert.get_model()
453
+ # if is_half:
454
+ # ssl_model = ssl_model.half().to(device)
455
+ # else:
456
+ # ssl_model = ssl_model.to(device)
457
+
458
+
459
+ def export_cfm(
460
+ e_cfm: ExportCFM,
461
+ mu: torch.Tensor,
462
+ x_lens: torch.LongTensor,
463
+ prompt: torch.Tensor,
464
+ n_timesteps: torch.IntTensor,
465
+ temperature=1.0,
466
+ ):
467
+ cfm = e_cfm.cfm
468
+
469
+ B, T = mu.size(0), mu.size(1)
470
+ x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
471
+ print("x:", x.shape, x.dtype)
472
+ prompt_len = prompt.size(-1)
473
+ prompt_x = torch.zeros_like(x, dtype=mu.dtype)
474
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
475
+ x[..., :prompt_len] = 0.0
476
+ mu = mu.transpose(2, 1)
477
+
478
+ ntimestep = int(n_timesteps)
479
+
480
+ t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
481
+ d = torch.tensor(1.0 / ntimestep, dtype=x.dtype, device=x.device)
482
+
483
+ t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
484
+ d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
485
+
486
+ print(
487
+ "cfm input shapes:",
488
+ x.shape,
489
+ prompt_x.shape,
490
+ x_lens.shape,
491
+ t_tensor.shape,
492
+ d_tensor.shape,
493
+ mu.shape,
494
+ )
495
+
496
+ print("cfm input dtypes:", x.dtype, prompt_x.dtype, x_lens.dtype, t_tensor.dtype, d_tensor.dtype, mu.dtype)
497
+
498
+ estimator: ExportDiT = torch.jit.trace(
499
+ cfm.estimator,
500
+ optimize=True,
501
+ example_inputs=(x, prompt_x, x_lens, t_tensor, d_tensor, mu),
502
+ )
503
+ estimator.save("onnx/ad/estimator.pt")
504
+ # torch.onnx.export(
505
+ # cfm.estimator,
506
+ # (x, prompt_x, x_lens, t_tensor, d_tensor, mu),
507
+ # "onnx/ad/dit.onnx",
508
+ # input_names=["x", "prompt_x", "x_lens", "t", "d", "mu"],
509
+ # output_names=["output"],
510
+ # dynamic_axes={
511
+ # "x": [2],
512
+ # "prompt_x": [2],
513
+ # "mu": [2],
514
+ # },
515
+ # )
516
+ print("save estimator ok")
517
+ cfm.estimator = estimator
518
+ export_cfm = torch.jit.script(e_cfm)
519
+ export_cfm.save("onnx/ad/cfm.pt")
520
+ # sovits.cfm = cfm
521
+ # cfm.save("onnx/ad/cfm.pt")
522
+ return export_cfm
523
+
524
+
525
+ def export():
526
+ sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
527
+
528
+ init_bigvgan()
529
+
530
+ dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
531
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
532
+ print("#### get_raw_t2s_model ####")
533
+ print(raw_t2s.config)
534
+
535
+ if is_half:
536
+ raw_t2s = raw_t2s.half().to(device)
537
+
538
+ t2s_m = T2SModel(raw_t2s)
539
+ t2s_m.eval()
540
+ script_t2s = torch.jit.script(t2s_m).to(device)
541
+
542
+ hps = sovits.hps
543
+ ref_wav_path = "onnx/ad/ref.wav"
544
+ speed = 1.0
545
+ sample_steps = 32
546
+ dtype = torch.float16 if is_half == True else torch.float32
547
+ refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
548
+ zero_wav = np.zeros(
549
+ int(hps.data.sampling_rate * 0.3),
550
+ dtype=np.float16 if is_half == True else np.float32,
551
+ )
552
+
553
+ with torch.no_grad():
554
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
555
+ wav16k = torch.from_numpy(wav16k)
556
+ zero_wav_torch = torch.from_numpy(zero_wav)
557
+
558
+ if is_half == True:
559
+ wav16k = wav16k.half().to(device)
560
+ zero_wav_torch = zero_wav_torch.half().to(device)
561
+ else:
562
+ wav16k = wav16k.to(device)
563
+ zero_wav_torch = zero_wav_torch.to(device)
564
+ wav16k = torch.cat([wav16k, zero_wav_torch])
565
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
566
+ codes = sovits.vq_model.extract_latent(ssl_content)
567
+ prompt_semantic = codes[0, 0]
568
+ prompt = prompt_semantic.unsqueeze(0).to(device)
569
+
570
+ phones1, bert1, norm_text1 = get_phones_and_bert(
571
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
572
+ )
573
+ phones2, bert2, norm_text2 = get_phones_and_bert(
574
+ "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
575
+ "auto",
576
+ "v3",
577
+ )
578
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
579
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
580
+
581
+ # codes = sovits.vq_model.extract_latent(ssl_content)
582
+ # prompt_semantic = codes[0, 0]
583
+ # prompts = prompt_semantic.unsqueeze(0)
584
+
585
+ top_k = torch.LongTensor([15]).to(device)
586
+ print("topk", top_k)
587
+
588
+ bert1 = bert1.T.to(device)
589
+ bert2 = bert2.T.to(device)
590
+ print(
591
+ prompt.dtype,
592
+ phoneme_ids0.dtype,
593
+ phoneme_ids1.dtype,
594
+ bert1.dtype,
595
+ bert2.dtype,
596
+ top_k.dtype,
597
+ )
598
+ print(
599
+ prompt.shape,
600
+ phoneme_ids0.shape,
601
+ phoneme_ids1.shape,
602
+ bert1.shape,
603
+ bert2.shape,
604
+ top_k.shape,
605
+ )
606
+ pred_semantic = t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
607
+
608
+ ge = sovits.vq_model.create_ge(refer)
609
+ prompt_ = prompt.unsqueeze(0)
610
+
611
+ torch._dynamo.mark_dynamic(prompt_, 2)
612
+ torch._dynamo.mark_dynamic(phoneme_ids0, 1)
613
+
614
+ fea_ref = sovits.vq_model(prompt_, phoneme_ids0, ge)
615
+
616
+ inputs = {
617
+ "forward": (prompt_, phoneme_ids0, ge),
618
+ "extract_latent": ssl_content,
619
+ "create_ge": refer,
620
+ }
621
+
622
+ trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
623
+ trace_vq_model.save("onnx/ad/vq_model.pt")
624
+
625
+ print(fea_ref.shape, fea_ref.dtype, ge.shape)
626
+ print(prompt_.shape, phoneme_ids0.shape, ge.shape)
627
+
628
+ # vq_model = torch.jit.trace(
629
+ # sovits.vq_model,
630
+ # optimize=True,
631
+ # # strict=False,
632
+ # example_inputs=(prompt_, phoneme_ids0, ge),
633
+ # )
634
+ # vq_model = sovits.vq_model
635
+ vq_model = trace_vq_model
636
+
637
+ gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
638
+ torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
639
+
640
+ ref_audio, sr = torchaudio.load(ref_wav_path)
641
+ ref_audio = ref_audio.to(device).float()
642
+ if ref_audio.shape[0] == 2:
643
+ ref_audio = ref_audio.mean(0).unsqueeze(0)
644
+ if sr != 24000:
645
+ ref_audio = resample(ref_audio, sr)
646
+ # mel2 = mel_fn(ref_audio)
647
+ mel2 = norm_spec(mel_fn(ref_audio))
648
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
649
+ fea_ref = fea_ref[:, :, :T_min]
650
+ print("fea_ref:", fea_ref.shape, T_min)
651
+ if T_min > 468:
652
+ mel2 = mel2[:, :, -468:]
653
+ fea_ref = fea_ref[:, :, -468:]
654
+ T_min = 468
655
+ chunk_len = 934 - T_min
656
+ mel2 = mel2.to(dtype)
657
+
658
+ # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
659
+ fea_todo = vq_model(pred_semantic, phoneme_ids1, ge)
660
+
661
+ cfm_resss = []
662
+ idx = 0
663
+ sample_steps = torch.LongTensor([sample_steps]).to(device)
664
+ export_cfm_ = ExportCFM(sovits.cfm)
665
+ while 1:
666
+ print("idx:", idx)
667
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
668
+ if fea_todo_chunk.shape[-1] == 0:
669
+ break
670
+
671
+ print(
672
+ "export_cfm:",
673
+ fea_ref.shape,
674
+ fea_todo_chunk.shape,
675
+ mel2.shape,
676
+ sample_steps.shape,
677
+ )
678
+ if idx == 0:
679
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
680
+ export_cfm_ = export_cfm(
681
+ export_cfm_,
682
+ fea,
683
+ torch.LongTensor([fea.size(1)]).to(fea.device),
684
+ mel2,
685
+ sample_steps,
686
+ )
687
+ # torch.onnx.export(
688
+ # export_cfm_,
689
+ # (
690
+ # fea_ref,
691
+ # fea_todo_chunk,
692
+ # mel2,
693
+ # sample_steps,
694
+ # ),
695
+ # "onnx/ad/cfm.onnx",
696
+ # input_names=["fea_ref", "fea_todo_chunk", "mel2", "sample_steps"],
697
+ # output_names=["cfm_res", "fea_ref_", "mel2_"],
698
+ # dynamic_axes={
699
+ # "fea_ref": [2],
700
+ # "fea_todo_chunk": [2],
701
+ # "mel2": [2],
702
+ # },
703
+ # )
704
+
705
+ idx += chunk_len
706
+
707
+ cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
708
+ cfm_resss.append(cfm_res)
709
+ continue
710
+
711
+ cmf_res = torch.cat(cfm_resss, 2)
712
+ cmf_res = denorm_spec(cmf_res).to(device)
713
+ print("cmf_res:", cmf_res.shape, cmf_res.dtype)
714
+ with torch.inference_mode():
715
+ cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
716
+ torch._dynamo.mark_dynamic(cmf_res_rand, 2)
717
+ bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
718
+ bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
719
+ wav_gen = bigvgan_model(cmf_res)
720
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
721
+ audio = wav_gen[0][0].cpu().detach().numpy()
722
+
723
+ sr = 24000
724
+ soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
725
+
726
+
727
+ from datetime import datetime
728
+
729
+
730
+ def test_export(
731
+ todo_text,
732
+ gpt_sovits_v3_half,
733
+ cfm,
734
+ bigvgan,
735
+ output,
736
+ ):
737
+ # hps = sovits.hps
738
+ ref_wav_path = "onnx/ad/ref.wav"
739
+ speed = 1.0
740
+ sample_steps = 8
741
+
742
+ dtype = torch.float16 if is_half == True else torch.float32
743
+
744
+ zero_wav = np.zeros(
745
+ int(16000 * 0.3),
746
+ dtype=np.float16 if is_half == True else np.float32,
747
+ )
748
+
749
+ with torch.no_grad():
750
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
751
+ wav16k = torch.from_numpy(wav16k)
752
+ zero_wav_torch = torch.from_numpy(zero_wav)
753
+
754
+ if is_half == True:
755
+ wav16k = wav16k.half().to(device)
756
+ zero_wav_torch = zero_wav_torch.half().to(device)
757
+ else:
758
+ wav16k = wav16k.to(device)
759
+ zero_wav_torch = zero_wav_torch.to(device)
760
+ wav16k = torch.cat([wav16k, zero_wav_torch])
761
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
762
+
763
+ ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
764
+ ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
765
+
766
+ phones1, bert1, norm_text1 = get_phones_and_bert(
767
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
768
+ )
769
+ phones2, bert2, norm_text2 = get_phones_and_bert(
770
+ todo_text,
771
+ "zh",
772
+ "v3",
773
+ )
774
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
775
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
776
+
777
+ bert1 = bert1.T.to(device)
778
+ bert2 = bert2.T.to(device)
779
+ top_k = torch.LongTensor([15]).to(device)
780
+
781
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
782
+ logger.info("start inference %s", current_time)
783
+ print(
784
+ ssl_content.shape,
785
+ ref_audio_32k.shape,
786
+ phoneme_ids0.shape,
787
+ phoneme_ids1.shape,
788
+ bert1.shape,
789
+ bert2.shape,
790
+ top_k.shape,
791
+ )
792
+ fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
793
+ ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
794
+ )
795
+ chunk_len = 934 - fea_ref.shape[2]
796
+ print(fea_ref.shape, fea_todo.shape, mel2.shape)
797
+
798
+ cfm_resss = []
799
+ sample_steps = torch.LongTensor([sample_steps])
800
+ idx = 0
801
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
802
+ logger.info("start cfm %s", current_time)
803
+ wav_gen_length = fea_todo.shape[2] * 256
804
+
805
+ while 1:
806
+ current_time = datetime.now()
807
+ print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
808
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
809
+ if fea_todo_chunk.shape[-1] == 0:
810
+ break
811
+
812
+ complete_len = chunk_len - fea_todo_chunk.shape[-1]
813
+ if complete_len != 0:
814
+ fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(device).to(dtype)], 2)
815
+
816
+ cfm_res, fea_ref, mel2 = cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
817
+ # if complete_len > 0 :
818
+ # cfm_res = cfm_res[:, :, :-complete_len]
819
+ # fea_ref = fea_ref[:, :, :-complete_len]
820
+ # mel2 = mel2[:, :, :-complete_len]
821
+
822
+ idx += chunk_len
823
+
824
+ current_time = datetime.now()
825
+ print("cfm end", current_time.strftime("%Y-%m-%d %H:%M:%S"))
826
+ cfm_res = denorm_spec(cfm_res).to(device)
827
+ bigvgan_res = bigvgan(cfm_res)
828
+ cfm_resss.append(bigvgan_res)
829
+
830
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
831
+ logger.info("start bigvgan %s", current_time)
832
+ wav_gen = torch.cat(cfm_resss, 2)
833
+ # cmf_res = denorm_spec(cmf_res)
834
+ # cmf_res = cmf_res.to(device)
835
+ # print("cmf_res:", cmf_res.shape)
836
+
837
+ # cmf_res = torch.cat([cmf_res,torch.zeros([1,100,2000-cmf_res.size(2)],device=device,dtype=cmf_res.dtype)], 2)
838
+
839
+ # wav_gen = bigvgan(cmf_res)
840
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
841
+ wav_gen = wav_gen[:, :, :wav_gen_length]
842
+
843
+ audio = wav_gen[0][0].cpu().detach().numpy()
844
+ logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
845
+ sr = 24000
846
+ soundfile.write(output, (audio * 32768).astype(np.int16), sr)
847
+
848
+
849
+ def test_export1(
850
+ todo_text,
851
+ gpt_sovits_v3,
852
+ output,
853
+ ):
854
+ # hps = sovits.hps
855
+ ref_wav_path = "onnx/ad/ref.wav"
856
+ speed = 1.0
857
+ sample_steps = torch.LongTensor([16])
858
+
859
+ dtype = torch.float16 if is_half == True else torch.float32
860
+
861
+ zero_wav = np.zeros(
862
+ int(24000 * 0.3),
863
+ dtype=np.float16 if is_half == True else np.float32,
864
+ )
865
+
866
+ with torch.no_grad():
867
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
868
+ wav16k = torch.from_numpy(wav16k)
869
+ zero_wav_torch = torch.from_numpy(zero_wav)
870
+
871
+ if is_half == True:
872
+ wav16k = wav16k.half().to(device)
873
+ zero_wav_torch = zero_wav_torch.half().to(device)
874
+ else:
875
+ wav16k = wav16k.to(device)
876
+ zero_wav_torch = zero_wav_torch.to(device)
877
+ wav16k = torch.cat([wav16k, zero_wav_torch])
878
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
879
+ print("ssl_content:", ssl_content.shape, ssl_content.dtype)
880
+
881
+ ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
882
+ ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
883
+
884
+ phones1, bert1, norm_text1 = get_phones_and_bert(
885
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
886
+ )
887
+ phones2, bert2, norm_text2 = get_phones_and_bert(
888
+ todo_text,
889
+ "zh",
890
+ "v3",
891
+ )
892
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
893
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
894
+
895
+ bert1 = bert1.T.to(device)
896
+ bert2 = bert2.T.to(device)
897
+ top_k = torch.LongTensor([15]).to(device)
898
+
899
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
900
+ logger.info("start inference %s", current_time)
901
+ print(
902
+ ssl_content.shape,
903
+ ref_audio_32k.shape,
904
+ phoneme_ids0.shape,
905
+ phoneme_ids1.shape,
906
+ bert1.shape,
907
+ bert2.shape,
908
+ top_k.shape,
909
+ )
910
+ wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
911
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
912
+
913
+ wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
914
+
915
+ audio = wav_gen.cpu().detach().numpy()
916
+ logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
917
+ sr = 24000
918
+ soundfile.write(output, (audio * 32768).astype(np.int16), sr)
919
+
920
+
921
+ import time
922
+
923
+
924
+ def test_():
925
+ sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
926
+
927
+ # cfm = ExportCFM(sovits.cfm)
928
+ # cfm.cfm.estimator = dit
929
+ sovits.cfm = None
930
+
931
+ cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
932
+ # cfm = torch.jit.optimize_for_inference(cfm)
933
+ cfm = cfm.half().to(device)
934
+
935
+ cfm.eval()
936
+
937
+ logger.info("cfm ok")
938
+
939
+ dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
940
+ # v2 的 gpt 也可以用
941
+ # dict_s1 = torch.load("GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
942
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
943
+ print("#### get_raw_t2s_model ####")
944
+ print(raw_t2s.config)
945
+ if is_half:
946
+ raw_t2s = raw_t2s.half().to(device)
947
+ t2s_m = T2SModel(raw_t2s).half().to(device)
948
+ t2s_m.eval()
949
+ t2s_m = torch.jit.script(t2s_m)
950
+ t2s_m.eval()
951
+ # t2s_m.top_k = 15
952
+ logger.info("t2s_m ok")
953
+
954
+ vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
955
+ # vq_model = torch.jit.optimize_for_inference(vq_model)
956
+ # vq_model = vq_model.half().to(device)
957
+ vq_model.eval()
958
+ # vq_model = sovits.vq_model
959
+ logger.info("vq_model ok")
960
+
961
+ # gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
962
+ # gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
963
+ # gpt_sovits_v3_half = gpt_sovits_v3_half.half()
964
+ # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
965
+ # gpt_sovits_v3_half.eval()
966
+ gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
967
+ logger.info("gpt_sovits_v3_half ok")
968
+
969
+ # init_bigvgan()
970
+ # global bigvgan_model
971
+ bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
972
+ # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
973
+ bigvgan_model = bigvgan_model.half()
974
+ bigvgan_model = bigvgan_model.cuda()
975
+ bigvgan_model.eval()
976
+
977
+ logger.info("bigvgan ok")
978
+
979
+ gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
980
+ gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
981
+ gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
982
+ gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
983
+ gpt_sovits_v3.eval()
984
+ print("save gpt_sovits_v3 ok")
985
+
986
+ time.sleep(5)
987
+ # print("thread:", torch.get_num_threads())
988
+ # print("thread:", torch.get_num_interop_threads())
989
+ # torch.set_num_interop_threads(1)
990
+ # torch.set_num_threads(1)
991
+
992
+ test_export1(
993
+ "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
994
+ gpt_sovits_v3,
995
+ "out.wav",
996
+ )
997
+
998
+ test_export1(
999
+ "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
1000
+ gpt_sovits_v3,
1001
+ "out2.wav",
1002
+ )
1003
+
1004
+ # test_export(
1005
+ # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP. 哈哈哈...",
1006
+ # gpt_sovits_v3_half,
1007
+ # cfm,
1008
+ # bigvgan_model,
1009
+ # "out2.wav",
1010
+ # )
1011
+
1012
+
1013
+ def test_export_gpt_sovits_v3():
1014
+ gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
1015
+ # test_export1(
1016
+ # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
1017
+ # gpt_sovits_v3,
1018
+ # "out3.wav",
1019
+ # )
1020
+ # test_export1(
1021
+ # "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
1022
+ # gpt_sovits_v3,
1023
+ # "out4.wav",
1024
+ # )
1025
+ test_export1(
1026
+ "风萧萧兮易水寒,壮士一去兮不复还.",
1027
+ gpt_sovits_v3,
1028
+ "out5.wav",
1029
+ )
1030
+
1031
+
1032
+ with torch.no_grad():
1033
+ # export()
1034
+ test_()
1035
+ # test_export_gpt_sovits_v3()
GPT_SoVITS/inference_cli.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import soundfile as sf
4
+
5
+ from tools.i18n.i18n import I18nAuto
6
+ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
7
+
8
+ i18n = I18nAuto()
9
+
10
+
11
+ def synthesize(
12
+ GPT_model_path,
13
+ SoVITS_model_path,
14
+ ref_audio_path,
15
+ ref_text_path,
16
+ ref_language,
17
+ target_text_path,
18
+ target_language,
19
+ output_path,
20
+ ):
21
+ # Read reference text
22
+ with open(ref_text_path, "r", encoding="utf-8") as file:
23
+ ref_text = file.read()
24
+
25
+ # Read target text
26
+ with open(target_text_path, "r", encoding="utf-8") as file:
27
+ target_text = file.read()
28
+
29
+ # Change model weights
30
+ change_gpt_weights(gpt_path=GPT_model_path)
31
+ change_sovits_weights(sovits_path=SoVITS_model_path)
32
+
33
+ # Synthesize audio
34
+ synthesis_result = get_tts_wav(
35
+ ref_wav_path=ref_audio_path,
36
+ prompt_text=ref_text,
37
+ prompt_language=i18n(ref_language),
38
+ text=target_text,
39
+ text_language=i18n(target_language),
40
+ top_p=1,
41
+ temperature=1,
42
+ )
43
+
44
+ result_list = list(synthesis_result)
45
+
46
+ if result_list:
47
+ last_sampling_rate, last_audio_data = result_list[-1]
48
+ output_wav_path = os.path.join(output_path, "output.wav")
49
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
50
+ print(f"Audio saved to {output_wav_path}")
51
+
52
+
53
+ def main():
54
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
55
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
56
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
57
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
58
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
59
+ parser.add_argument(
60
+ "--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
61
+ )
62
+ parser.add_argument("--target_text", required=True, help="Path to the target text file")
63
+ parser.add_argument(
64
+ "--target_language",
65
+ required=True,
66
+ choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
67
+ help="Language of the target text",
68
+ )
69
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
70
+
71
+ args = parser.parse_args()
72
+
73
+ synthesize(
74
+ args.gpt_model,
75
+ args.sovits_model,
76
+ args.ref_audio,
77
+ args.ref_text,
78
+ args.ref_language,
79
+ args.target_text,
80
+ args.target_language,
81
+ args.output_path,
82
+ )
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
GPT_SoVITS/inference_gui.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from PyQt5.QtCore import QEvent
4
+ from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
5
+ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
6
+ import soundfile as sf
7
+
8
+ from tools.i18n.i18n import I18nAuto
9
+
10
+ i18n = I18nAuto()
11
+
12
+ from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
13
+
14
+
15
+ class GPTSoVITSGUI(QMainWindow):
16
+ GPT_Path = gpt_path
17
+ SoVITS_Path = sovits_path
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ self.setWindowTitle("GPT-SoVITS GUI")
23
+ self.setGeometry(800, 450, 950, 850)
24
+
25
+ self.setStyleSheet("""
26
+ QWidget {
27
+ background-color: #a3d3b1;
28
+ }
29
+
30
+ QTabWidget::pane {
31
+ background-color: #a3d3b1;
32
+ }
33
+
34
+ QTabWidget::tab-bar {
35
+ alignment: left;
36
+ }
37
+
38
+ QTabBar::tab {
39
+ background: #8da4bf;
40
+ color: #ffffff;
41
+ padding: 8px;
42
+ }
43
+
44
+ QTabBar::tab:selected {
45
+ background: #2a3f54;
46
+ }
47
+
48
+ QLabel {
49
+ color: #000000;
50
+ }
51
+
52
+ QPushButton {
53
+ background-color: #4CAF50;
54
+ color: white;
55
+ padding: 8px;
56
+ border: 1px solid #4CAF50;
57
+ border-radius: 4px;
58
+ }
59
+
60
+ QPushButton:hover {
61
+ background-color: #45a049;
62
+ border: 1px solid #45a049;
63
+ box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
64
+ }
65
+ """)
66
+
67
+ license_text = (
68
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
69
+ "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
70
+ )
71
+ license_label = QLabel(license_text)
72
+ license_label.setWordWrap(True)
73
+
74
+ self.GPT_model_label = QLabel("选择GPT模型:")
75
+ self.GPT_model_input = QLineEdit()
76
+ self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
77
+ self.GPT_model_input.setText(self.GPT_Path)
78
+ self.GPT_model_input.setReadOnly(True)
79
+ self.GPT_model_button = QPushButton("选择GPT模型文件")
80
+ self.GPT_model_button.clicked.connect(self.select_GPT_model)
81
+
82
+ self.SoVITS_model_label = QLabel("选择SoVITS模型:")
83
+ self.SoVITS_model_input = QLineEdit()
84
+ self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
85
+ self.SoVITS_model_input.setText(self.SoVITS_Path)
86
+ self.SoVITS_model_input.setReadOnly(True)
87
+ self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
88
+ self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
89
+
90
+ self.ref_audio_label = QLabel("上传参考音频:")
91
+ self.ref_audio_input = QLineEdit()
92
+ self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
93
+ self.ref_audio_input.setReadOnly(True)
94
+ self.ref_audio_button = QPushButton("选择音频文件")
95
+ self.ref_audio_button.clicked.connect(self.select_ref_audio)
96
+
97
+ self.ref_text_label = QLabel("参考音频文本:")
98
+ self.ref_text_input = QLineEdit()
99
+ self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
100
+ self.ref_text_button = QPushButton("上传文本")
101
+ self.ref_text_button.clicked.connect(self.upload_ref_text)
102
+
103
+ self.ref_language_label = QLabel("参考音频语言:")
104
+ self.ref_language_combobox = QComboBox()
105
+ self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
106
+ self.ref_language_combobox.setCurrentText("多语种混合")
107
+
108
+ self.target_text_label = QLabel("合成目标文本:")
109
+ self.target_text_input = QLineEdit()
110
+ self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
111
+ self.target_text_button = QPushButton("上传文本")
112
+ self.target_text_button.clicked.connect(self.upload_target_text)
113
+
114
+ self.target_language_label = QLabel("合成音频语言:")
115
+ self.target_language_combobox = QComboBox()
116
+ self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
117
+ self.target_language_combobox.setCurrentText("多语种混合")
118
+
119
+ self.output_label = QLabel("输出音频路径:")
120
+ self.output_input = QLineEdit()
121
+ self.output_input.setPlaceholderText("拖拽或选择文件")
122
+ self.output_input.setReadOnly(True)
123
+ self.output_button = QPushButton("选择文件夹")
124
+ self.output_button.clicked.connect(self.select_output_path)
125
+
126
+ self.output_text = QTextEdit()
127
+ self.output_text.setReadOnly(True)
128
+
129
+ self.add_drag_drop_events(
130
+ [
131
+ self.GPT_model_input,
132
+ self.SoVITS_model_input,
133
+ self.ref_audio_input,
134
+ self.ref_text_input,
135
+ self.target_text_input,
136
+ self.output_input,
137
+ ]
138
+ )
139
+
140
+ self.synthesize_button = QPushButton("合成")
141
+ self.synthesize_button.clicked.connect(self.synthesize)
142
+
143
+ self.clear_output_button = QPushButton("清空输出")
144
+ self.clear_output_button.clicked.connect(self.clear_output)
145
+
146
+ self.status_bar = QStatusBar()
147
+
148
+ main_layout = QVBoxLayout()
149
+
150
+ input_layout = QGridLayout(self)
151
+ input_layout.setSpacing(10)
152
+
153
+ input_layout.addWidget(license_label, 0, 0, 1, 3)
154
+
155
+ input_layout.addWidget(self.GPT_model_label, 1, 0)
156
+ input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
157
+ input_layout.addWidget(self.GPT_model_button, 2, 2)
158
+
159
+ input_layout.addWidget(self.SoVITS_model_label, 3, 0)
160
+ input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
161
+ input_layout.addWidget(self.SoVITS_model_button, 4, 2)
162
+
163
+ input_layout.addWidget(self.ref_audio_label, 5, 0)
164
+ input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
165
+ input_layout.addWidget(self.ref_audio_button, 6, 2)
166
+
167
+ input_layout.addWidget(self.ref_language_label, 7, 0)
168
+ input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
169
+ input_layout.addWidget(self.ref_text_label, 9, 0)
170
+ input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
171
+ input_layout.addWidget(self.ref_text_button, 10, 2)
172
+
173
+ input_layout.addWidget(self.target_language_label, 11, 0)
174
+ input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
175
+ input_layout.addWidget(self.target_text_label, 13, 0)
176
+ input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
177
+ input_layout.addWidget(self.target_text_button, 14, 2)
178
+
179
+ input_layout.addWidget(self.output_label, 15, 0)
180
+ input_layout.addWidget(self.output_input, 16, 0, 1, 2)
181
+ input_layout.addWidget(self.output_button, 16, 2)
182
+
183
+ main_layout.addLayout(input_layout)
184
+
185
+ output_layout = QVBoxLayout()
186
+ output_layout.addWidget(self.output_text)
187
+ main_layout.addLayout(output_layout)
188
+
189
+ main_layout.addWidget(self.synthesize_button)
190
+
191
+ main_layout.addWidget(self.clear_output_button)
192
+
193
+ main_layout.addWidget(self.status_bar)
194
+
195
+ self.central_widget = QWidget()
196
+ self.central_widget.setLayout(main_layout)
197
+ self.setCentralWidget(self.central_widget)
198
+
199
+ def dragEnterEvent(self, event):
200
+ if event.mimeData().hasUrls():
201
+ event.acceptProposedAction()
202
+
203
+ def dropEvent(self, event):
204
+ if event.mimeData().hasUrls():
205
+ file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
206
+ if len(file_paths) == 1:
207
+ self.update_ref_audio(file_paths[0])
208
+ else:
209
+ self.update_ref_audio(", ".join(file_paths))
210
+
211
+ def add_drag_drop_events(self, widgets):
212
+ for widget in widgets:
213
+ widget.setAcceptDrops(True)
214
+ widget.installEventFilter(self)
215
+
216
+ def eventFilter(self, obj, event):
217
+ if event.type() in (QEvent.DragEnter, QEvent.Drop):
218
+ mime_data = event.mimeData()
219
+ if mime_data.hasUrls():
220
+ event.acceptProposedAction()
221
+
222
+ return super().eventFilter(obj, event)
223
+
224
+ def select_GPT_model(self):
225
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
226
+ if file_path:
227
+ self.GPT_model_input.setText(file_path)
228
+
229
+ def select_SoVITS_model(self):
230
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
231
+ if file_path:
232
+ self.SoVITS_model_input.setText(file_path)
233
+
234
+ def select_ref_audio(self):
235
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
236
+ if file_path:
237
+ self.update_ref_audio(file_path)
238
+
239
+ def upload_ref_text(self):
240
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
241
+ if file_path:
242
+ with open(file_path, "r", encoding="utf-8") as file:
243
+ content = file.read()
244
+ self.ref_text_input.setText(content)
245
+
246
+ def upload_target_text(self):
247
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
248
+ if file_path:
249
+ with open(file_path, "r", encoding="utf-8") as file:
250
+ content = file.read()
251
+ self.target_text_input.setText(content)
252
+
253
+ def select_output_path(self):
254
+ options = QFileDialog.Options()
255
+ options |= QFileDialog.DontUseNativeDialog
256
+ options |= QFileDialog.ShowDirsOnly
257
+
258
+ folder_dialog = QFileDialog()
259
+ folder_dialog.setOptions(options)
260
+ folder_dialog.setFileMode(QFileDialog.Directory)
261
+
262
+ if folder_dialog.exec_():
263
+ folder_path = folder_dialog.selectedFiles()[0]
264
+ self.output_input.setText(folder_path)
265
+
266
+ def update_ref_audio(self, file_path):
267
+ self.ref_audio_input.setText(file_path)
268
+
269
+ def clear_output(self):
270
+ self.output_text.clear()
271
+
272
+ def synthesize(self):
273
+ GPT_model_path = self.GPT_model_input.text()
274
+ SoVITS_model_path = self.SoVITS_model_input.text()
275
+ ref_audio_path = self.ref_audio_input.text()
276
+ language_combobox = self.ref_language_combobox.currentText()
277
+ language_combobox = i18n(language_combobox)
278
+ ref_text = self.ref_text_input.text()
279
+ target_language_combobox = self.target_language_combobox.currentText()
280
+ target_language_combobox = i18n(target_language_combobox)
281
+ target_text = self.target_text_input.text()
282
+ output_path = self.output_input.text()
283
+
284
+ if GPT_model_path != self.GPT_Path:
285
+ change_gpt_weights(gpt_path=GPT_model_path)
286
+ self.GPT_Path = GPT_model_path
287
+ if SoVITS_model_path != self.SoVITS_Path:
288
+ change_sovits_weights(sovits_path=SoVITS_model_path)
289
+ self.SoVITS_Path = SoVITS_model_path
290
+
291
+ synthesis_result = get_tts_wav(
292
+ ref_wav_path=ref_audio_path,
293
+ prompt_text=ref_text,
294
+ prompt_language=language_combobox,
295
+ text=target_text,
296
+ text_language=target_language_combobox,
297
+ )
298
+
299
+ result_list = list(synthesis_result)
300
+
301
+ if result_list:
302
+ last_sampling_rate, last_audio_data = result_list[-1]
303
+ output_wav_path = os.path.join(output_path, "output.wav")
304
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
305
+
306
+ result = "Audio saved to " + output_wav_path
307
+
308
+ self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
309
+ self.output_text.append("处理结果:\n" + result)
310
+
311
+
312
+ if __name__ == "__main__":
313
+ app = QApplication(sys.argv)
314
+ mainWin = GPTSoVITSGUI()
315
+ mainWin.show()
316
+ sys.exit(app.exec_())
GPT_SoVITS/inference_webui.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ """
9
+
10
+ import logging
11
+ import traceback
12
+ import warnings
13
+
14
+ import torchaudio
15
+
16
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
17
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
18
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
19
+ logging.getLogger("httpx").setLevel(logging.ERROR)
20
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
21
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
22
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
23
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
24
+ warnings.simplefilter(action="ignore", category=FutureWarning)
25
+
26
+ import json
27
+ import os
28
+ import re
29
+ import sys
30
+
31
+ import torch
32
+ from text.LangSegmenter import LangSegmenter
33
+
34
+ try:
35
+ import gradio.analytics as analytics
36
+
37
+ analytics.version_check = lambda: None
38
+ except:
39
+ ...
40
+ version = model_version = os.environ.get("version", "v2")
41
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
42
+ path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
43
+ is_exist_s2gv3 = os.path.exists(path_sovits_v3)
44
+ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
45
+ pretrained_sovits_name = [
46
+ "GPT_SoVITS/pretrained_models/s2G488k.pth",
47
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
48
+ "GPT_SoVITS/pretrained_models/s2Gv3.pth",
49
+ "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
50
+ ]
51
+ pretrained_gpt_name = [
52
+ "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
53
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
54
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
55
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
56
+ ]
57
+
58
+
59
+ _ = [[], []]
60
+ for i in range(4):
61
+ if os.path.exists(pretrained_gpt_name[i]):
62
+ _[0].append(pretrained_gpt_name[i])
63
+ if os.path.exists(pretrained_sovits_name[i]):
64
+ _[-1].append(pretrained_sovits_name[i])
65
+ pretrained_gpt_name, pretrained_sovits_name = _
66
+
67
+
68
+ if os.path.exists("./weight.json"):
69
+ pass
70
+ else:
71
+ with open("./weight.json", "w", encoding="utf-8") as file:
72
+ json.dump({"GPT": {}, "SoVITS": {}}, file)
73
+
74
+ with open("./weight.json", "r", encoding="utf-8") as file:
75
+ weight_data = file.read()
76
+ weight_data = json.loads(weight_data)
77
+ gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
78
+ sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
79
+ if isinstance(gpt_path, list):
80
+ gpt_path = gpt_path[0]
81
+ if isinstance(sovits_path, list):
82
+ sovits_path = sovits_path[0]
83
+
84
+ # gpt_path = os.environ.get(
85
+ # "gpt_path", pretrained_gpt_name
86
+ # )
87
+ # sovits_path = os.environ.get("sovits_path", pretrained_sovits_name)
88
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base")
89
+ bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
90
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
91
+ infer_ttswebui = int(infer_ttswebui)
92
+ is_share = os.environ.get("is_share", "False")
93
+ is_share = eval(is_share)
94
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
96
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
97
+ # is_half=False
98
+ punctuation = set(["!", "?", "…", ",", ".", "-", " "])
99
+ import gradio as gr
100
+ import librosa
101
+ import numpy as np
102
+ from feature_extractor import cnhubert
103
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
104
+
105
+ cnhubert.cnhubert_base_path = cnhubert_base_path
106
+
107
+ import random
108
+
109
+ from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
110
+
111
+
112
+ def set_seed(seed):
113
+ if seed == -1:
114
+ seed = random.randint(0, 1000000)
115
+ seed = int(seed)
116
+ random.seed(seed)
117
+ os.environ["PYTHONHASHSEED"] = str(seed)
118
+ np.random.seed(seed)
119
+ torch.manual_seed(seed)
120
+ torch.cuda.manual_seed(seed)
121
+
122
+
123
+ # set_seed(42)
124
+
125
+ from time import time as ttime
126
+
127
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
128
+ from peft import LoraConfig, get_peft_model
129
+ from text import cleaned_text_to_sequence
130
+ from text.cleaner import clean_text
131
+
132
+ from tools.i18n.i18n import I18nAuto, scan_language_list
133
+
134
+ language = os.environ.get("language", "Auto")
135
+ language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
136
+ i18n = I18nAuto(language=language)
137
+
138
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
139
+
140
+ if torch.cuda.is_available():
141
+ device = "cuda"
142
+ else:
143
+ device = "cpu"
144
+
145
+ dict_language_v1 = {
146
+ i18n("中文"): "all_zh", # 全部按中文识别
147
+ i18n("英文"): "en", # 全部按英文识别#######不变
148
+ i18n("日文"): "all_ja", # 全部按日文识别
149
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
150
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
151
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
152
+ }
153
+ dict_language_v2 = {
154
+ i18n("中文"): "all_zh", # 全部按中文识别
155
+ i18n("英文"): "en", # 全部按英文识别#######不变
156
+ i18n("日文"): "all_ja", # 全部按日文识别
157
+ i18n("粤语"): "all_yue", # 全部按中文识别
158
+ i18n("韩文"): "all_ko", # 全部按韩文识别
159
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
160
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
161
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
162
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
163
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
164
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
165
+ }
166
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
167
+
168
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
169
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
170
+ if is_half == True:
171
+ bert_model = bert_model.half().to(device)
172
+ else:
173
+ bert_model = bert_model.to(device)
174
+
175
+
176
+ def get_bert_feature(text, word2ph):
177
+ with torch.no_grad():
178
+ inputs = tokenizer(text, return_tensors="pt")
179
+ for i in inputs:
180
+ inputs[i] = inputs[i].to(device)
181
+ res = bert_model(**inputs, output_hidden_states=True)
182
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
183
+ assert len(word2ph) == len(text)
184
+ phone_level_feature = []
185
+ for i in range(len(word2ph)):
186
+ repeat_feature = res[i].repeat(word2ph[i], 1)
187
+ phone_level_feature.append(repeat_feature)
188
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
189
+ return phone_level_feature.T
190
+
191
+
192
+ class DictToAttrRecursive(dict):
193
+ def __init__(self, input_dict):
194
+ super().__init__(input_dict)
195
+ for key, value in input_dict.items():
196
+ if isinstance(value, dict):
197
+ value = DictToAttrRecursive(value)
198
+ self[key] = value
199
+ setattr(self, key, value)
200
+
201
+ def __getattr__(self, item):
202
+ try:
203
+ return self[item]
204
+ except KeyError:
205
+ raise AttributeError(f"Attribute {item} not found")
206
+
207
+ def __setattr__(self, key, value):
208
+ if isinstance(value, dict):
209
+ value = DictToAttrRecursive(value)
210
+ super(DictToAttrRecursive, self).__setitem__(key, value)
211
+ super().__setattr__(key, value)
212
+
213
+ def __delattr__(self, item):
214
+ try:
215
+ del self[item]
216
+ except KeyError:
217
+ raise AttributeError(f"Attribute {item} not found")
218
+
219
+
220
+ ssl_model = cnhubert.get_model()
221
+ if is_half == True:
222
+ ssl_model = ssl_model.half().to(device)
223
+ else:
224
+ ssl_model = ssl_model.to(device)
225
+
226
+ resample_transform_dict = {}
227
+
228
+
229
+ def resample(audio_tensor, sr0,sr1):
230
+ global resample_transform_dict
231
+ key="%s-%s"%(sr0,sr1)
232
+ if key not in resample_transform_dict:
233
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
234
+ return resample_transform_dict[key](audio_tensor)
235
+
236
+
237
+ ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
238
+ # symbol_version-model_version-if_lora_v3
239
+ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
240
+
241
+ v3v4set={"v3","v4"}
242
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
243
+ global vq_model, hps, version, model_version, dict_language, if_lora_v3
244
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
245
+ print(sovits_path,version, model_version, if_lora_v3)
246
+ is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
247
+ if if_lora_v3 == True and is_exist == False:
248
+ info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
249
+ gr.Warning(info)
250
+ raise FileExistsError(info)
251
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
252
+ if prompt_language is not None and text_language is not None:
253
+ if prompt_language in list(dict_language.keys()):
254
+ prompt_text_update, prompt_language_update = (
255
+ {"__type__": "update"},
256
+ {"__type__": "update", "value": prompt_language},
257
+ )
258
+ else:
259
+ prompt_text_update = {"__type__": "update", "value": ""}
260
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
261
+ if text_language in list(dict_language.keys()):
262
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
263
+ else:
264
+ text_update = {"__type__": "update", "value": ""}
265
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
266
+ if model_version in v3v4set:
267
+ visible_sample_steps = True
268
+ visible_inp_refs = False
269
+ else:
270
+ visible_sample_steps = False
271
+ visible_inp_refs = True
272
+ yield (
273
+ {"__type__": "update", "choices": list(dict_language.keys())},
274
+ {"__type__": "update", "choices": list(dict_language.keys())},
275
+ prompt_text_update,
276
+ prompt_language_update,
277
+ text_update,
278
+ text_language_update,
279
+ {"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
280
+ {"__type__": "update", "visible": visible_inp_refs},
281
+ {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
282
+ {"__type__": "update", "visible": True if model_version =="v3" else False},
283
+ {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
284
+ )
285
+
286
+ dict_s2 = load_sovits_new(sovits_path)
287
+ hps = dict_s2["config"]
288
+ hps = DictToAttrRecursive(hps)
289
+ hps.model.semantic_frame_rate = "25hz"
290
+ if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
291
+ hps.model.version = "v2" # v3model,v2sybomls
292
+ elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
293
+ hps.model.version = "v1"
294
+ else:
295
+ hps.model.version = "v2"
296
+ version = hps.model.version
297
+ # print("sovits版本:",hps.model.version)
298
+ if model_version not in v3v4set:
299
+ vq_model = SynthesizerTrn(
300
+ hps.data.filter_length // 2 + 1,
301
+ hps.train.segment_size // hps.data.hop_length,
302
+ n_speakers=hps.data.n_speakers,
303
+ **hps.model,
304
+ )
305
+ model_version = version
306
+ else:
307
+ hps.model.version=model_version
308
+ vq_model = SynthesizerTrnV3(
309
+ hps.data.filter_length // 2 + 1,
310
+ hps.train.segment_size // hps.data.hop_length,
311
+ n_speakers=hps.data.n_speakers,
312
+ **hps.model,
313
+ )
314
+ if "pretrained" not in sovits_path:
315
+ try:
316
+ del vq_model.enc_q
317
+ except:
318
+ pass
319
+ if is_half == True:
320
+ vq_model = vq_model.half().to(device)
321
+ else:
322
+ vq_model = vq_model.to(device)
323
+ vq_model.eval()
324
+ if if_lora_v3 == False:
325
+ print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
326
+ else:
327
+ path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
328
+ print(
329
+ "loading sovits_%spretrained_G"%model_version,
330
+ vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
331
+ )
332
+ lora_rank = dict_s2["lora_rank"]
333
+ lora_config = LoraConfig(
334
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
335
+ r=lora_rank,
336
+ lora_alpha=lora_rank,
337
+ init_lora_weights=True,
338
+ )
339
+ vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
340
+ print("loading sovits_%s_lora%s" % (model_version,lora_rank))
341
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
342
+ vq_model.cfm = vq_model.cfm.merge_and_unload()
343
+ # torch.save(vq_model.state_dict(),"merge_win.pth")
344
+ vq_model.eval()
345
+
346
+ yield (
347
+ {"__type__": "update", "choices": list(dict_language.keys())},
348
+ {"__type__": "update", "choices": list(dict_language.keys())},
349
+ prompt_text_update,
350
+ prompt_language_update,
351
+ text_update,
352
+ text_language_update,
353
+ {"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
354
+ {"__type__": "update", "visible": visible_inp_refs},
355
+ {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
356
+ {"__type__": "update", "visible": True if model_version =="v3" else False},
357
+ {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
358
+ )
359
+ with open("./weight.json") as f:
360
+ data = f.read()
361
+ data = json.loads(data)
362
+ data["SoVITS"][version] = sovits_path
363
+ with open("./weight.json", "w") as f:
364
+ f.write(json.dumps(data))
365
+
366
+
367
+ try:
368
+ next(change_sovits_weights(sovits_path))
369
+ except:
370
+ pass
371
+
372
+
373
+ def change_gpt_weights(gpt_path):
374
+ global hz, max_sec, t2s_model, config
375
+ hz = 50
376
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
377
+ config = dict_s1["config"]
378
+ max_sec = config["data"]["max_sec"]
379
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
380
+ t2s_model.load_state_dict(dict_s1["weight"])
381
+ if is_half == True:
382
+ t2s_model = t2s_model.half()
383
+ t2s_model = t2s_model.to(device)
384
+ t2s_model.eval()
385
+ # total = sum([param.nelement() for param in t2s_model.parameters()])
386
+ # print("Number of parameter: %.2fM" % (total / 1e6))
387
+ with open("./weight.json") as f:
388
+ data = f.read()
389
+ data = json.loads(data)
390
+ data["GPT"][version] = gpt_path
391
+ with open("./weight.json", "w") as f:
392
+ f.write(json.dumps(data))
393
+
394
+
395
+ change_gpt_weights(gpt_path)
396
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
397
+ import torch
398
+
399
+ now_dir = os.getcwd()
400
+
401
+
402
+ def init_bigvgan():
403
+ global bigvgan_model,hifigan_model
404
+ from BigVGAN import bigvgan
405
+
406
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained(
407
+ "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
408
+ use_cuda_kernel=False,
409
+ ) # if True, RuntimeError: Ninja is required to load C++ extensions
410
+ # remove weight norm in the model and set to eval mode
411
+ bigvgan_model.remove_weight_norm()
412
+ bigvgan_model = bigvgan_model.eval()
413
+ if hifigan_model:
414
+ hifigan_model=hifigan_model.cpu()
415
+ hifigan_model=None
416
+ try:torch.cuda.empty_cache()
417
+ except:pass
418
+ if is_half == True:
419
+ bigvgan_model = bigvgan_model.half().to(device)
420
+ else:
421
+ bigvgan_model = bigvgan_model.to(device)
422
+
423
+ def init_hifigan():
424
+ global hifigan_model,bigvgan_model
425
+ hifigan_model = Generator(
426
+ initial_channel=100,
427
+ resblock="1",
428
+ resblock_kernel_sizes=[3, 7, 11],
429
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
430
+ upsample_rates=[10, 6, 2, 2, 2],
431
+ upsample_initial_channel=512,
432
+ upsample_kernel_sizes=[20, 12, 4, 4, 4],
433
+ gin_channels=0, is_bias=True
434
+ )
435
+ hifigan_model.eval()
436
+ hifigan_model.remove_weight_norm()
437
+ state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
438
+ print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
439
+ if bigvgan_model:
440
+ bigvgan_model=bigvgan_model.cpu()
441
+ bigvgan_model=None
442
+ try:torch.cuda.empty_cache()
443
+ except:pass
444
+ if is_half == True:
445
+ hifigan_model = hifigan_model.half().to(device)
446
+ else:
447
+ hifigan_model = hifigan_model.to(device)
448
+
449
+ bigvgan_model=hifigan_model=None
450
+ if model_version=="v3":
451
+ init_bigvgan()
452
+ if model_version=="v4":
453
+ init_hifigan()
454
+
455
+
456
+ def get_spepc(hps, filename):
457
+ # audio = load_audio(filename, int(hps.data.sampling_rate))
458
+ audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
459
+ audio = torch.FloatTensor(audio)
460
+ maxx = audio.abs().max()
461
+ if maxx > 1:
462
+ audio /= min(2, maxx)
463
+ audio_norm = audio
464
+ audio_norm = audio_norm.unsqueeze(0)
465
+ spec = spectrogram_torch(
466
+ audio_norm,
467
+ hps.data.filter_length,
468
+ hps.data.sampling_rate,
469
+ hps.data.hop_length,
470
+ hps.data.win_length,
471
+ center=False,
472
+ )
473
+ return spec
474
+
475
+
476
+ def clean_text_inf(text, language, version):
477
+ language = language.replace("all_", "")
478
+ phones, word2ph, norm_text = clean_text(text, language, version)
479
+ phones = cleaned_text_to_sequence(phones, version)
480
+ return phones, word2ph, norm_text
481
+
482
+
483
+ dtype = torch.float16 if is_half == True else torch.float32
484
+
485
+
486
+ def get_bert_inf(phones, word2ph, norm_text, language):
487
+ language = language.replace("all_", "")
488
+ if language == "zh":
489
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
490
+ else:
491
+ bert = torch.zeros(
492
+ (1024, len(phones)),
493
+ dtype=torch.float16 if is_half == True else torch.float32,
494
+ ).to(device)
495
+
496
+ return bert
497
+
498
+
499
+ splits = {
500
+ ",",
501
+ "。",
502
+ "?",
503
+ "!",
504
+ ",",
505
+ ".",
506
+ "?",
507
+ "!",
508
+ "~",
509
+ ":",
510
+ ":",
511
+ "—",
512
+ "…",
513
+ }
514
+
515
+
516
+ def get_first(text):
517
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
518
+ text = re.split(pattern, text)[0].strip()
519
+ return text
520
+
521
+
522
+ from text import chinese
523
+
524
+
525
+ def get_phones_and_bert(text, language, version, final=False):
526
+ if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
527
+ formattext = text
528
+ while " " in formattext:
529
+ formattext = formattext.replace(" ", " ")
530
+ if language == "all_zh":
531
+ if re.search(r"[A-Za-z]", formattext):
532
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
533
+ formattext = chinese.mix_text_normalize(formattext)
534
+ return get_phones_and_bert(formattext, "zh", version)
535
+ else:
536
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
537
+ bert = get_bert_feature(norm_text, word2ph).to(device)
538
+ elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
539
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
540
+ formattext = chinese.mix_text_normalize(formattext)
541
+ return get_phones_and_bert(formattext, "yue", version)
542
+ else:
543
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
544
+ bert = torch.zeros(
545
+ (1024, len(phones)),
546
+ dtype=torch.float16 if is_half == True else torch.float32,
547
+ ).to(device)
548
+ elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
549
+ textlist = []
550
+ langlist = []
551
+ if language == "auto":
552
+ for tmp in LangSegmenter.getTexts(text):
553
+ langlist.append(tmp["lang"])
554
+ textlist.append(tmp["text"])
555
+ elif language == "auto_yue":
556
+ for tmp in LangSegmenter.getTexts(text):
557
+ if tmp["lang"] == "zh":
558
+ tmp["lang"] = "yue"
559
+ langlist.append(tmp["lang"])
560
+ textlist.append(tmp["text"])
561
+ else:
562
+ for tmp in LangSegmenter.getTexts(text):
563
+ if tmp["lang"] == "en":
564
+ langlist.append(tmp["lang"])
565
+ else:
566
+ # 因无法区别中日韩文汉字,以用户输入为准
567
+ langlist.append(language)
568
+ textlist.append(tmp["text"])
569
+ print(textlist)
570
+ print(langlist)
571
+ phones_list = []
572
+ bert_list = []
573
+ norm_text_list = []
574
+ for i in range(len(textlist)):
575
+ lang = langlist[i]
576
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
577
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
578
+ phones_list.append(phones)
579
+ norm_text_list.append(norm_text)
580
+ bert_list.append(bert)
581
+ bert = torch.cat(bert_list, dim=1)
582
+ phones = sum(phones_list, [])
583
+ norm_text = "".join(norm_text_list)
584
+
585
+ if not final and len(phones) < 6:
586
+ return get_phones_and_bert("." + text, language, version, final=True)
587
+
588
+ return phones, bert.to(dtype), norm_text
589
+
590
+
591
+ from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
592
+
593
+ spec_min = -12
594
+ spec_max = 2
595
+
596
+
597
+ def norm_spec(x):
598
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
599
+
600
+
601
+ def denorm_spec(x):
602
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
603
+
604
+
605
+ mel_fn = lambda x: mel_spectrogram_torch(
606
+ x,
607
+ **{
608
+ "n_fft": 1024,
609
+ "win_size": 1024,
610
+ "hop_size": 256,
611
+ "num_mels": 100,
612
+ "sampling_rate": 24000,
613
+ "fmin": 0,
614
+ "fmax": None,
615
+ "center": False,
616
+ },
617
+ )
618
+ mel_fn_v4 = lambda x: mel_spectrogram_torch(
619
+ x,
620
+ **{
621
+ "n_fft": 1280,
622
+ "win_size": 1280,
623
+ "hop_size": 320,
624
+ "num_mels": 100,
625
+ "sampling_rate": 32000,
626
+ "fmin": 0,
627
+ "fmax": None,
628
+ "center": False,
629
+ },
630
+ )
631
+
632
+
633
+ def merge_short_text_in_array(texts, threshold):
634
+ if (len(texts)) < 2:
635
+ return texts
636
+ result = []
637
+ text = ""
638
+ for ele in texts:
639
+ text += ele
640
+ if len(text) >= threshold:
641
+ result.append(text)
642
+ text = ""
643
+ if len(text) > 0:
644
+ if len(result) == 0:
645
+ result.append(text)
646
+ else:
647
+ result[len(result) - 1] += text
648
+ return result
649
+
650
+
651
+ sr_model = None
652
+
653
+
654
+ def audio_sr(audio, sr):
655
+ global sr_model
656
+ if sr_model == None:
657
+ from tools.audio_sr import AP_BWE
658
+
659
+ try:
660
+ sr_model = AP_BWE(device, DictToAttrRecursive)
661
+ except FileNotFoundError:
662
+ gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
663
+ return audio.cpu().detach().numpy(), sr
664
+ return sr_model(audio, sr)
665
+
666
+
667
+ ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
668
+ # cache_tokens={}#暂未实现清理机制
669
+ cache = {}
670
+
671
+
672
+ def get_tts_wav(
673
+ ref_wav_path,
674
+ prompt_text,
675
+ prompt_language,
676
+ text,
677
+ text_language,
678
+ how_to_cut=i18n("不切"),
679
+ top_k=20,
680
+ top_p=0.6,
681
+ temperature=0.6,
682
+ ref_free=False,
683
+ speed=1,
684
+ if_freeze=False,
685
+ inp_refs=None,
686
+ sample_steps=8,
687
+ if_sr=False,
688
+ pause_second=0.3,
689
+ ):
690
+ global cache
691
+ if ref_wav_path:
692
+ pass
693
+ else:
694
+ gr.Warning(i18n("请上传参考音频"))
695
+ if text:
696
+ pass
697
+ else:
698
+ gr.Warning(i18n("请填入推理文本"))
699
+ t = []
700
+ if prompt_text is None or len(prompt_text) == 0:
701
+ ref_free = True
702
+ if model_version in v3v4set:
703
+ ref_free = False # s2v3暂不支持ref_free
704
+ else:
705
+ if_sr = False
706
+ t0 = ttime()
707
+ prompt_language = dict_language[prompt_language]
708
+ text_language = dict_language[text_language]
709
+
710
+ if not ref_free:
711
+ prompt_text = prompt_text.strip("\n")
712
+ if prompt_text[-1] not in splits:
713
+ prompt_text += "。" if prompt_language != "en" else "."
714
+ print(i18n("实际输入的参考文本:"), prompt_text)
715
+ text = text.strip("\n")
716
+ # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
717
+
718
+ print(i18n("实际输入的目标文本:"), text)
719
+ zero_wav = np.zeros(
720
+ int(hps.data.sampling_rate * pause_second),
721
+ dtype=np.float16 if is_half == True else np.float32,
722
+ )
723
+ zero_wav_torch = torch.from_numpy(zero_wav)
724
+ if is_half == True:
725
+ zero_wav_torch = zero_wav_torch.half().to(device)
726
+ else:
727
+ zero_wav_torch = zero_wav_torch.to(device)
728
+ if not ref_free:
729
+ with torch.no_grad():
730
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
731
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
732
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
733
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
734
+ wav16k = torch.from_numpy(wav16k)
735
+ if is_half == True:
736
+ wav16k = wav16k.half().to(device)
737
+ else:
738
+ wav16k = wav16k.to(device)
739
+ wav16k = torch.cat([wav16k, zero_wav_torch])
740
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
741
+ codes = vq_model.extract_latent(ssl_content)
742
+ prompt_semantic = codes[0, 0]
743
+ prompt = prompt_semantic.unsqueeze(0).to(device)
744
+
745
+ t1 = ttime()
746
+ t.append(t1 - t0)
747
+
748
+ if how_to_cut == i18n("凑四句一切"):
749
+ text = cut1(text)
750
+ elif how_to_cut == i18n("凑50字一切"):
751
+ text = cut2(text)
752
+ elif how_to_cut == i18n("按中文句号。切"):
753
+ text = cut3(text)
754
+ elif how_to_cut == i18n("按英文句号.切"):
755
+ text = cut4(text)
756
+ elif how_to_cut == i18n("按标点符号切"):
757
+ text = cut5(text)
758
+ while "\n\n" in text:
759
+ text = text.replace("\n\n", "\n")
760
+ print(i18n("实际输入的目标文本(切句后):"), text)
761
+ texts = text.split("\n")
762
+ texts = process_text(texts)
763
+ texts = merge_short_text_in_array(texts, 5)
764
+ audio_opt = []
765
+ ###s2v3暂不支持ref_free
766
+ if not ref_free:
767
+ phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
768
+
769
+ for i_text, text in enumerate(texts):
770
+ # 解决输入目标文本的空行导致报错的问题
771
+ if len(text.strip()) == 0:
772
+ continue
773
+ if text[-1] not in splits:
774
+ text += "。" if text_language != "en" else "."
775
+ print(i18n("实际输入的目标文本(每句):"), text)
776
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
777
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
778
+ if not ref_free:
779
+ bert = torch.cat([bert1, bert2], 1)
780
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
781
+ else:
782
+ bert = bert2
783
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
784
+
785
+ bert = bert.to(device).unsqueeze(0)
786
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
787
+
788
+ t2 = ttime()
789
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
790
+ # print(cache.keys(),if_freeze)
791
+ if i_text in cache and if_freeze == True:
792
+ pred_semantic = cache[i_text]
793
+ else:
794
+ with torch.no_grad():
795
+ pred_semantic, idx = t2s_model.model.infer_panel(
796
+ all_phoneme_ids,
797
+ all_phoneme_len,
798
+ None if ref_free else prompt,
799
+ bert,
800
+ # prompt_phone_len=ph_offset,
801
+ top_k=top_k,
802
+ top_p=top_p,
803
+ temperature=temperature,
804
+ early_stop_num=hz * max_sec,
805
+ )
806
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
807
+ cache[i_text] = pred_semantic
808
+ t3 = ttime()
809
+ ###v3不存在以下逻辑和inp_refs
810
+ if model_version not in v3v4set:
811
+ refers = []
812
+ if inp_refs:
813
+ for path in inp_refs:
814
+ try:
815
+ refer = get_spepc(hps, path.name).to(dtype).to(device)
816
+ refers.append(refer)
817
+ except:
818
+ traceback.print_exc()
819
+ if len(refers) == 0:
820
+ refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
821
+ audio = vq_model.decode(
822
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
823
+ )[0][0] # .cpu().detach().numpy()
824
+ else:
825
+ refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
826
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
827
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
828
+ # print(11111111, phoneme_ids0, phoneme_ids1)
829
+ fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
830
+ ref_audio, sr = torchaudio.load(ref_wav_path)
831
+ ref_audio = ref_audio.to(device).float()
832
+ if ref_audio.shape[0] == 2:
833
+ ref_audio = ref_audio.mean(0).unsqueeze(0)
834
+ tgt_sr=24000 if model_version=="v3"else 32000
835
+ if sr != tgt_sr:
836
+ ref_audio = resample(ref_audio, sr,tgt_sr)
837
+ # print("ref_audio",ref_audio.abs().mean())
838
+ mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
839
+ mel2 = norm_spec(mel2)
840
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
841
+ mel2 = mel2[:, :, :T_min]
842
+ fea_ref = fea_ref[:, :, :T_min]
843
+ Tref=468 if model_version=="v3"else 500
844
+ Tchunk=934 if model_version=="v3"else 1000
845
+ if T_min > Tref:
846
+ mel2 = mel2[:, :, -Tref:]
847
+ fea_ref = fea_ref[:, :, -Tref:]
848
+ T_min = Tref
849
+ chunk_len = Tchunk - T_min
850
+ mel2 = mel2.to(dtype)
851
+ fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
852
+ cfm_resss = []
853
+ idx = 0
854
+ while 1:
855
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
856
+ if fea_todo_chunk.shape[-1] == 0:
857
+ break
858
+ idx += chunk_len
859
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
860
+ cfm_res = vq_model.cfm.inference(
861
+ fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
862
+ )
863
+ cfm_res = cfm_res[:, :, mel2.shape[2] :]
864
+ mel2 = cfm_res[:, :, -T_min:]
865
+ fea_ref = fea_todo_chunk[:, :, -T_min:]
866
+ cfm_resss.append(cfm_res)
867
+ cfm_res = torch.cat(cfm_resss, 2)
868
+ cfm_res = denorm_spec(cfm_res)
869
+ if model_version=="v3":
870
+ if bigvgan_model == None:
871
+ init_bigvgan()
872
+ else:#v4
873
+ if hifigan_model == None:
874
+ init_hifigan()
875
+ vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
876
+ with torch.inference_mode():
877
+ wav_gen = vocoder_model(cfm_res)
878
+ audio = wav_gen[0][0] # .cpu().detach().numpy()
879
+ max_audio = torch.abs(audio).max() # 简单防止16bit爆音
880
+ if max_audio > 1:
881
+ audio = audio / max_audio
882
+ audio_opt.append(audio)
883
+ audio_opt.append(zero_wav_torch) # zero_wav
884
+ t4 = ttime()
885
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
886
+ t1 = ttime()
887
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
888
+ audio_opt = torch.cat(audio_opt, 0) # np.concatenate
889
+ if model_version in {"v1","v2"}:opt_sr=32000
890
+ elif model_version=="v3":opt_sr=24000
891
+ else:opt_sr=48000#v4
892
+ if if_sr == True and opt_sr == 24000:
893
+ print(i18n("音频超分中"))
894
+ audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
895
+ max_audio = np.abs(audio_opt).max()
896
+ if max_audio > 1:
897
+ audio_opt /= max_audio
898
+ else:
899
+ audio_opt = audio_opt.cpu().detach().numpy()
900
+ yield opt_sr, (audio_opt * 32767).astype(np.int16)
901
+
902
+
903
+ def split(todo_text):
904
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
905
+ if todo_text[-1] not in splits:
906
+ todo_text += "。"
907
+ i_split_head = i_split_tail = 0
908
+ len_text = len(todo_text)
909
+ todo_texts = []
910
+ while 1:
911
+ if i_split_head >= len_text:
912
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
913
+ if todo_text[i_split_head] in splits:
914
+ i_split_head += 1
915
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
916
+ i_split_tail = i_split_head
917
+ else:
918
+ i_split_head += 1
919
+ return todo_texts
920
+
921
+
922
+ def cut1(inp):
923
+ inp = inp.strip("\n")
924
+ inps = split(inp)
925
+ split_idx = list(range(0, len(inps), 4))
926
+ split_idx[-1] = None
927
+ if len(split_idx) > 1:
928
+ opts = []
929
+ for idx in range(len(split_idx) - 1):
930
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
931
+ else:
932
+ opts = [inp]
933
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
934
+ return "\n".join(opts)
935
+
936
+
937
+ def cut2(inp):
938
+ inp = inp.strip("\n")
939
+ inps = split(inp)
940
+ if len(inps) < 2:
941
+ return inp
942
+ opts = []
943
+ summ = 0
944
+ tmp_str = ""
945
+ for i in range(len(inps)):
946
+ summ += len(inps[i])
947
+ tmp_str += inps[i]
948
+ if summ > 50:
949
+ summ = 0
950
+ opts.append(tmp_str)
951
+ tmp_str = ""
952
+ if tmp_str != "":
953
+ opts.append(tmp_str)
954
+ # print(opts)
955
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
956
+ opts[-2] = opts[-2] + opts[-1]
957
+ opts = opts[:-1]
958
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
959
+ return "\n".join(opts)
960
+
961
+
962
+ def cut3(inp):
963
+ inp = inp.strip("\n")
964
+ opts = ["%s" % item for item in inp.strip("。").split("。")]
965
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
966
+ return "\n".join(opts)
967
+
968
+
969
+ def cut4(inp):
970
+ inp = inp.strip("\n")
971
+ opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
972
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
973
+ return "\n".join(opts)
974
+
975
+
976
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
977
+ def cut5(inp):
978
+ inp = inp.strip("\n")
979
+ punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
980
+ mergeitems = []
981
+ items = []
982
+
983
+ for i, char in enumerate(inp):
984
+ if char in punds:
985
+ if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
986
+ items.append(char)
987
+ else:
988
+ items.append(char)
989
+ mergeitems.append("".join(items))
990
+ items = []
991
+ else:
992
+ items.append(char)
993
+
994
+ if items:
995
+ mergeitems.append("".join(items))
996
+
997
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
998
+ return "\n".join(opt)
999
+
1000
+
1001
+ def custom_sort_key(s):
1002
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
1003
+ parts = re.split("(\d+)", s)
1004
+ # 将数字部分转换为整数,非数字部分保持不变
1005
+ parts = [int(part) if part.isdigit() else part for part in parts]
1006
+ return parts
1007
+
1008
+
1009
+ def process_text(texts):
1010
+ _text = []
1011
+ if all(text in [None, " ", "\n", ""] for text in texts):
1012
+ raise ValueError(i18n("请输入有效文本"))
1013
+ for text in texts:
1014
+ if text in [None, " ", ""]:
1015
+ pass
1016
+ else:
1017
+ _text.append(text)
1018
+ return _text
1019
+
1020
+
1021
+ def change_choices():
1022
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
1023
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
1024
+ "choices": sorted(GPT_names, key=custom_sort_key),
1025
+ "__type__": "update",
1026
+ }
1027
+
1028
+
1029
+ SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
1030
+ GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
1031
+ for path in SoVITS_weight_root + GPT_weight_root:
1032
+ os.makedirs(path, exist_ok=True)
1033
+
1034
+
1035
+ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
1036
+ SoVITS_names = [i for i in pretrained_sovits_name]
1037
+ for path in SoVITS_weight_root:
1038
+ for name in os.listdir(path):
1039
+ if name.endswith(".pth"):
1040
+ SoVITS_names.append("%s/%s" % (path, name))
1041
+ GPT_names = [i for i in pretrained_gpt_name]
1042
+ for path in GPT_weight_root:
1043
+ for name in os.listdir(path):
1044
+ if name.endswith(".ckpt"):
1045
+ GPT_names.append("%s/%s" % (path, name))
1046
+ return SoVITS_names, GPT_names
1047
+
1048
+
1049
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
1050
+
1051
+
1052
+ def html_center(text, label="p"):
1053
+ return f"""<div style="text-align: center; margin: 100; padding: 50;">
1054
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
1055
+ </div>"""
1056
+
1057
+
1058
+ def html_left(text, label="p"):
1059
+ return f"""<div style="text-align: left; margin: 0; padding: 0;">
1060
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
1061
+ </div>"""
1062
+
1063
+
1064
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
1065
+ gr.Markdown(
1066
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
1067
+ + "<br>"
1068
+ + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
1069
+ )
1070
+ with gr.Group():
1071
+ gr.Markdown(html_center(i18n("模型切换"), "h3"))
1072
+ with gr.Row():
1073
+ GPT_dropdown = gr.Dropdown(
1074
+ label=i18n("GPT模型列表"),
1075
+ choices=sorted(GPT_names, key=custom_sort_key),
1076
+ value=gpt_path,
1077
+ interactive=True,
1078
+ scale=14,
1079
+ )
1080
+ SoVITS_dropdown = gr.Dropdown(
1081
+ label=i18n("SoVITS模型列表"),
1082
+ choices=sorted(SoVITS_names, key=custom_sort_key),
1083
+ value=sovits_path,
1084
+ interactive=True,
1085
+ scale=14,
1086
+ )
1087
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
1088
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
1089
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
1090
+ with gr.Row():
1091
+ inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13)
1092
+ with gr.Column(scale=13):
1093
+ ref_text_free = gr.Checkbox(
1094
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
1095
+ + i18n("v3暂不支持该模式,使用了会报错。"),
1096
+ value=False,
1097
+ interactive=True if model_version not in v3v4set else False,
1098
+ show_label=True,
1099
+ scale=1,
1100
+ )
1101
+ gr.Markdown(
1102
+ html_left(
1103
+ i18n("使用无参考文本模式时建议使用微调的GPT")
1104
+ + "<br>"
1105
+ + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
1106
+ )
1107
+ )
1108
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5, scale=1)
1109
+ with gr.Column(scale=14):
1110
+ prompt_language = gr.Dropdown(
1111
+ label=i18n("参考音频的语种"),
1112
+ choices=list(dict_language.keys()),
1113
+ value=i18n("中文"),
1114
+ )
1115
+ inp_refs = (
1116
+ gr.File(
1117
+ label=i18n(
1118
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
1119
+ ),
1120
+ file_count="multiple",
1121
+ )
1122
+ if model_version not in v3v4set
1123
+ else gr.File(
1124
+ label=i18n(
1125
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
1126
+ ),
1127
+ file_count="multiple",
1128
+ visible=False,
1129
+ )
1130
+ )
1131
+ sample_steps = (
1132
+ gr.Radio(
1133
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
1134
+ value=32 if model_version=="v3"else 8,
1135
+ choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
1136
+ visible=True,
1137
+ )
1138
+ if model_version in v3v4set
1139
+ else gr.Radio(
1140
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
1141
+ choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
1142
+ visible=False,
1143
+ value=32 if model_version=="v3"else 8,
1144
+ )
1145
+ )
1146
+ if_sr_Checkbox = gr.Checkbox(
1147
+ label=i18n("v3输出如果觉得闷可以试试开超分"),
1148
+ value=False,
1149
+ interactive=True,
1150
+ show_label=True,
1151
+ visible=False if model_version !="v3" else True,
1152
+ )
1153
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
1154
+ with gr.Row():
1155
+ with gr.Column(scale=13):
1156
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
1157
+ with gr.Column(scale=7):
1158
+ text_language = gr.Dropdown(
1159
+ label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
1160
+ choices=list(dict_language.keys()),
1161
+ value=i18n("中文"),
1162
+ scale=1,
1163
+ )
1164
+ how_to_cut = gr.Dropdown(
1165
+ label=i18n("怎么切"),
1166
+ choices=[
1167
+ i18n("不切"),
1168
+ i18n("凑四句一切"),
1169
+ i18n("凑50字一切"),
1170
+ i18n("按中文句号。切"),
1171
+ i18n("按英文句号.切"),
1172
+ i18n("按标点符号切"),
1173
+ ],
1174
+ value=i18n("凑四句一切"),
1175
+ interactive=True,
1176
+ scale=1,
1177
+ )
1178
+ gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
1179
+ if_freeze = gr.Checkbox(
1180
+ label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
1181
+ value=False,
1182
+ interactive=True,
1183
+ show_label=True,
1184
+ scale=1,
1185
+ )
1186
+ with gr.Row():
1187
+ speed = gr.Slider(
1188
+ minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1
1189
+ )
1190
+ pause_second_slider = gr.Slider(
1191
+ minimum=0.1,
1192
+ maximum=0.5,
1193
+ step=0.01,
1194
+ label=i18n("句间停顿秒数"),
1195
+ value=0.3,
1196
+ interactive=True,
1197
+ scale=1,
1198
+ )
1199
+ gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
1200
+ top_k = gr.Slider(
1201
+ minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1
1202
+ )
1203
+ top_p = gr.Slider(
1204
+ minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1
1205
+ )
1206
+ temperature = gr.Slider(
1207
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1
1208
+ )
1209
+ # with gr.Column():
1210
+ # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
1211
+ # phoneme=gr.Textbox(label=i18n("音素框"), value="")
1212
+ # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
1213
+ with gr.Row():
1214
+ inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
1215
+ output = gr.Audio(label=i18n("输出的语音"), scale=14)
1216
+
1217
+ inference_button.click(
1218
+ get_tts_wav,
1219
+ [
1220
+ inp_ref,
1221
+ prompt_text,
1222
+ prompt_language,
1223
+ text,
1224
+ text_language,
1225
+ how_to_cut,
1226
+ top_k,
1227
+ top_p,
1228
+ temperature,
1229
+ ref_text_free,
1230
+ speed,
1231
+ if_freeze,
1232
+ inp_refs,
1233
+ sample_steps,
1234
+ if_sr_Checkbox,
1235
+ pause_second_slider,
1236
+ ],
1237
+ [output],
1238
+ )
1239
+ SoVITS_dropdown.change(
1240
+ change_sovits_weights,
1241
+ [SoVITS_dropdown, prompt_language, text_language],
1242
+ [
1243
+ prompt_language,
1244
+ text_language,
1245
+ prompt_text,
1246
+ prompt_language,
1247
+ text,
1248
+ text_language,
1249
+ sample_steps,
1250
+ inp_refs,
1251
+ ref_text_free,
1252
+ if_sr_Checkbox,
1253
+ inference_button,
1254
+ ],
1255
+ )
1256
+ GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
1257
+
1258
+ # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
1259
+ # with gr.Row():
1260
+ # text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
1261
+ # button1 = gr.Button(i18n("凑四句一切"), variant="primary")
1262
+ # button2 = gr.Button(i18n("凑50字一切"), variant="primary")
1263
+ # button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
1264
+ # button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
1265
+ # button5 = gr.Button(i18n("按标点符号切"), variant="primary")
1266
+ # text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
1267
+ # button1.click(cut1, [text_inp], [text_opt])
1268
+ # button2.click(cut2, [text_inp], [text_opt])
1269
+ # button3.click(cut3, [text_inp], [text_opt])
1270
+ # button4.click(cut4, [text_inp], [text_opt])
1271
+ # button5.click(cut5, [text_inp], [text_opt])
1272
+ # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
1273
+
1274
+ if __name__ == "__main__":
1275
+ app.queue().launch( # concurrency_count=511, max_size=1022
1276
+ server_name="0.0.0.0",
1277
+ inbrowser=True,
1278
+ share=is_share,
1279
+ server_port=infer_ttswebui,
1280
+ # quiet=True,
1281
+ )
GPT_SoVITS/inference_webui_fast.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ """
9
+
10
+ import json
11
+ import logging
12
+ import os
13
+ import random
14
+ import re
15
+ import sys
16
+
17
+ now_dir = os.getcwd()
18
+ sys.path.append(now_dir)
19
+ sys.path.append("%s/GPT_SoVITS" % (now_dir))
20
+
21
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
22
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
23
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
24
+ logging.getLogger("httpx").setLevel(logging.ERROR)
25
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
26
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
27
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
28
+ import torch
29
+
30
+ try:
31
+ import gradio.analytics as analytics
32
+
33
+ analytics.version_check = lambda: None
34
+ except:
35
+ ...
36
+
37
+
38
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
39
+ infer_ttswebui = int(infer_ttswebui)
40
+ is_share = os.environ.get("is_share", "False")
41
+ is_share = eval(is_share)
42
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
43
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
44
+
45
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
46
+ gpt_path = os.environ.get("gpt_path", None)
47
+ sovits_path = os.environ.get("sovits_path", None)
48
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
49
+ bert_path = os.environ.get("bert_path", None)
50
+ version = model_version = os.environ.get("version", "v2")
51
+
52
+ import gradio as gr
53
+ from TTS_infer_pack.text_segmentation_method import get_method
54
+ from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
55
+
56
+ from tools.i18n.i18n import I18nAuto, scan_language_list
57
+
58
+ language = os.environ.get("language", "Auto")
59
+ language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
60
+ i18n = I18nAuto(language=language)
61
+
62
+
63
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
64
+
65
+ if torch.cuda.is_available():
66
+ device = "cuda"
67
+ # elif torch.backends.mps.is_available():
68
+ # device = "mps"
69
+ else:
70
+ device = "cpu"
71
+
72
+ # is_half = False
73
+ # device = "cpu"
74
+
75
+ dict_language_v1 = {
76
+ i18n("中文"): "all_zh", # 全部按中文识别
77
+ i18n("英文"): "en", # 全部按英文识别#######不变
78
+ i18n("日文"): "all_ja", # 全部按日文识别
79
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
80
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
81
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
82
+ }
83
+ dict_language_v2 = {
84
+ i18n("中文"): "all_zh", # 全部按中文识别
85
+ i18n("英文"): "en", # 全部按英文识别#######不变
86
+ i18n("日文"): "all_ja", # 全部按日文识别
87
+ i18n("粤语"): "all_yue", # 全部按中文识别
88
+ i18n("韩文"): "all_ko", # 全部按韩文识别
89
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
90
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
91
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
92
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
93
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
94
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
95
+ }
96
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
97
+
98
+ cut_method = {
99
+ i18n("不切"): "cut0",
100
+ i18n("凑四句一切"): "cut1",
101
+ i18n("凑50字一切"): "cut2",
102
+ i18n("按中文句号。切"): "cut3",
103
+ i18n("按英文句号.切"): "cut4",
104
+ i18n("按标点符号切"): "cut5",
105
+ }
106
+
107
+ tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
108
+ tts_config.device = device
109
+ tts_config.is_half = is_half
110
+ tts_config.version = version
111
+ if gpt_path is not None:
112
+ tts_config.t2s_weights_path = gpt_path
113
+ if sovits_path is not None:
114
+ tts_config.vits_weights_path = sovits_path
115
+ if cnhubert_base_path is not None:
116
+ tts_config.cnhuhbert_base_path = cnhubert_base_path
117
+ if bert_path is not None:
118
+ tts_config.bert_base_path = bert_path
119
+
120
+ print(tts_config)
121
+ tts_pipeline = TTS(tts_config)
122
+ gpt_path = tts_config.t2s_weights_path
123
+ sovits_path = tts_config.vits_weights_path
124
+ version = tts_config.version
125
+
126
+
127
+ def inference(
128
+ text,
129
+ text_lang,
130
+ ref_audio_path,
131
+ aux_ref_audio_paths,
132
+ prompt_text,
133
+ prompt_lang,
134
+ top_k,
135
+ top_p,
136
+ temperature,
137
+ text_split_method,
138
+ batch_size,
139
+ speed_factor,
140
+ ref_text_free,
141
+ split_bucket,
142
+ fragment_interval,
143
+ seed,
144
+ keep_random,
145
+ parallel_infer,
146
+ repetition_penalty,
147
+ sample_steps,
148
+ super_sampling,
149
+ ):
150
+ seed = -1 if keep_random else seed
151
+ actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
152
+ inputs = {
153
+ "text": text,
154
+ "text_lang": dict_language[text_lang],
155
+ "ref_audio_path": ref_audio_path,
156
+ "aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
157
+ "prompt_text": prompt_text if not ref_text_free else "",
158
+ "prompt_lang": dict_language[prompt_lang],
159
+ "top_k": top_k,
160
+ "top_p": top_p,
161
+ "temperature": temperature,
162
+ "text_split_method": cut_method[text_split_method],
163
+ "batch_size": int(batch_size),
164
+ "speed_factor": float(speed_factor),
165
+ "split_bucket": split_bucket,
166
+ "return_fragment": False,
167
+ "fragment_interval": fragment_interval,
168
+ "seed": actual_seed,
169
+ "parallel_infer": parallel_infer,
170
+ "repetition_penalty": repetition_penalty,
171
+ "sample_steps": int(sample_steps),
172
+ "super_sampling": super_sampling,
173
+ }
174
+ try:
175
+ for item in tts_pipeline.run(inputs):
176
+ yield item, actual_seed
177
+ except NO_PROMPT_ERROR:
178
+ gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
179
+
180
+
181
+ def custom_sort_key(s):
182
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
183
+ parts = re.split("(\d+)", s)
184
+ # 将数字部分转换为整数,非数字部分保持不变
185
+ parts = [int(part) if part.isdigit() else part for part in parts]
186
+ return parts
187
+
188
+
189
+ def change_choices():
190
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
191
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
192
+ "choices": sorted(GPT_names, key=custom_sort_key),
193
+ "__type__": "update",
194
+ }
195
+
196
+
197
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
198
+ path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
199
+ is_exist_s2gv3 = os.path.exists(path_sovits_v3)
200
+ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
201
+ pretrained_sovits_name = [
202
+ "GPT_SoVITS/pretrained_models/s2G488k.pth",
203
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
204
+ "GPT_SoVITS/pretrained_models/s2Gv3.pth",
205
+ "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
206
+ ]
207
+ pretrained_gpt_name = [
208
+ "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
209
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
210
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
211
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
212
+ ]
213
+
214
+
215
+ _ = [[], []]
216
+ for i in range(4):
217
+ if os.path.exists(pretrained_gpt_name[i]):
218
+ _[0].append(pretrained_gpt_name[i])
219
+ if os.path.exists(pretrained_sovits_name[i]):
220
+ _[-1].append(pretrained_sovits_name[i])
221
+ pretrained_gpt_name, pretrained_sovits_name = _
222
+
223
+ if os.path.exists("./weight.json"):
224
+ pass
225
+ else:
226
+ with open("./weight.json", "w", encoding="utf-8") as file:
227
+ json.dump({"GPT": {}, "SoVITS": {}}, file)
228
+
229
+ with open("./weight.json", "r", encoding="utf-8") as file:
230
+ weight_data = file.read()
231
+ weight_data = json.loads(weight_data)
232
+ gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
233
+ sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
234
+ if isinstance(gpt_path, list):
235
+ gpt_path = gpt_path[0]
236
+ if isinstance(sovits_path, list):
237
+ sovits_path = sovits_path[0]
238
+
239
+
240
+ SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
241
+ GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
242
+ for path in SoVITS_weight_root + GPT_weight_root:
243
+ os.makedirs(path, exist_ok=True)
244
+
245
+
246
+ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
247
+ SoVITS_names = [i for i in pretrained_sovits_name]
248
+ for path in SoVITS_weight_root:
249
+ for name in os.listdir(path):
250
+ if name.endswith(".pth"):
251
+ SoVITS_names.append("%s/%s" % (path, name))
252
+ GPT_names = [i for i in pretrained_gpt_name]
253
+ for path in GPT_weight_root:
254
+ for name in os.listdir(path):
255
+ if name.endswith(".ckpt"):
256
+ GPT_names.append("%s/%s" % (path, name))
257
+ return SoVITS_names, GPT_names
258
+
259
+
260
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
261
+
262
+
263
+ from process_ckpt import get_sovits_version_from_path_fast
264
+
265
+ v3v4set={"v3","v4"}
266
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
267
+ global version, model_version, dict_language, if_lora_v3
268
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
269
+ # print(sovits_path,version, model_version, if_lora_v3)
270
+ is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
271
+ path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
272
+ if if_lora_v3 == True and is_exist == False:
273
+ info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
274
+ gr.Warning(info)
275
+ raise FileExistsError(info)
276
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
277
+ if prompt_language is not None and text_language is not None:
278
+ if prompt_language in list(dict_language.keys()):
279
+ prompt_text_update, prompt_language_update = (
280
+ {"__type__": "update"},
281
+ {"__type__": "update", "value": prompt_language},
282
+ )
283
+ else:
284
+ prompt_text_update = {"__type__": "update", "value": ""}
285
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
286
+ if text_language in list(dict_language.keys()):
287
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
288
+ else:
289
+ text_update = {"__type__": "update", "value": ""}
290
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
291
+ if model_version in v3v4set:
292
+ visible_sample_steps = True
293
+ visible_inp_refs = False
294
+ else:
295
+ visible_sample_steps = False
296
+ visible_inp_refs = True
297
+ yield (
298
+ {"__type__": "update", "choices": list(dict_language.keys())},
299
+ {"__type__": "update", "choices": list(dict_language.keys())},
300
+ prompt_text_update,
301
+ prompt_language_update,
302
+ text_update,
303
+ text_language_update,
304
+ {"__type__": "update", "interactive": visible_sample_steps, "value": 32},
305
+ {"__type__": "update", "visible": visible_inp_refs},
306
+ {"__type__": "update", "interactive": True if model_version not in v3v4set else False},
307
+ {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
308
+ )
309
+
310
+ tts_pipeline.init_vits_weights(sovits_path)
311
+ yield (
312
+ {"__type__": "update", "choices": list(dict_language.keys())},
313
+ {"__type__": "update", "choices": list(dict_language.keys())},
314
+ prompt_text_update,
315
+ prompt_language_update,
316
+ text_update,
317
+ text_language_update,
318
+ {"__type__": "update", "interactive": visible_sample_steps, "value": 32},
319
+ {"__type__": "update", "visible": visible_inp_refs},
320
+ {"__type__": "update", "interactive": True if model_version not in v3v4set else False},
321
+ {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
322
+ )
323
+ with open("./weight.json") as f:
324
+ data = f.read()
325
+ data = json.loads(data)
326
+ data["SoVITS"][version] = sovits_path
327
+ with open("./weight.json", "w") as f:
328
+ f.write(json.dumps(data))
329
+
330
+
331
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
332
+ gr.Markdown(
333
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
334
+ + "<br>"
335
+ + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
336
+ )
337
+
338
+ with gr.Column():
339
+ # with gr.Group():
340
+ gr.Markdown(value=i18n("模型切换"))
341
+ with gr.Row():
342
+ GPT_dropdown = gr.Dropdown(
343
+ label=i18n("GPT模型列表"),
344
+ choices=sorted(GPT_names, key=custom_sort_key),
345
+ value=gpt_path,
346
+ interactive=True,
347
+ )
348
+ SoVITS_dropdown = gr.Dropdown(
349
+ label=i18n("SoVITS模型列表"),
350
+ choices=sorted(SoVITS_names, key=custom_sort_key),
351
+ value=sovits_path,
352
+ interactive=True,
353
+ )
354
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
355
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
356
+
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown(value=i18n("*请上传并填写参考信息"))
360
+ with gr.Row():
361
+ inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
362
+ inp_refs = gr.File(
363
+ label=i18n("辅参考音频(可选多个,或不选)"),
364
+ file_count="multiple",
365
+ visible=True if model_version != "v3" else False,
366
+ )
367
+ prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
368
+ with gr.Row():
369
+ prompt_language = gr.Dropdown(
370
+ label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
371
+ )
372
+ with gr.Column():
373
+ ref_text_free = gr.Checkbox(
374
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
375
+ value=False,
376
+ interactive=True if model_version != "v3" else False,
377
+ show_label=True,
378
+ )
379
+ gr.Markdown(
380
+ i18n("使用无参考文本模式时建议使用微调的GPT")
381
+ + "<br>"
382
+ + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
383
+ )
384
+
385
+ with gr.Column():
386
+ gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
387
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
388
+ text_language = gr.Dropdown(
389
+ label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
390
+ )
391
+
392
+ with gr.Group():
393
+ gr.Markdown(value=i18n("推理设置"))
394
+ with gr.Row():
395
+ with gr.Column():
396
+ with gr.Row():
397
+ batch_size = gr.Slider(
398
+ minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
399
+ )
400
+ sample_steps = gr.Radio(
401
+ label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
402
+ )
403
+ with gr.Row():
404
+ fragment_interval = gr.Slider(
405
+ minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
406
+ )
407
+ speed_factor = gr.Slider(
408
+ minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
409
+ )
410
+ with gr.Row():
411
+ top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
412
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
413
+ with gr.Row():
414
+ temperature = gr.Slider(
415
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
416
+ )
417
+ repetition_penalty = gr.Slider(
418
+ minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
419
+ )
420
+
421
+ with gr.Column():
422
+ with gr.Row():
423
+ how_to_cut = gr.Dropdown(
424
+ label=i18n("怎么切"),
425
+ choices=[
426
+ i18n("不切"),
427
+ i18n("凑四句一切"),
428
+ i18n("凑50字一切"),
429
+ i18n("按中文句号。切"),
430
+ i18n("按英文句号.切"),
431
+ i18n("按标点符号切"),
432
+ ],
433
+ value=i18n("凑四句一切"),
434
+ interactive=True,
435
+ scale=1,
436
+ )
437
+ super_sampling = gr.Checkbox(
438
+ label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
439
+ )
440
+
441
+ with gr.Row():
442
+ parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
443
+ split_bucket = gr.Checkbox(
444
+ label=i18n("数据分桶(并行推理时会降低一点计算量)"),
445
+ value=True,
446
+ interactive=True,
447
+ show_label=True,
448
+ )
449
+
450
+ with gr.Row():
451
+ seed = gr.Number(label=i18n("随机种子"), value=-1)
452
+ keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
453
+
454
+ output = gr.Audio(label=i18n("输出的语音"))
455
+ with gr.Row():
456
+ inference_button = gr.Button(i18n("合成语音"), variant="primary")
457
+ stop_infer = gr.Button(i18n("终止合成"), variant="primary")
458
+
459
+ inference_button.click(
460
+ inference,
461
+ [
462
+ text,
463
+ text_language,
464
+ inp_ref,
465
+ inp_refs,
466
+ prompt_text,
467
+ prompt_language,
468
+ top_k,
469
+ top_p,
470
+ temperature,
471
+ how_to_cut,
472
+ batch_size,
473
+ speed_factor,
474
+ ref_text_free,
475
+ split_bucket,
476
+ fragment_interval,
477
+ seed,
478
+ keep_random,
479
+ parallel_infer,
480
+ repetition_penalty,
481
+ sample_steps,
482
+ super_sampling,
483
+ ],
484
+ [output, seed],
485
+ )
486
+ stop_infer.click(tts_pipeline.stop, [], [])
487
+ SoVITS_dropdown.change(
488
+ change_sovits_weights,
489
+ [SoVITS_dropdown, prompt_language, text_language],
490
+ [
491
+ prompt_language,
492
+ text_language,
493
+ prompt_text,
494
+ prompt_language,
495
+ text,
496
+ text_language,
497
+ sample_steps,
498
+ inp_refs,
499
+ ref_text_free,
500
+ inference_button,
501
+ ],
502
+ ) #
503
+ GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
504
+
505
+ with gr.Group():
506
+ gr.Markdown(
507
+ value=i18n(
508
+ "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
509
+ )
510
+ )
511
+ with gr.Row():
512
+ text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
513
+ with gr.Column():
514
+ _how_to_cut = gr.Radio(
515
+ label=i18n("怎么切"),
516
+ choices=[
517
+ i18n("不切"),
518
+ i18n("凑四句一切"),
519
+ i18n("凑50字一切"),
520
+ i18n("按中文句号。切"),
521
+ i18n("按英文句号.切"),
522
+ i18n("按标点符号切"),
523
+ ],
524
+ value=i18n("凑四句一切"),
525
+ interactive=True,
526
+ )
527
+ cut_text = gr.Button(i18n("切分"), variant="primary")
528
+
529
+ def to_cut(text_inp, how_to_cut):
530
+ if len(text_inp.strip()) == 0 or text_inp == []:
531
+ return ""
532
+ method = get_method(cut_method[how_to_cut])
533
+ return method(text_inp)
534
+
535
+ text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
536
+ cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
537
+ gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
538
+
539
+ if __name__ == "__main__":
540
+ app.queue().launch( # concurrency_count=511, max_size=1022
541
+ server_name="0.0.0.0",
542
+ inbrowser=True,
543
+ share=is_share,
544
+ server_port=infer_ttswebui,
545
+ # quiet=True,
546
+ )
GPT_SoVITS/onnx_export.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
4
+ from feature_extractor import cnhubert
5
+ from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
6
+ from torch import nn
7
+
8
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
9
+ cnhubert.cnhubert_base_path = cnhubert_base_path
10
+ ssl_model = cnhubert.get_model()
11
+ import json
12
+ import os
13
+
14
+ import soundfile
15
+ from text import cleaned_text_to_sequence
16
+
17
+
18
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
19
+ hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
20
+ y = torch.nn.functional.pad(
21
+ y.unsqueeze(1),
22
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
+ mode="reflect",
24
+ )
25
+ y = y.squeeze(1)
26
+ spec = torch.stft(
27
+ y,
28
+ n_fft,
29
+ hop_length=hop_size,
30
+ win_length=win_size,
31
+ window=hann_window,
32
+ center=center,
33
+ pad_mode="reflect",
34
+ normalized=False,
35
+ onesided=True,
36
+ return_complex=False,
37
+ )
38
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
+ return spec
40
+
41
+
42
+ class DictToAttrRecursive(dict):
43
+ def __init__(self, input_dict):
44
+ super().__init__(input_dict)
45
+ for key, value in input_dict.items():
46
+ if isinstance(value, dict):
47
+ value = DictToAttrRecursive(value)
48
+ self[key] = value
49
+ setattr(self, key, value)
50
+
51
+ def __getattr__(self, item):
52
+ try:
53
+ return self[item]
54
+ except KeyError:
55
+ raise AttributeError(f"Attribute {item} not found")
56
+
57
+ def __setattr__(self, key, value):
58
+ if isinstance(value, dict):
59
+ value = DictToAttrRecursive(value)
60
+ super(DictToAttrRecursive, self).__setitem__(key, value)
61
+ super().__setattr__(key, value)
62
+
63
+ def __delattr__(self, item):
64
+ try:
65
+ del self[item]
66
+ except KeyError:
67
+ raise AttributeError(f"Attribute {item} not found")
68
+
69
+
70
+ class T2SEncoder(nn.Module):
71
+ def __init__(self, t2s, vits):
72
+ super().__init__()
73
+ self.encoder = t2s.onnx_encoder
74
+ self.vits = vits
75
+
76
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
+ codes = self.vits.extract_latent(ssl_content)
78
+ prompt_semantic = codes[0, 0]
79
+ bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
+ bert = bert.unsqueeze(0)
82
+ prompt = prompt_semantic.unsqueeze(0)
83
+ return self.encoder(all_phoneme_ids, bert), prompt
84
+
85
+
86
+ class T2SModel(nn.Module):
87
+ def __init__(self, t2s_path, vits_model):
88
+ super().__init__()
89
+ dict_s1 = torch.load(t2s_path, map_location="cpu")
90
+ self.config = dict_s1["config"]
91
+ self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
+ self.t2s_model.load_state_dict(dict_s1["weight"])
93
+ self.t2s_model.eval()
94
+ self.vits_model = vits_model.vq_model
95
+ self.hz = 50
96
+ self.max_sec = self.config["data"]["max_sec"]
97
+ self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
+ self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
+ self.t2s_model = self.t2s_model.model
100
+ self.t2s_model.init_onnx()
101
+ self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
+ self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
+ self.stage_decoder = self.t2s_model.stage_decoder
104
+ # self.t2s_model = torch.jit.script(self.t2s_model)
105
+
106
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
+ early_stop_num = self.t2s_model.early_stop_num
108
+
109
+ # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
+
112
+ prefix_len = prompts.shape[1]
113
+
114
+ # [1,N,512] [1,N]
115
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
+
117
+ stop = False
118
+ for idx in range(1, 1500):
119
+ # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
+ enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
+ y, k, v, y_emb, logits, samples = enco
122
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
+ stop = True
124
+ if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
+ stop = True
126
+ if stop:
127
+ break
128
+ y[0, -1] = 0
129
+
130
+ return y[:, -idx:].unsqueeze(0)
131
+
132
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
+ # self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
+ if dynamo:
135
+ export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
+ onnx_encoder_export_output = torch.onnx.dynamo_export(
137
+ self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
138
+ )
139
+ onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
140
+ return
141
+
142
+ torch.onnx.export(
143
+ self.onnx_encoder,
144
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
145
+ f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
146
+ input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
147
+ output_names=["x", "prompts"],
148
+ dynamic_axes={
149
+ "ref_seq": {1: "ref_length"},
150
+ "text_seq": {1: "text_length"},
151
+ "ref_bert": {0: "ref_length"},
152
+ "text_bert": {0: "text_length"},
153
+ "ssl_content": {2: "ssl_length"},
154
+ },
155
+ opset_version=16,
156
+ )
157
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
158
+
159
+ torch.onnx.export(
160
+ self.first_stage_decoder,
161
+ (x, prompts),
162
+ f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
163
+ input_names=["x", "prompts"],
164
+ output_names=["y", "k", "v", "y_emb", "x_example"],
165
+ dynamic_axes={
166
+ "x": {1: "x_length"},
167
+ "prompts": {1: "prompts_length"},
168
+ },
169
+ verbose=False,
170
+ opset_version=16,
171
+ )
172
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
173
+
174
+ torch.onnx.export(
175
+ self.stage_decoder,
176
+ (y, k, v, y_emb, x_example),
177
+ f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
178
+ input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
179
+ output_names=["y", "k", "v", "y_emb", "logits", "samples"],
180
+ dynamic_axes={
181
+ "iy": {1: "iy_length"},
182
+ "ik": {1: "ik_length"},
183
+ "iv": {1: "iv_length"},
184
+ "iy_emb": {1: "iy_emb_length"},
185
+ "ix_example": {1: "ix_example_length"},
186
+ },
187
+ verbose=False,
188
+ opset_version=16,
189
+ )
190
+
191
+
192
+ class VitsModel(nn.Module):
193
+ def __init__(self, vits_path):
194
+ super().__init__()
195
+ dict_s2 = torch.load(vits_path, map_location="cpu")
196
+ self.hps = dict_s2["config"]
197
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
198
+ self.hps["model"]["version"] = "v1"
199
+ else:
200
+ self.hps["model"]["version"] = "v2"
201
+
202
+ self.hps = DictToAttrRecursive(self.hps)
203
+ self.hps.model.semantic_frame_rate = "25hz"
204
+ self.vq_model = SynthesizerTrn(
205
+ self.hps.data.filter_length // 2 + 1,
206
+ self.hps.train.segment_size // self.hps.data.hop_length,
207
+ n_speakers=self.hps.data.n_speakers,
208
+ **self.hps.model,
209
+ )
210
+ self.vq_model.eval()
211
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
212
+
213
+ def forward(self, text_seq, pred_semantic, ref_audio):
214
+ refer = spectrogram_torch(
215
+ ref_audio,
216
+ self.hps.data.filter_length,
217
+ self.hps.data.sampling_rate,
218
+ self.hps.data.hop_length,
219
+ self.hps.data.win_length,
220
+ center=False,
221
+ )
222
+ return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
223
+
224
+
225
+ class GptSoVits(nn.Module):
226
+ def __init__(self, vits, t2s):
227
+ super().__init__()
228
+ self.vits = vits
229
+ self.t2s = t2s
230
+
231
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
232
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
233
+ audio = self.vits(text_seq, pred_semantic, ref_audio)
234
+ if debug:
235
+ import onnxruntime
236
+
237
+ sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
238
+ audio1 = sess.run(
239
+ None,
240
+ {
241
+ "text_seq": text_seq.detach().cpu().numpy(),
242
+ "pred_semantic": pred_semantic.detach().cpu().numpy(),
243
+ "ref_audio": ref_audio.detach().cpu().numpy(),
244
+ },
245
+ )
246
+ return audio, audio1
247
+ return audio
248
+
249
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
250
+ self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
251
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
252
+ torch.onnx.export(
253
+ self.vits,
254
+ (text_seq, pred_semantic, ref_audio),
255
+ f"onnx/{project_name}/{project_name}_vits.onnx",
256
+ input_names=["text_seq", "pred_semantic", "ref_audio"],
257
+ output_names=["audio"],
258
+ dynamic_axes={
259
+ "text_seq": {1: "text_length"},
260
+ "pred_semantic": {2: "pred_length"},
261
+ "ref_audio": {1: "audio_length"},
262
+ },
263
+ opset_version=17,
264
+ verbose=False,
265
+ )
266
+
267
+
268
+ class SSLModel(nn.Module):
269
+ def __init__(self):
270
+ super().__init__()
271
+ self.ssl = ssl_model
272
+
273
+ def forward(self, ref_audio_16k):
274
+ return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
275
+
276
+
277
+ def export(vits_path, gpt_path, project_name, vits_model="v2"):
278
+ vits = VitsModel(vits_path)
279
+ gpt = T2SModel(gpt_path, vits)
280
+ gpt_sovits = GptSoVits(vits, gpt)
281
+ ssl = SSLModel()
282
+ ref_seq = torch.LongTensor(
283
+ [
284
+ cleaned_text_to_sequence(
285
+ [
286
+ "n",
287
+ "i2",
288
+ "h",
289
+ "ao3",
290
+ ",",
291
+ "w",
292
+ "o3",
293
+ "sh",
294
+ "i4",
295
+ "b",
296
+ "ai2",
297
+ "y",
298
+ "e4",
299
+ ],
300
+ version=vits_model,
301
+ )
302
+ ]
303
+ )
304
+ text_seq = torch.LongTensor(
305
+ [
306
+ cleaned_text_to_sequence(
307
+ [
308
+ "w",
309
+ "o3",
310
+ "sh",
311
+ "i4",
312
+ "b",
313
+ "ai2",
314
+ "y",
315
+ "e4",
316
+ "w",
317
+ "o3",
318
+ "sh",
319
+ "i4",
320
+ "b",
321
+ "ai2",
322
+ "y",
323
+ "e4",
324
+ "w",
325
+ "o3",
326
+ "sh",
327
+ "i4",
328
+ "b",
329
+ "ai2",
330
+ "y",
331
+ "e4",
332
+ ],
333
+ version=vits_model,
334
+ )
335
+ ]
336
+ )
337
+ ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
338
+ text_bert = torch.randn((text_seq.shape[1], 1024)).float()
339
+ ref_audio = torch.randn((1, 48000 * 5)).float()
340
+ # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
341
+ ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
342
+ ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
343
+
344
+ try:
345
+ os.mkdir(f"onnx/{project_name}")
346
+ except:
347
+ pass
348
+
349
+ ssl_content = ssl(ref_audio_16k).float()
350
+
351
+ # debug = False
352
+ debug = True
353
+
354
+ # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
355
+
356
+ if debug:
357
+ a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
358
+ soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
359
+ soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
360
+ else:
361
+ a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
362
+ soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
363
+
364
+ if vits_model == "v1":
365
+ symbols = symbols_v1
366
+ else:
367
+ symbols = symbols_v2
368
+
369
+ MoeVSConf = {
370
+ "Folder": f"{project_name}",
371
+ "Name": f"{project_name}",
372
+ "Type": "GPT-SoVits",
373
+ "Rate": vits.hps.data.sampling_rate,
374
+ "NumLayers": gpt.t2s_model.num_layers,
375
+ "EmbeddingDim": gpt.t2s_model.embedding_dim,
376
+ "Dict": "BasicDict",
377
+ "BertPath": "chinese-roberta-wwm-ext-large",
378
+ # "Symbol": symbols,
379
+ "AddBlank": False,
380
+ }
381
+
382
+ MoeVSConfJson = json.dumps(MoeVSConf)
383
+ with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
384
+ json.dump(MoeVSConf, MoeVsConfFile, indent=4)
385
+
386
+
387
+ if __name__ == "__main__":
388
+ try:
389
+ os.mkdir("onnx")
390
+ except:
391
+ pass
392
+
393
+ gpt_path = "GPT_weights/nahida-e25.ckpt"
394
+ vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
395
+ exp_path = "nahida"
396
+ export(vits_path, gpt_path, exp_path)
397
+
398
+ # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
GPT_SoVITS/prepare_datasets/1-get-text.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+
5
+ inp_text = os.environ.get("inp_text")
6
+ inp_wav_dir = os.environ.get("inp_wav_dir")
7
+ exp_name = os.environ.get("exp_name")
8
+ i_part = os.environ.get("i_part")
9
+ all_parts = os.environ.get("all_parts")
10
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
12
+ opt_dir = os.environ.get("opt_dir")
13
+ bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
14
+ import torch
15
+
16
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
17
+ version = os.environ.get("version", None)
18
+ import traceback
19
+ import os.path
20
+ from text.cleaner import clean_text
21
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
22
+ from tools.my_utils import clean_path
23
+
24
+ # inp_text=sys.argv[1]
25
+ # inp_wav_dir=sys.argv[2]
26
+ # exp_name=sys.argv[3]
27
+ # i_part=sys.argv[4]
28
+ # all_parts=sys.argv[5]
29
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
30
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
31
+ # bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
32
+
33
+ from time import time as ttime
34
+ import shutil
35
+
36
+
37
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
38
+ dir = os.path.dirname(path)
39
+ name = os.path.basename(path)
40
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
41
+ tmp_path = "%s%s.pth" % (ttime(), i_part)
42
+ torch.save(fea, tmp_path)
43
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
44
+
45
+
46
+ txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
47
+ if os.path.exists(txt_path) == False:
48
+ bert_dir = "%s/3-bert" % (opt_dir)
49
+ os.makedirs(opt_dir, exist_ok=True)
50
+ os.makedirs(bert_dir, exist_ok=True)
51
+ if torch.cuda.is_available():
52
+ device = "cuda:0"
53
+ # elif torch.backends.mps.is_available():
54
+ # device = "mps"
55
+ else:
56
+ device = "cpu"
57
+ if os.path.exists(bert_pretrained_dir):
58
+ ...
59
+ else:
60
+ raise FileNotFoundError(bert_pretrained_dir)
61
+ tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
62
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
63
+ if is_half == True:
64
+ bert_model = bert_model.half().to(device)
65
+ else:
66
+ bert_model = bert_model.to(device)
67
+
68
+ def get_bert_feature(text, word2ph):
69
+ with torch.no_grad():
70
+ inputs = tokenizer(text, return_tensors="pt")
71
+ for i in inputs:
72
+ inputs[i] = inputs[i].to(device)
73
+ res = bert_model(**inputs, output_hidden_states=True)
74
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
75
+
76
+ assert len(word2ph) == len(text)
77
+ phone_level_feature = []
78
+ for i in range(len(word2ph)):
79
+ repeat_feature = res[i].repeat(word2ph[i], 1)
80
+ phone_level_feature.append(repeat_feature)
81
+
82
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
83
+
84
+ return phone_level_feature.T
85
+
86
+ def process(data, res):
87
+ for name, text, lan in data:
88
+ try:
89
+ name = clean_path(name)
90
+ name = os.path.basename(name)
91
+ print(name)
92
+ phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
93
+ path_bert = "%s/%s.pt" % (bert_dir, name)
94
+ if os.path.exists(path_bert) == False and lan == "zh":
95
+ bert_feature = get_bert_feature(norm_text, word2ph)
96
+ assert bert_feature.shape[-1] == len(phones)
97
+ # torch.save(bert_feature, path_bert)
98
+ my_save(bert_feature, path_bert)
99
+ phones = " ".join(phones)
100
+ # res.append([name,phones])
101
+ res.append([name, phones, word2ph, norm_text])
102
+ except:
103
+ print(name, text, traceback.format_exc())
104
+
105
+ todo = []
106
+ res = []
107
+ with open(inp_text, "r", encoding="utf8") as f:
108
+ lines = f.read().strip("\n").split("\n")
109
+
110
+ language_v1_to_language_v2 = {
111
+ "ZH": "zh",
112
+ "zh": "zh",
113
+ "JP": "ja",
114
+ "jp": "ja",
115
+ "JA": "ja",
116
+ "ja": "ja",
117
+ "EN": "en",
118
+ "en": "en",
119
+ "En": "en",
120
+ "KO": "ko",
121
+ "Ko": "ko",
122
+ "ko": "ko",
123
+ "yue": "yue",
124
+ "YUE": "yue",
125
+ "Yue": "yue",
126
+ }
127
+ for line in lines[int(i_part) :: int(all_parts)]:
128
+ try:
129
+ wav_name, spk_name, language, text = line.split("|")
130
+ # todo.append([name,text,"zh"])
131
+ if language in language_v1_to_language_v2.keys():
132
+ todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
133
+ else:
134
+ print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
135
+ except:
136
+ print(line, traceback.format_exc())
137
+
138
+ process(todo, res)
139
+ opt = []
140
+ for name, phones, word2ph, norm_text in res:
141
+ opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
142
+ with open(txt_path, "w", encoding="utf8") as f:
143
+ f.write("\n".join(opt) + "\n")
GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys
4
+ import os
5
+
6
+ inp_text = os.environ.get("inp_text")
7
+ inp_wav_dir = os.environ.get("inp_wav_dir")
8
+ exp_name = os.environ.get("exp_name")
9
+ i_part = os.environ.get("i_part")
10
+ all_parts = os.environ.get("all_parts")
11
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
13
+ from feature_extractor import cnhubert
14
+
15
+ opt_dir = os.environ.get("opt_dir")
16
+ cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
17
+ import torch
18
+
19
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
20
+
21
+ import traceback
22
+ import numpy as np
23
+ from scipy.io import wavfile
24
+ import librosa
25
+
26
+ now_dir = os.getcwd()
27
+ sys.path.append(now_dir)
28
+ from tools.my_utils import load_audio, clean_path
29
+
30
+ # from config import cnhubert_base_path
31
+ # cnhubert.cnhubert_base_path=cnhubert_base_path
32
+ # inp_text=sys.argv[1]
33
+ # inp_wav_dir=sys.argv[2]
34
+ # exp_name=sys.argv[3]
35
+ # i_part=sys.argv[4]
36
+ # all_parts=sys.argv[5]
37
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
38
+ # cnhubert.cnhubert_base_path=sys.argv[7]
39
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
40
+
41
+ from time import time as ttime
42
+ import shutil
43
+
44
+
45
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
46
+ dir = os.path.dirname(path)
47
+ name = os.path.basename(path)
48
+ # tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
49
+ tmp_path = "%s%s.pth" % (ttime(), i_part)
50
+ torch.save(fea, tmp_path)
51
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
52
+
53
+
54
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
55
+ wav32dir = "%s/5-wav32k" % (opt_dir)
56
+ os.makedirs(opt_dir, exist_ok=True)
57
+ os.makedirs(hubert_dir, exist_ok=True)
58
+ os.makedirs(wav32dir, exist_ok=True)
59
+
60
+ maxx = 0.95
61
+ alpha = 0.5
62
+ if torch.cuda.is_available():
63
+ device = "cuda:0"
64
+ # elif torch.backends.mps.is_available():
65
+ # device = "mps"
66
+ else:
67
+ device = "cpu"
68
+ model = cnhubert.get_model()
69
+ # is_half=False
70
+ if is_half == True:
71
+ model = model.half().to(device)
72
+ else:
73
+ model = model.to(device)
74
+
75
+ nan_fails = []
76
+
77
+
78
+ def name2go(wav_name, wav_path):
79
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
80
+ if os.path.exists(hubert_path):
81
+ return
82
+ tmp_audio = load_audio(wav_path, 32000)
83
+ tmp_max = np.abs(tmp_audio).max()
84
+ if tmp_max > 2.2:
85
+ print("%s-filtered,%s" % (wav_name, tmp_max))
86
+ return
87
+ tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
88
+ tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
89
+ tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
90
+ tensor_wav16 = torch.from_numpy(tmp_audio)
91
+ if is_half == True:
92
+ tensor_wav16 = tensor_wav16.half().to(device)
93
+ else:
94
+ tensor_wav16 = tensor_wav16.to(device)
95
+ ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
96
+ if np.isnan(ssl.detach().numpy()).sum() != 0:
97
+ nan_fails.append((wav_name, wav_path))
98
+ print("nan filtered:%s" % wav_name)
99
+ return
100
+ wavfile.write(
101
+ "%s/%s" % (wav32dir, wav_name),
102
+ 32000,
103
+ tmp_audio32.astype("int16"),
104
+ )
105
+ my_save(ssl, hubert_path)
106
+
107
+
108
+ with open(inp_text, "r", encoding="utf8") as f:
109
+ lines = f.read().strip("\n").split("\n")
110
+
111
+ for line in lines[int(i_part) :: int(all_parts)]:
112
+ try:
113
+ # wav_name,text=line.split("\t")
114
+ wav_name, spk_name, language, text = line.split("|")
115
+ wav_name = clean_path(wav_name)
116
+ if inp_wav_dir != "" and inp_wav_dir != None:
117
+ wav_name = os.path.basename(wav_name)
118
+ wav_path = "%s/%s" % (inp_wav_dir, wav_name)
119
+
120
+ else:
121
+ wav_path = wav_name
122
+ wav_name = os.path.basename(wav_name)
123
+ name2go(wav_name, wav_path)
124
+ except:
125
+ print(line, traceback.format_exc())
126
+
127
+ if len(nan_fails) > 0 and is_half == True:
128
+ is_half = False
129
+ model = model.float()
130
+ for wav in nan_fails:
131
+ try:
132
+ name2go(wav[0], wav[1])
133
+ except:
134
+ print(wav_name, traceback.format_exc())
GPT_SoVITS/prepare_datasets/3-get-semantic.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ inp_text = os.environ.get("inp_text")
4
+ exp_name = os.environ.get("exp_name")
5
+ i_part = os.environ.get("i_part")
6
+ all_parts = os.environ.get("all_parts")
7
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
9
+ opt_dir = os.environ.get("opt_dir")
10
+ pretrained_s2G = os.environ.get("pretrained_s2G")
11
+ s2config_path = os.environ.get("s2config_path")
12
+
13
+ if os.path.exists(pretrained_s2G):
14
+ ...
15
+ else:
16
+ raise FileNotFoundError(pretrained_s2G)
17
+ # version=os.environ.get("version","v2")
18
+ size = os.path.getsize(pretrained_s2G)
19
+ if size < 82978 * 1024:
20
+ version = "v1"
21
+ elif size < 100 * 1024 * 1024:
22
+ version = "v2"
23
+ elif size < 103520 * 1024:
24
+ version = "v1"
25
+ elif size < 700 * 1024 * 1024:
26
+ version = "v2"
27
+ else:
28
+ version = "v3"
29
+ import torch
30
+
31
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
32
+ import traceback
33
+ import sys
34
+
35
+ now_dir = os.getcwd()
36
+ sys.path.append(now_dir)
37
+ import logging
38
+ import utils
39
+
40
+ if version != "v3":
41
+ from module.models import SynthesizerTrn
42
+ else:
43
+ from module.models import SynthesizerTrnV3 as SynthesizerTrn
44
+ from tools.my_utils import clean_path
45
+
46
+ logging.getLogger("numba").setLevel(logging.WARNING)
47
+ # from config import pretrained_s2G
48
+
49
+ # inp_text=sys.argv[1]
50
+ # exp_name=sys.argv[2]
51
+ # i_part=sys.argv[3]
52
+ # all_parts=sys.argv[4]
53
+ # os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[5]
54
+ # opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
55
+
56
+
57
+ hubert_dir = "%s/4-cnhubert" % (opt_dir)
58
+ semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
59
+ if os.path.exists(semantic_path) == False:
60
+ os.makedirs(opt_dir, exist_ok=True)
61
+
62
+ if torch.cuda.is_available():
63
+ device = "cuda"
64
+ # elif torch.backends.mps.is_available():
65
+ # device = "mps"
66
+ else:
67
+ device = "cpu"
68
+ hps = utils.get_hparams_from_file(s2config_path)
69
+ vq_model = SynthesizerTrn(
70
+ hps.data.filter_length // 2 + 1,
71
+ hps.train.segment_size // hps.data.hop_length,
72
+ n_speakers=hps.data.n_speakers,
73
+ version=version,
74
+ **hps.model,
75
+ )
76
+ if is_half == True:
77
+ vq_model = vq_model.half().to(device)
78
+ else:
79
+ vq_model = vq_model.to(device)
80
+ vq_model.eval()
81
+ # utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, None, True)
82
+ # utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
83
+ print(
84
+ vq_model.load_state_dict(
85
+ torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
86
+ )
87
+ )
88
+
89
+ def name2go(wav_name, lines):
90
+ hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
91
+ if os.path.exists(hubert_path) == False:
92
+ return
93
+ ssl_content = torch.load(hubert_path, map_location="cpu")
94
+ if is_half == True:
95
+ ssl_content = ssl_content.half().to(device)
96
+ else:
97
+ ssl_content = ssl_content.to(device)
98
+ codes = vq_model.extract_latent(ssl_content)
99
+ semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()])
100
+ lines.append("%s\t%s" % (wav_name, semantic))
101
+
102
+ with open(inp_text, "r", encoding="utf8") as f:
103
+ lines = f.read().strip("\n").split("\n")
104
+
105
+ lines1 = []
106
+ for line in lines[int(i_part) :: int(all_parts)]:
107
+ # print(line)
108
+ try:
109
+ # wav_name,text=line.split("\t")
110
+ wav_name, spk_name, language, text = line.split("|")
111
+ wav_name = clean_path(wav_name)
112
+ wav_name = os.path.basename(wav_name)
113
+ # name2go(name,lines1)
114
+ name2go(wav_name, lines1)
115
+ except:
116
+ print(line, traceback.format_exc())
117
+ with open(semantic_path, "w", encoding="utf8") as f:
118
+ f.write("\n".join(lines1))
GPT_SoVITS/pretrained_models/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *
2
+ !.gitignore
GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1
3
+ size 155093966
GPT_SoVITS/pretrained_models/s1v3.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87133414860ea14ff6620c483a3db5ed07b44be42e2c3fcdad65523a729a745a
3
+ size 155284856
GPT_SoVITS/process_ckpt.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from collections import OrderedDict
3
+ from time import time as ttime
4
+ import shutil
5
+ import os
6
+ import torch
7
+ from tools.i18n.i18n import I18nAuto
8
+
9
+ i18n = I18nAuto()
10
+
11
+
12
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
13
+ dir = os.path.dirname(path)
14
+ name = os.path.basename(path)
15
+ tmp_path = "%s.pth" % (ttime())
16
+ torch.save(fea, tmp_path)
17
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
18
+
19
+
20
+ """
21
+ 00:v1
22
+ 01:v2
23
+ 02:v3
24
+ 03:v3lora
25
+ 04:v4lora
26
+
27
+ """
28
+ from io import BytesIO
29
+
30
+
31
+ def my_save2(fea, path,cfm_version):
32
+ bio = BytesIO()
33
+ torch.save(fea, bio)
34
+ bio.seek(0)
35
+ data = bio.getvalue()
36
+ byte=b"03" if cfm_version=="v3"else b"04"
37
+ data = byte + data[2:]
38
+ with open(path, "wb") as f:
39
+ f.write(data)
40
+
41
+
42
+ def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
43
+ try:
44
+ opt = OrderedDict()
45
+ opt["weight"] = {}
46
+ for key in ckpt.keys():
47
+ if "enc_q" in key:
48
+ continue
49
+ opt["weight"][key] = ckpt[key].half()
50
+ opt["config"] = hps
51
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
52
+ if lora_rank:
53
+ opt["lora_rank"] = lora_rank
54
+ my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
55
+ else:
56
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
57
+ return "Success."
58
+ except:
59
+ return traceback.format_exc()
60
+
61
+
62
+ head2version = {
63
+ b"00": ["v1", "v1", False],
64
+ b"01": ["v2", "v2", False],
65
+ b"02": ["v2", "v3", False],
66
+ b"03": ["v2", "v3", True],
67
+ b"04": ["v2", "v4", True],
68
+ }
69
+ hash_pretrained_dict = {
70
+ "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
71
+ "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
72
+ "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
73
+ "4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
74
+ }
75
+ import hashlib
76
+
77
+
78
+ def get_hash_from_file(sovits_path):
79
+ with open(sovits_path, "rb") as f:
80
+ data = f.read(8192)
81
+ hash_md5 = hashlib.md5()
82
+ hash_md5.update(data)
83
+ return hash_md5.hexdigest()
84
+
85
+
86
+ def get_sovits_version_from_path_fast(sovits_path):
87
+ ###1-if it is pretrained sovits models, by hash
88
+ hash = get_hash_from_file(sovits_path)
89
+ if hash in hash_pretrained_dict:
90
+ return hash_pretrained_dict[hash]
91
+ ###2-new weights, by head
92
+ with open(sovits_path, "rb") as f:
93
+ version = f.read(2)
94
+ if version != b"PK":
95
+ return head2version[version]
96
+ ###3-old weights, by file size
97
+ if_lora_v3 = False
98
+ size = os.path.getsize(sovits_path)
99
+ """
100
+ v1weights:about 82942KB
101
+ half thr:82978KB
102
+ v2weights:about 83014KB
103
+ v3weights:about 750MB
104
+ """
105
+ if size < 82978 * 1024:
106
+ model_version = version = "v1"
107
+ elif size < 700 * 1024 * 1024:
108
+ model_version = version = "v2"
109
+ else:
110
+ version = "v2"
111
+ model_version = "v3"
112
+ return version, model_version, if_lora_v3
113
+
114
+
115
+ def load_sovits_new(sovits_path):
116
+ f = open(sovits_path, "rb")
117
+ meta = f.read(2)
118
+ if meta != "PK":
119
+ data = b"PK" + f.read()
120
+ bio = BytesIO()
121
+ bio.write(data)
122
+ bio.seek(0)
123
+ return torch.load(bio, map_location="cpu", weights_only=False)
124
+ return torch.load(sovits_path, map_location="cpu", weights_only=False)
GPT_SoVITS/s1_train.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import os
3
+
4
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
6
+ import argparse
7
+ import logging
8
+ import platform
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from AR.data.data_module import Text2SemanticDataModule
13
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
14
+ from AR.utils.io import load_yaml_config
15
+ from pytorch_lightning import Trainer, seed_everything
16
+ from pytorch_lightning.callbacks import ModelCheckpoint
17
+ from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
18
+ from pytorch_lightning.strategies import DDPStrategy
19
+
20
+ logging.getLogger("numba").setLevel(logging.WARNING)
21
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
22
+ torch.set_float32_matmul_precision("high")
23
+ from collections import OrderedDict
24
+
25
+ from AR.utils import get_newest_ckpt
26
+ from process_ckpt import my_save
27
+
28
+
29
+ class my_model_ckpt(ModelCheckpoint):
30
+ def __init__(
31
+ self,
32
+ config,
33
+ if_save_latest,
34
+ if_save_every_weights,
35
+ half_weights_save_dir,
36
+ exp_name,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.if_save_latest = if_save_latest
41
+ self.if_save_every_weights = if_save_every_weights
42
+ self.half_weights_save_dir = half_weights_save_dir
43
+ self.exp_name = exp_name
44
+ self.config = config
45
+
46
+ def on_train_epoch_end(self, trainer, pl_module):
47
+ # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
48
+ if self._should_save_on_train_epoch_end(trainer):
49
+ monitor_candidates = self._monitor_candidates(trainer)
50
+ if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
51
+ if (
52
+ self.if_save_latest == True
53
+ ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
54
+ to_clean = list(os.listdir(self.dirpath))
55
+ self._save_topk_checkpoint(trainer, monitor_candidates)
56
+ if self.if_save_latest == True:
57
+ for name in to_clean:
58
+ try:
59
+ os.remove("%s/%s" % (self.dirpath, name))
60
+ except:
61
+ pass
62
+ if self.if_save_every_weights == True:
63
+ to_save_od = OrderedDict()
64
+ to_save_od["weight"] = OrderedDict()
65
+ dictt = trainer.strategy._lightning_module.state_dict()
66
+ for key in dictt:
67
+ to_save_od["weight"][key] = dictt[key].half()
68
+ to_save_od["config"] = self.config
69
+ to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
70
+ # torch.save(
71
+ # print(os.environ)
72
+ if os.environ.get("LOCAL_RANK", "0") == "0":
73
+ my_save(
74
+ to_save_od,
75
+ "%s/%s-e%s.ckpt"
76
+ % (
77
+ self.half_weights_save_dir,
78
+ self.exp_name,
79
+ trainer.current_epoch + 1,
80
+ ),
81
+ )
82
+ self._save_last_checkpoint(trainer, monitor_candidates)
83
+
84
+
85
+ def main(args):
86
+ config = load_yaml_config(args.config_file)
87
+
88
+ output_dir = Path(config["output_dir"])
89
+ output_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ ckpt_dir = output_dir / "ckpt"
92
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
93
+
94
+ seed_everything(config["train"]["seed"], workers=True)
95
+ ckpt_callback: ModelCheckpoint = my_model_ckpt(
96
+ config=config,
97
+ if_save_latest=config["train"]["if_save_latest"],
98
+ if_save_every_weights=config["train"]["if_save_every_weights"],
99
+ half_weights_save_dir=config["train"]["half_weights_save_dir"],
100
+ exp_name=config["train"]["exp_name"],
101
+ save_top_k=-1,
102
+ monitor="top_3_acc",
103
+ mode="max",
104
+ save_on_train_epoch_end=True,
105
+ every_n_epochs=config["train"]["save_every_n_epoch"],
106
+ dirpath=ckpt_dir,
107
+ )
108
+ logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
109
+ os.environ["MASTER_ADDR"] = "localhost"
110
+ os.environ["USE_LIBUV"] = "0"
111
+ trainer: Trainer = Trainer(
112
+ max_epochs=config["train"]["epochs"],
113
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
114
+ # val_check_interval=9999999999999999999999,###不要验证
115
+ # check_val_every_n_epoch=None,
116
+ limit_val_batches=0,
117
+ devices=-1 if torch.cuda.is_available() else 1,
118
+ benchmark=False,
119
+ fast_dev_run=False,
120
+ strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
121
+ if torch.cuda.is_available()
122
+ else "auto",
123
+ precision=config["train"]["precision"],
124
+ logger=logger,
125
+ num_sanity_val_steps=0,
126
+ callbacks=[ckpt_callback],
127
+ use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
128
+ )
129
+
130
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
131
+
132
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
133
+ config,
134
+ train_semantic_path=config["train_semantic_path"],
135
+ train_phoneme_path=config["train_phoneme_path"],
136
+ # dev_semantic_path=args.dev_semantic_path,
137
+ # dev_phoneme_path=args.dev_phoneme_path
138
+ )
139
+
140
+ try:
141
+ # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
142
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
143
+ ckpt_path = ckpt_dir / newest_ckpt_name
144
+ except Exception:
145
+ ckpt_path = None
146
+ print("ckpt_path:", ckpt_path)
147
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
148
+
149
+
150
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument(
154
+ "-c",
155
+ "--config_file",
156
+ type=str,
157
+ default="configs/s1longer.yaml",
158
+ help="path of config file",
159
+ )
160
+ # args for dataset
161
+ # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
162
+ # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
163
+
164
+ # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
165
+ # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
166
+ # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
167
+ # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
168
+
169
+ args = parser.parse_args()
170
+ logging.info(str(args))
171
+ main(args)
GPT_SoVITS/s2_train.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+ import os
5
+
6
+ import utils
7
+
8
+ hps = utils.get_hparams(stage=2)
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
10
+ import logging
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from torch.nn import functional as F
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from torch.utils.data import DataLoader
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tqdm import tqdm
21
+
22
+ logging.getLogger("matplotlib").setLevel(logging.INFO)
23
+ logging.getLogger("h5py").setLevel(logging.INFO)
24
+ logging.getLogger("numba").setLevel(logging.INFO)
25
+ from random import randint
26
+
27
+ from module import commons
28
+ from module.data_utils import (
29
+ DistributedBucketSampler,
30
+ TextAudioSpeakerCollate,
31
+ TextAudioSpeakerLoader,
32
+ )
33
+ from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
34
+ from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
35
+ from module.models import (
36
+ MultiPeriodDiscriminator,
37
+ SynthesizerTrn,
38
+ )
39
+ from process_ckpt import savee
40
+
41
+ torch.backends.cudnn.benchmark = False
42
+ torch.backends.cudnn.deterministic = False
43
+ ###反正A100fp32更快,那试试tf32吧
44
+ torch.backends.cuda.matmul.allow_tf32 = True
45
+ torch.backends.cudnn.allow_tf32 = True
46
+ torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
47
+ # from config import pretrained_s2G,pretrained_s2D
48
+ global_step = 0
49
+
50
+ device = "cpu" # cuda以外的设备,等mps优化后加入
51
+
52
+
53
+ def main():
54
+ if torch.cuda.is_available():
55
+ n_gpus = torch.cuda.device_count()
56
+ else:
57
+ n_gpus = 1
58
+ os.environ["MASTER_ADDR"] = "localhost"
59
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
60
+
61
+ mp.spawn(
62
+ run,
63
+ nprocs=n_gpus,
64
+ args=(
65
+ n_gpus,
66
+ hps,
67
+ ),
68
+ )
69
+
70
+
71
+ def run(rank, n_gpus, hps):
72
+ global global_step
73
+ if rank == 0:
74
+ logger = utils.get_logger(hps.data.exp_dir)
75
+ logger.info(hps)
76
+ # utils.check_git_hash(hps.s2_ckpt_dir)
77
+ writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
78
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
79
+
80
+ dist.init_process_group(
81
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
82
+ init_method="env://?use_libuv=False",
83
+ world_size=n_gpus,
84
+ rank=rank,
85
+ )
86
+ torch.manual_seed(hps.train.seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.set_device(rank)
89
+
90
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
91
+ train_sampler = DistributedBucketSampler(
92
+ train_dataset,
93
+ hps.train.batch_size,
94
+ [
95
+ 32,
96
+ 300,
97
+ 400,
98
+ 500,
99
+ 600,
100
+ 700,
101
+ 800,
102
+ 900,
103
+ 1000,
104
+ 1100,
105
+ 1200,
106
+ 1300,
107
+ 1400,
108
+ 1500,
109
+ 1600,
110
+ 1700,
111
+ 1800,
112
+ 1900,
113
+ ],
114
+ num_replicas=n_gpus,
115
+ rank=rank,
116
+ shuffle=True,
117
+ )
118
+ collate_fn = TextAudioSpeakerCollate()
119
+ train_loader = DataLoader(
120
+ train_dataset,
121
+ num_workers=6,
122
+ shuffle=False,
123
+ pin_memory=True,
124
+ collate_fn=collate_fn,
125
+ batch_sampler=train_sampler,
126
+ persistent_workers=True,
127
+ prefetch_factor=4,
128
+ )
129
+ # if rank == 0:
130
+ # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
131
+ # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
132
+ # batch_size=1, pin_memory=True,
133
+ # drop_last=False, collate_fn=collate_fn)
134
+
135
+ net_g = (
136
+ SynthesizerTrn(
137
+ hps.data.filter_length // 2 + 1,
138
+ hps.train.segment_size // hps.data.hop_length,
139
+ n_speakers=hps.data.n_speakers,
140
+ **hps.model,
141
+ ).cuda(rank)
142
+ if torch.cuda.is_available()
143
+ else SynthesizerTrn(
144
+ hps.data.filter_length // 2 + 1,
145
+ hps.train.segment_size // hps.data.hop_length,
146
+ n_speakers=hps.data.n_speakers,
147
+ **hps.model,
148
+ ).to(device)
149
+ )
150
+
151
+ net_d = (
152
+ MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
153
+ if torch.cuda.is_available()
154
+ else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
155
+ )
156
+ for name, param in net_g.named_parameters():
157
+ if not param.requires_grad:
158
+ print(name, "not requires_grad")
159
+
160
+ te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
161
+ et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
162
+ mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
163
+ base_params = filter(
164
+ lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
165
+ net_g.parameters(),
166
+ )
167
+
168
+ # te_p=net_g.enc_p.text_embedding.parameters()
169
+ # et_p=net_g.enc_p.encoder_text.parameters()
170
+ # mrte_p=net_g.enc_p.mrte.parameters()
171
+
172
+ optim_g = torch.optim.AdamW(
173
+ # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
174
+ [
175
+ {"params": base_params, "lr": hps.train.learning_rate},
176
+ {
177
+ "params": net_g.enc_p.text_embedding.parameters(),
178
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
179
+ },
180
+ {
181
+ "params": net_g.enc_p.encoder_text.parameters(),
182
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
183
+ },
184
+ {
185
+ "params": net_g.enc_p.mrte.parameters(),
186
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
187
+ },
188
+ ],
189
+ hps.train.learning_rate,
190
+ betas=hps.train.betas,
191
+ eps=hps.train.eps,
192
+ )
193
+ optim_d = torch.optim.AdamW(
194
+ net_d.parameters(),
195
+ hps.train.learning_rate,
196
+ betas=hps.train.betas,
197
+ eps=hps.train.eps,
198
+ )
199
+ if torch.cuda.is_available():
200
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
201
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
202
+ else:
203
+ net_g = net_g.to(device)
204
+ net_d = net_d.to(device)
205
+
206
+ try: # 如果能加载自动resume
207
+ _, _, _, epoch_str = utils.load_checkpoint(
208
+ utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
209
+ net_d,
210
+ optim_d,
211
+ ) # D多半加载没事
212
+ if rank == 0:
213
+ logger.info("loaded D")
214
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
215
+ _, _, _, epoch_str = utils.load_checkpoint(
216
+ utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
217
+ net_g,
218
+ optim_g,
219
+ )
220
+ epoch_str += 1
221
+ global_step = (epoch_str - 1) * len(train_loader)
222
+ # epoch_str = 1
223
+ # global_step = 0
224
+ except: # 如果首次不能加载,加载pretrain
225
+ # traceback.print_exc()
226
+ epoch_str = 1
227
+ global_step = 0
228
+ if (
229
+ hps.train.pretrained_s2G != ""
230
+ and hps.train.pretrained_s2G != None
231
+ and os.path.exists(hps.train.pretrained_s2G)
232
+ ):
233
+ if rank == 0:
234
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
235
+ print(
236
+ "loaded pretrained %s" % hps.train.pretrained_s2G,
237
+ net_g.module.load_state_dict(
238
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
239
+ strict=False,
240
+ )
241
+ if torch.cuda.is_available()
242
+ else net_g.load_state_dict(
243
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
244
+ strict=False,
245
+ ),
246
+ ) ##测试不加载优化器
247
+ if (
248
+ hps.train.pretrained_s2D != ""
249
+ and hps.train.pretrained_s2D != None
250
+ and os.path.exists(hps.train.pretrained_s2D)
251
+ ):
252
+ if rank == 0:
253
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
254
+ print(
255
+ "loaded pretrained %s" % hps.train.pretrained_s2D,
256
+ net_d.module.load_state_dict(
257
+ torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
258
+ )
259
+ if torch.cuda.is_available()
260
+ else net_d.load_state_dict(
261
+ torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
262
+ ),
263
+ )
264
+
265
+ # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
266
+ # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
267
+
268
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
269
+ optim_g,
270
+ gamma=hps.train.lr_decay,
271
+ last_epoch=-1,
272
+ )
273
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
274
+ optim_d,
275
+ gamma=hps.train.lr_decay,
276
+ last_epoch=-1,
277
+ )
278
+ for _ in range(epoch_str):
279
+ scheduler_g.step()
280
+ scheduler_d.step()
281
+
282
+ scaler = GradScaler(enabled=hps.train.fp16_run)
283
+
284
+ print("start training from epoch %s" % epoch_str)
285
+ for epoch in range(epoch_str, hps.train.epochs + 1):
286
+ if rank == 0:
287
+ train_and_evaluate(
288
+ rank,
289
+ epoch,
290
+ hps,
291
+ [net_g, net_d],
292
+ [optim_g, optim_d],
293
+ [scheduler_g, scheduler_d],
294
+ scaler,
295
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
296
+ [train_loader, None],
297
+ logger,
298
+ [writer, writer_eval],
299
+ )
300
+ else:
301
+ train_and_evaluate(
302
+ rank,
303
+ epoch,
304
+ hps,
305
+ [net_g, net_d],
306
+ [optim_g, optim_d],
307
+ [scheduler_g, scheduler_d],
308
+ scaler,
309
+ [train_loader, None],
310
+ None,
311
+ None,
312
+ )
313
+ scheduler_g.step()
314
+ scheduler_d.step()
315
+ print("training done")
316
+
317
+
318
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
319
+ net_g, net_d = nets
320
+ optim_g, optim_d = optims
321
+ # scheduler_g, scheduler_d = schedulers
322
+ train_loader, eval_loader = loaders
323
+ if writers is not None:
324
+ writer, writer_eval = writers
325
+
326
+ train_loader.batch_sampler.set_epoch(epoch)
327
+ global global_step
328
+
329
+ net_g.train()
330
+ net_d.train()
331
+ for batch_idx, (
332
+ ssl,
333
+ ssl_lengths,
334
+ spec,
335
+ spec_lengths,
336
+ y,
337
+ y_lengths,
338
+ text,
339
+ text_lengths,
340
+ ) in enumerate(tqdm(train_loader)):
341
+ if torch.cuda.is_available():
342
+ spec, spec_lengths = (
343
+ spec.cuda(
344
+ rank,
345
+ non_blocking=True,
346
+ ),
347
+ spec_lengths.cuda(
348
+ rank,
349
+ non_blocking=True,
350
+ ),
351
+ )
352
+ y, y_lengths = (
353
+ y.cuda(
354
+ rank,
355
+ non_blocking=True,
356
+ ),
357
+ y_lengths.cuda(
358
+ rank,
359
+ non_blocking=True,
360
+ ),
361
+ )
362
+ ssl = ssl.cuda(rank, non_blocking=True)
363
+ ssl.requires_grad = False
364
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
365
+ text, text_lengths = (
366
+ text.cuda(
367
+ rank,
368
+ non_blocking=True,
369
+ ),
370
+ text_lengths.cuda(
371
+ rank,
372
+ non_blocking=True,
373
+ ),
374
+ )
375
+ else:
376
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
377
+ y, y_lengths = y.to(device), y_lengths.to(device)
378
+ ssl = ssl.to(device)
379
+ ssl.requires_grad = False
380
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
381
+ text, text_lengths = text.to(device), text_lengths.to(device)
382
+
383
+ with autocast(enabled=hps.train.fp16_run):
384
+ (
385
+ y_hat,
386
+ kl_ssl,
387
+ ids_slice,
388
+ x_mask,
389
+ z_mask,
390
+ (z, z_p, m_p, logs_p, m_q, logs_q),
391
+ stats_ssl,
392
+ ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
393
+
394
+ mel = spec_to_mel_torch(
395
+ spec,
396
+ hps.data.filter_length,
397
+ hps.data.n_mel_channels,
398
+ hps.data.sampling_rate,
399
+ hps.data.mel_fmin,
400
+ hps.data.mel_fmax,
401
+ )
402
+ y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
403
+ y_hat_mel = mel_spectrogram_torch(
404
+ y_hat.squeeze(1),
405
+ hps.data.filter_length,
406
+ hps.data.n_mel_channels,
407
+ hps.data.sampling_rate,
408
+ hps.data.hop_length,
409
+ hps.data.win_length,
410
+ hps.data.mel_fmin,
411
+ hps.data.mel_fmax,
412
+ )
413
+
414
+ y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
415
+
416
+ # Discriminator
417
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
418
+ with autocast(enabled=False):
419
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
420
+ y_d_hat_r,
421
+ y_d_hat_g,
422
+ )
423
+ loss_disc_all = loss_disc
424
+ optim_d.zero_grad()
425
+ scaler.scale(loss_disc_all).backward()
426
+ scaler.unscale_(optim_d)
427
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
428
+ scaler.step(optim_d)
429
+
430
+ with autocast(enabled=hps.train.fp16_run):
431
+ # Generator
432
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
433
+ with autocast(enabled=False):
434
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
435
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
436
+
437
+ loss_fm = feature_loss(fmap_r, fmap_g)
438
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
439
+ loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
440
+
441
+ optim_g.zero_grad()
442
+ scaler.scale(loss_gen_all).backward()
443
+ scaler.unscale_(optim_g)
444
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
445
+ scaler.step(optim_g)
446
+ scaler.update()
447
+
448
+ if rank == 0:
449
+ if global_step % hps.train.log_interval == 0:
450
+ lr = optim_g.param_groups[0]["lr"]
451
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
452
+ logger.info(
453
+ "Train Epoch: {} [{:.0f}%]".format(
454
+ epoch,
455
+ 100.0 * batch_idx / len(train_loader),
456
+ )
457
+ )
458
+ logger.info([x.item() for x in losses] + [global_step, lr])
459
+
460
+ scalar_dict = {
461
+ "loss/g/total": loss_gen_all,
462
+ "loss/d/total": loss_disc_all,
463
+ "learning_rate": lr,
464
+ "grad_norm_d": grad_norm_d,
465
+ "grad_norm_g": grad_norm_g,
466
+ }
467
+ scalar_dict.update(
468
+ {
469
+ "loss/g/fm": loss_fm,
470
+ "loss/g/mel": loss_mel,
471
+ "loss/g/kl_ssl": kl_ssl,
472
+ "loss/g/kl": loss_kl,
473
+ }
474
+ )
475
+
476
+ # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
477
+ # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
478
+ # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
479
+ image_dict = None
480
+ try: ###Some people installed the wrong version of matplotlib.
481
+ image_dict = {
482
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
483
+ y_mel[0].data.cpu().numpy(),
484
+ ),
485
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
486
+ y_hat_mel[0].data.cpu().numpy(),
487
+ ),
488
+ "all/mel": utils.plot_spectrogram_to_numpy(
489
+ mel[0].data.cpu().numpy(),
490
+ ),
491
+ "all/stats_ssl": utils.plot_spectrogram_to_numpy(
492
+ stats_ssl[0].data.cpu().numpy(),
493
+ ),
494
+ }
495
+ except:
496
+ pass
497
+ if image_dict:
498
+ utils.summarize(
499
+ writer=writer,
500
+ global_step=global_step,
501
+ images=image_dict,
502
+ scalars=scalar_dict,
503
+ )
504
+ else:
505
+ utils.summarize(
506
+ writer=writer,
507
+ global_step=global_step,
508
+ scalars=scalar_dict,
509
+ )
510
+ global_step += 1
511
+ if epoch % hps.train.save_every_epoch == 0 and rank == 0:
512
+ if hps.train.if_save_latest == 0:
513
+ utils.save_checkpoint(
514
+ net_g,
515
+ optim_g,
516
+ hps.train.learning_rate,
517
+ epoch,
518
+ os.path.join(
519
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
520
+ "G_{}.pth".format(global_step),
521
+ ),
522
+ )
523
+ utils.save_checkpoint(
524
+ net_d,
525
+ optim_d,
526
+ hps.train.learning_rate,
527
+ epoch,
528
+ os.path.join(
529
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
530
+ "D_{}.pth".format(global_step),
531
+ ),
532
+ )
533
+ else:
534
+ utils.save_checkpoint(
535
+ net_g,
536
+ optim_g,
537
+ hps.train.learning_rate,
538
+ epoch,
539
+ os.path.join(
540
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
541
+ "G_{}.pth".format(233333333333),
542
+ ),
543
+ )
544
+ utils.save_checkpoint(
545
+ net_d,
546
+ optim_d,
547
+ hps.train.learning_rate,
548
+ epoch,
549
+ os.path.join(
550
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
551
+ "D_{}.pth".format(233333333333),
552
+ ),
553
+ )
554
+ if rank == 0 and hps.train.if_save_every_weights == True:
555
+ if hasattr(net_g, "module"):
556
+ ckpt = net_g.module.state_dict()
557
+ else:
558
+ ckpt = net_g.state_dict()
559
+ logger.info(
560
+ "saving ckpt %s_e%s:%s"
561
+ % (
562
+ hps.name,
563
+ epoch,
564
+ savee(
565
+ ckpt,
566
+ hps.name + "_e%s_s%s" % (epoch, global_step),
567
+ epoch,
568
+ global_step,
569
+ hps,
570
+ ),
571
+ )
572
+ )
573
+
574
+ if rank == 0:
575
+ logger.info("====> Epoch: {}".format(epoch))
576
+
577
+
578
+ def evaluate(hps, generator, eval_loader, writer_eval):
579
+ generator.eval()
580
+ image_dict = {}
581
+ audio_dict = {}
582
+ print("Evaluating ...")
583
+ with torch.no_grad():
584
+ for batch_idx, (
585
+ ssl,
586
+ ssl_lengths,
587
+ spec,
588
+ spec_lengths,
589
+ y,
590
+ y_lengths,
591
+ text,
592
+ text_lengths,
593
+ ) in enumerate(eval_loader):
594
+ print(111)
595
+ if torch.cuda.is_available():
596
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
597
+ y, y_lengths = y.cuda(), y_lengths.cuda()
598
+ ssl = ssl.cuda()
599
+ text, text_lengths = text.cuda(), text_lengths.cuda()
600
+ else:
601
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
602
+ y, y_lengths = y.to(device), y_lengths.to(device)
603
+ ssl = ssl.to(device)
604
+ text, text_lengths = text.to(device), text_lengths.to(device)
605
+ for test in [0, 1]:
606
+ y_hat, mask, *_ = (
607
+ generator.module.infer(
608
+ ssl,
609
+ spec,
610
+ spec_lengths,
611
+ text,
612
+ text_lengths,
613
+ test=test,
614
+ )
615
+ if torch.cuda.is_available()
616
+ else generator.infer(
617
+ ssl,
618
+ spec,
619
+ spec_lengths,
620
+ text,
621
+ text_lengths,
622
+ test=test,
623
+ )
624
+ )
625
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
626
+
627
+ mel = spec_to_mel_torch(
628
+ spec,
629
+ hps.data.filter_length,
630
+ hps.data.n_mel_channels,
631
+ hps.data.sampling_rate,
632
+ hps.data.mel_fmin,
633
+ hps.data.mel_fmax,
634
+ )
635
+ y_hat_mel = mel_spectrogram_torch(
636
+ y_hat.squeeze(1).float(),
637
+ hps.data.filter_length,
638
+ hps.data.n_mel_channels,
639
+ hps.data.sampling_rate,
640
+ hps.data.hop_length,
641
+ hps.data.win_length,
642
+ hps.data.mel_fmin,
643
+ hps.data.mel_fmax,
644
+ )
645
+ image_dict.update(
646
+ {
647
+ f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
648
+ y_hat_mel[0].cpu().numpy(),
649
+ ),
650
+ }
651
+ )
652
+ audio_dict.update(
653
+ {
654
+ f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
655
+ },
656
+ )
657
+ image_dict.update(
658
+ {
659
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
660
+ },
661
+ )
662
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
663
+
664
+ # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
665
+ # audio_dict.update({
666
+ # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
667
+ # })
668
+
669
+ utils.summarize(
670
+ writer=writer_eval,
671
+ global_step=global_step,
672
+ images=image_dict,
673
+ audios=audio_dict,
674
+ audio_sampling_rate=hps.data.sampling_rate,
675
+ )
676
+ generator.train()
677
+
678
+
679
+ if __name__ == "__main__":
680
+ main()