kevinwang676 commited on
Commit
559ee5e
·
verified ·
1 Parent(s): 2154d8e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +8 -0
  2. .gitignore +200 -0
  3. Docker/damo.sha256 +3 -0
  4. Docker/download.py +8 -0
  5. Docker/download.sh +11 -0
  6. Docker/links.sha256 +12 -0
  7. Docker/links.txt +34 -0
  8. Dockerfile +42 -0
  9. GPT_SoVITS/BigVGAN/LICENSE +21 -0
  10. GPT_SoVITS/BigVGAN/README.md +266 -0
  11. GPT_SoVITS/BigVGAN/activations.py +122 -0
  12. GPT_SoVITS/BigVGAN/bigvgan.py +461 -0
  13. GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json +45 -0
  14. GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json +45 -0
  15. GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json +45 -0
  16. GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json +45 -0
  17. GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json +61 -0
  18. GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json +61 -0
  19. GPT_SoVITS/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json +61 -0
  20. GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json +61 -0
  21. GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json +61 -0
  22. GPT_SoVITS/BigVGAN/discriminators.py +625 -0
  23. GPT_SoVITS/BigVGAN/env.py +18 -0
  24. GPT_SoVITS/BigVGAN/inference.py +85 -0
  25. GPT_SoVITS/BigVGAN/inference_e2e.py +100 -0
  26. GPT_SoVITS/BigVGAN/loss.py +238 -0
  27. GPT_SoVITS/BigVGAN/meldataset.py +370 -0
  28. GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep +1 -0
  29. GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md +4 -0
  30. GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md +13 -0
  31. GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md +126 -0
  32. GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md +14 -0
  33. GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md +6 -0
  34. GPT_SoVITS/BigVGAN/requirements.txt +13 -0
  35. GPT_SoVITS/BigVGAN/train.py +716 -0
  36. GPT_SoVITS/BigVGAN/utils0.py +99 -0
  37. GPT_SoVITS/download.py +13 -0
  38. GPT_SoVITS/export_torch_script.py +861 -0
  39. GPT_SoVITS/export_torch_script_v3.py +1035 -0
  40. GPT_SoVITS/inference_cli.py +86 -0
  41. GPT_SoVITS/inference_gui.py +316 -0
  42. GPT_SoVITS/inference_webui.py +1280 -0
  43. GPT_SoVITS/inference_webui_fast.py +540 -0
  44. GPT_SoVITS/onnx_export.py +398 -0
  45. GPT_SoVITS/process_ckpt.py +124 -0
  46. GPT_SoVITS/s1_train.py +171 -0
  47. GPT_SoVITS/s2_train.py +680 -0
  48. GPT_SoVITS/s2_train_v3.py +467 -0
  49. GPT_SoVITS/s2_train_v3_lora.py +379 -0
  50. GPT_SoVITS/utils.py +361 -0
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ docs
2
+ logs
3
+ output
4
+ reference
5
+ SoVITS_weights
6
+ GPT_weights
7
+ TEMP
8
+ .git
.gitignore ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .vscode
3
+ __pycache__
4
+ *.pyc
5
+ env
6
+ runtime
7
+ .idea
8
+ output
9
+ logs
10
+ reference
11
+ GPT_weights
12
+ SoVITS_weights
13
+ GPT_weights_v2
14
+ SoVITS_weights_v2
15
+ GPT_weights_v3
16
+ SoVITS_weights_v3
17
+ TEMP
18
+ weight.json
19
+ ffmpeg*
20
+ ffprobe*
21
+ cfg.json
22
+ speakers.json
23
+ ref_audios
24
+ tools/AP_BWE_main/24kto48k/*
25
+ !tools/AP_BWE_main/24kto48k/readme.txt
26
+
27
+ # Byte-compiled / optimized / DLL files
28
+ __pycache__/
29
+ *.py[cod]
30
+ *$py.class
31
+
32
+ # C extensions
33
+ *.so
34
+
35
+ # Distribution / packaging
36
+ .Python
37
+ build/
38
+ develop-eggs/
39
+ dist/
40
+ downloads/
41
+ eggs/
42
+ .eggs/
43
+ lib/
44
+ lib64/
45
+ parts/
46
+ sdist/
47
+ var/
48
+ wheels/
49
+ share/python-wheels/
50
+ *.egg-info/
51
+ .installed.cfg
52
+ *.egg
53
+ MANIFEST
54
+
55
+ # PyInstaller
56
+ # Usually these files are written by a python script from a template
57
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
58
+ *.manifest
59
+ *.spec
60
+
61
+ # Installer logs
62
+ pip-log.txt
63
+ pip-delete-this-directory.txt
64
+
65
+ # Unit test / coverage reports
66
+ htmlcov/
67
+ .tox/
68
+ .nox/
69
+ .coverage
70
+ .coverage.*
71
+ .cache
72
+ nosetests.xml
73
+ coverage.xml
74
+ *.cover
75
+ *.py,cover
76
+ .hypothesis/
77
+ .pytest_cache/
78
+ cover/
79
+
80
+ # Translations
81
+ *.mo
82
+ *.pot
83
+
84
+ # Django stuff:
85
+ *.log
86
+ local_settings.py
87
+ db.sqlite3
88
+ db.sqlite3-journal
89
+
90
+ # Flask stuff:
91
+ instance/
92
+ .webassets-cache
93
+
94
+ # Scrapy stuff:
95
+ .scrapy
96
+
97
+ # Sphinx documentation
98
+ docs/_build/
99
+
100
+ # PyBuilder
101
+ .pybuilder/
102
+ target/
103
+
104
+ # Jupyter Notebook
105
+ .ipynb_checkpoints
106
+
107
+ # IPython
108
+ profile_default/
109
+ ipython_config.py
110
+
111
+ # pyenv
112
+ # For a library or package, you might want to ignore these files since the code is
113
+ # intended to run in multiple environments; otherwise, check them in:
114
+ # .python-version
115
+
116
+ # pipenv
117
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
118
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
119
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
120
+ # install all needed dependencies.
121
+ #Pipfile.lock
122
+
123
+ # UV
124
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
125
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
126
+ # commonly ignored for libraries.
127
+ #uv.lock
128
+
129
+ # poetry
130
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
131
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
132
+ # commonly ignored for libraries.
133
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
134
+ #poetry.lock
135
+
136
+ # pdm
137
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
138
+ #pdm.lock
139
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
140
+ # in version control.
141
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
142
+ .pdm.toml
143
+ .pdm-python
144
+ .pdm-build/
145
+
146
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
147
+ __pypackages__/
148
+
149
+ # Celery stuff
150
+ celerybeat-schedule
151
+ celerybeat.pid
152
+
153
+ # SageMath parsed files
154
+ *.sage.py
155
+
156
+ # Environments
157
+ .env
158
+ .venv
159
+ env/
160
+ venv/
161
+ ENV/
162
+ env.bak/
163
+ venv.bak/
164
+
165
+ # Spyder project settings
166
+ .spyderproject
167
+ .spyproject
168
+
169
+ # Rope project settings
170
+ .ropeproject
171
+
172
+ # mkdocs documentation
173
+ /site
174
+
175
+ # mypy
176
+ .mypy_cache/
177
+ .dmypy.json
178
+ dmypy.json
179
+
180
+ # Pyre type checker
181
+ .pyre/
182
+
183
+ # pytype static type analyzer
184
+ .pytype/
185
+
186
+ # Cython debug symbols
187
+ cython_debug/
188
+
189
+ # PyCharm
190
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
191
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
192
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
193
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
194
+ #.idea/
195
+
196
+ # Ruff stuff:
197
+ .ruff_cache/
198
+
199
+ # PyPI configuration file
200
+ .pypirc
Docker/damo.sha256 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb
2
+ b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb
3
+ a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb
Docker/download.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Download moda ASR related models
2
+ from modelscope import snapshot_download
3
+
4
+ model_dir = snapshot_download(
5
+ "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
6
+ )
7
+ model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
8
+ model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")
Docker/download.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -Eeuo pipefail
4
+
5
+ echo "Downloading models..."
6
+
7
+ aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue
8
+
9
+ echo "Checking SHA256..."
10
+
11
+ parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c"
Docker/links.sha256 ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
2
+ fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth
3
+ 020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth
4
+ 24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
5
+ e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
6
+ 39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth
7
+ 45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth
8
+ 5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
9
+ 8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
10
+ 01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
11
+ 56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
12
+ 233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
Docker/links.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT-SoVITS models
2
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt
3
+ out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
4
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth
5
+ out=GPT_SoVITS/pretrained_models/s2D488k.pth
6
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth
7
+ out=GPT_SoVITS/pretrained_models/s2G488k.pth
8
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json
9
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json
10
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json
11
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json
12
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin
13
+ out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
14
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json
15
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json
16
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin
17
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
18
+ https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json
19
+ out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
20
+ # UVR5
21
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth
22
+ out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth
23
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth
24
+ out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth
25
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth
26
+ out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
27
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth
28
+ out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
29
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth
30
+ out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
31
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth
32
+ out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
33
+ https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
34
+ out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
Dockerfile ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base CUDA image
2
+ FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
3
+
4
+ LABEL maintainer="[email protected]"
5
+ LABEL version="dev-20240209"
6
+ LABEL description="Docker image for GPT-SoVITS"
7
+
8
+
9
+ # Install 3rd party apps
10
+ ENV DEBIAN_FRONTEND=noninteractive
11
+ ENV TZ=Etc/UTC
12
+ RUN apt-get update && \
13
+ apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
14
+ git lfs install && \
15
+ rm -rf /var/lib/apt/lists/*
16
+
17
+ # Copy only requirements.txt initially to leverage Docker cache
18
+ WORKDIR /workspace
19
+ COPY requirements.txt /workspace/
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Define a build-time argument for image type
23
+ ARG IMAGE_TYPE=full
24
+
25
+ # Conditional logic based on the IMAGE_TYPE argument
26
+ # Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
27
+ COPY ./Docker /workspace/Docker
28
+ # elite 类型的镜像里面不包含额外的模型
29
+ RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
30
+ chmod +x /workspace/Docker/download.sh && \
31
+ /workspace/Docker/download.sh && \
32
+ python /workspace/Docker/download.py && \
33
+ python -m nltk.downloader averaged_perceptron_tagger cmudict; \
34
+ fi
35
+
36
+
37
+ # Copy the rest of the application
38
+ COPY . /workspace
39
+
40
+ EXPOSE 9871 9872 9873 9874 9880
41
+
42
+ CMD ["python", "webui.py"]
GPT_SoVITS/BigVGAN/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
GPT_SoVITS/BigVGAN/README.md ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
2
+
3
+ #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
4
+
5
+ [[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
6
+
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
8
+
9
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
10
+
11
+ ## News
12
+ - **Sep 2024 (v2.4):**
13
+ - We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
14
+
15
+ - **Jul 2024 (v2.3):**
16
+ - General refactor and code improvements for improved readability.
17
+ - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
18
+
19
+ - **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
20
+
21
+ - **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
22
+
23
+ - **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
24
+ - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
25
+ - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
26
+ - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
27
+ - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
28
+
29
+ ## Installation
30
+
31
+ The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
32
+
33
+ ```shell
34
+ conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
35
+ conda activate bigvgan
36
+ ```
37
+
38
+ Clone the repository and install dependencies:
39
+
40
+ ```shell
41
+ git clone https://github.com/NVIDIA/BigVGAN
42
+ cd BigVGAN
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Inference Quickstart using 🤗 Hugging Face Hub
47
+
48
+ Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
49
+
50
+ ```python
51
+ device = 'cuda'
52
+
53
+ import torch
54
+ import bigvgan
55
+ import librosa
56
+ from meldataset import get_mel_spectrogram
57
+
58
+ # instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
59
+ model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
60
+
61
+ # remove weight norm in the model and set to eval mode
62
+ model.remove_weight_norm()
63
+ model = model.eval().to(device)
64
+
65
+ # load wav file and compute mel spectrogram
66
+ wav_path = '/path/to/your/audio.wav'
67
+ wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
68
+ wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
69
+
70
+ # compute mel spectrogram from the ground truth audio
71
+ mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
72
+
73
+ # generate waveform from mel
74
+ with torch.inference_mode():
75
+ wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
76
+ wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
77
+
78
+ # you can convert the generated waveform to 16 bit linear PCM
79
+ wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
80
+ ```
81
+
82
+ ## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
83
+
84
+ You can run a local gradio demo using below command:
85
+
86
+ ```python
87
+ pip install -r demo/requirements.txt
88
+ python demo/app.py
89
+ ```
90
+
91
+ ## Training
92
+
93
+ Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
94
+
95
+ ```shell
96
+ cd filelists/LibriTTS && \
97
+ ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
98
+ ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
99
+ ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
100
+ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
101
+ ln -s /path/to/your/LibriTTS/dev-other dev-other && \
102
+ ln -s /path/to/your/LibriTTS/test-clean test-clean && \
103
+ ln -s /path/to/your/LibriTTS/test-other test-other && \
104
+ cd ../..
105
+ ```
106
+
107
+ Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
108
+
109
+ ```shell
110
+ python train.py \
111
+ --config configs/bigvgan_v2_24khz_100band_256x.json \
112
+ --input_wavs_dir filelists/LibriTTS \
113
+ --input_training_file filelists/LibriTTS/train-full.txt \
114
+ --input_validation_file filelists/LibriTTS/val-full.txt \
115
+ --list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
116
+ --list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
117
+ --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
118
+ ```
119
+
120
+ ## Synthesis
121
+
122
+ Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
123
+ It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
124
+
125
+ ```shell
126
+ python inference.py \
127
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
128
+ --input_wavs_dir /path/to/your/input_wav \
129
+ --output_dir /path/to/your/output_wav
130
+ ```
131
+
132
+ `inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
133
+ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
134
+
135
+ Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
136
+
137
+ ```shell
138
+ python inference_e2e.py \
139
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
140
+ --input_mels_dir /path/to/your/input_mel \
141
+ --output_dir /path/to/your/output_wav
142
+ ```
143
+
144
+ ## Using Custom CUDA Kernel for Synthesis
145
+
146
+ You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
147
+
148
+ ```python
149
+ generator = BigVGAN(h, use_cuda_kernel=True)
150
+ ```
151
+
152
+ You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
153
+
154
+ When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
155
+
156
+ Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
157
+
158
+ We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
159
+
160
+ ```python
161
+ python tests/test_cuda_vs_torch_model.py \
162
+ --checkpoint_file /path/to/your/bigvgan_generator.pt
163
+ ```
164
+
165
+ ```shell
166
+ loading plain Pytorch BigVGAN
167
+ ...
168
+ loading CUDA kernel BigVGAN with auto-build
169
+ Detected CUDA files, patching ldflags
170
+ Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
171
+ Building extension module anti_alias_activation_cuda...
172
+ ...
173
+ Loading extension module anti_alias_activation_cuda...
174
+ ...
175
+ Loading '/path/to/your/bigvgan_generator.pt'
176
+ ...
177
+ [Success] test CUDA fused vs. plain torch BigVGAN inference
178
+ > mean_difference=0.0007238413265440613
179
+ ...
180
+ ```
181
+
182
+ If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
183
+
184
+ ## Pretrained Models
185
+
186
+ We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
187
+ One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
188
+
189
+ | Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
190
+ |:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
191
+ | [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
192
+ | [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
193
+ | [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
194
+ | [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
195
+ | [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
196
+ | [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
197
+ | [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
198
+ | [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
199
+ | [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
200
+
201
+ The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
202
+ We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
203
+ Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
204
+
205
+ You can fine-tune the models by:
206
+
207
+ 1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
208
+ 2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
209
+
210
+ ## Training Details of BigVGAN-v2
211
+
212
+ Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
213
+
214
+ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
215
+
216
+ When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
217
+
218
+ ## Evaluation Results of BigVGAN-v2
219
+
220
+ Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
221
+
222
+ | Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
223
+ |:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
224
+ | BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
225
+ | BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
226
+ | BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
227
+ | BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
228
+
229
+ ## Speed Benchmark
230
+
231
+ Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
232
+
233
+ | GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
234
+ |:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
235
+ | NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
236
+ | | | True | 3916.5 | 163.2x | 1.3 |
237
+ | | 2048 | False | 1899.6 | 79.2x | 1.7 |
238
+ | | | True | 5330.1 | 222.1x | 1.7 |
239
+ | | 16384 | False | 1973.8 | 82.2x | 5.0 |
240
+ | | | True | 5761.7 | 240.1x | 4.4 |
241
+ | NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
242
+ | | | True | 1598.1 | 66.6x | 1.3 |
243
+ | | 2048 | False | 929.9 | 38.7x | 1.7 |
244
+ | | | True | 1971.3 | 82.1x | 1.6 |
245
+ | | 16384 | False | 943.4 | 39.3x | 5.0 |
246
+ | | | True | 2026.5 | 84.4x | 3.9 |
247
+ | NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
248
+ | | | True | 811.3 | 33.8x | 1.3 |
249
+ | | 2048 | False | 576.5 | 24.0x | 1.7 |
250
+ | | | True | 1023.0 | 42.6x | 1.5 |
251
+ | | 16384 | False | 589.4 | 24.6x | 5.0 |
252
+ | | | True | 1068.1 | 44.5x | 3.2 |
253
+
254
+ ## Acknowledgements
255
+
256
+ We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
257
+
258
+ ## References
259
+
260
+ - [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
261
+ - [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
262
+ - [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
263
+ - [Julius](https://github.com/adefossez/julius) (for low-pass filter)
264
+ - [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
265
+ - [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
266
+ - [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
GPT_SoVITS/BigVGAN/activations.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
27
+ """
28
+ Initialization.
29
+ INPUT:
30
+ - in_features: shape of the input
31
+ - alpha: trainable parameter
32
+ alpha is initialized to 1 by default, higher values = higher-frequency.
33
+ alpha will be trained along with the rest of your model.
34
+ """
35
+ super(Snake, self).__init__()
36
+ self.in_features = in_features
37
+
38
+ # Initialize alpha
39
+ self.alpha_logscale = alpha_logscale
40
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
41
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
42
+ else: # Linear scale alphas initialized to ones
43
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
44
+
45
+ self.alpha.requires_grad = alpha_trainable
46
+
47
+ self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ """
51
+ Forward pass of the function.
52
+ Applies the function to the input elementwise.
53
+ Snake ∶= x + 1/a * sin^2 (xa)
54
+ """
55
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
56
+ if self.alpha_logscale:
57
+ alpha = torch.exp(alpha)
58
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
59
+
60
+ return x
61
+
62
+
63
+ class SnakeBeta(nn.Module):
64
+ """
65
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
66
+ Shape:
67
+ - Input: (B, C, T)
68
+ - Output: (B, C, T), same shape as the input
69
+ Parameters:
70
+ - alpha - trainable parameter that controls frequency
71
+ - beta - trainable parameter that controls magnitude
72
+ References:
73
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
74
+ https://arxiv.org/abs/2006.08195
75
+ Examples:
76
+ >>> a1 = snakebeta(256)
77
+ >>> x = torch.randn(256)
78
+ >>> x = a1(x)
79
+ """
80
+
81
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
82
+ """
83
+ Initialization.
84
+ INPUT:
85
+ - in_features: shape of the input
86
+ - alpha - trainable parameter that controls frequency
87
+ - beta - trainable parameter that controls magnitude
88
+ alpha is initialized to 1 by default, higher values = higher-frequency.
89
+ beta is initialized to 1 by default, higher values = higher-magnitude.
90
+ alpha will be trained along with the rest of your model.
91
+ """
92
+ super(SnakeBeta, self).__init__()
93
+ self.in_features = in_features
94
+
95
+ # Initialize alpha
96
+ self.alpha_logscale = alpha_logscale
97
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
98
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
99
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
100
+ else: # Linear scale alphas initialized to ones
101
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
102
+ self.beta = Parameter(torch.ones(in_features) * alpha)
103
+
104
+ self.alpha.requires_grad = alpha_trainable
105
+ self.beta.requires_grad = alpha_trainable
106
+
107
+ self.no_div_by_zero = 0.000000001
108
+
109
+ def forward(self, x):
110
+ """
111
+ Forward pass of the function.
112
+ Applies the function to the input elementwise.
113
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
114
+ """
115
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
116
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
117
+ if self.alpha_logscale:
118
+ alpha = torch.exp(alpha)
119
+ beta = torch.exp(beta)
120
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
121
+
122
+ return x
GPT_SoVITS/BigVGAN/bigvgan.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+
17
+ from . import activations
18
+ from .utils0 import init_weights, get_padding
19
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
+ from .env import AttrDict
21
+
22
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
+
24
+
25
+ def load_hparams_from_json(path) -> AttrDict:
26
+ with open(path) as f:
27
+ data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.h = h
55
+
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs1.apply(init_weights)
72
+
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
+ self.convs2.apply(init_weights)
89
+
90
+ self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
91
+
92
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
93
+ if self.h.get("use_cuda_kernel", False):
94
+ from .alias_free_activation.cuda.activation1d import (
95
+ Activation1d as CudaActivation1d,
96
+ )
97
+
98
+ Activation1d = CudaActivation1d
99
+ else:
100
+ Activation1d = TorchActivation1d
101
+
102
+ # Activation functions
103
+ if activation == "snake":
104
+ self.activations = nn.ModuleList(
105
+ [
106
+ Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
107
+ for _ in range(self.num_layers)
108
+ ]
109
+ )
110
+ elif activation == "snakebeta":
111
+ self.activations = nn.ModuleList(
112
+ [
113
+ Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
114
+ for _ in range(self.num_layers)
115
+ ]
116
+ )
117
+ else:
118
+ raise NotImplementedError(
119
+ "activation incorrectly specified. check the config file and look for 'activation'."
120
+ )
121
+
122
+ def forward(self, x):
123
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
124
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
125
+ xt = a1(x)
126
+ xt = c1(xt)
127
+ xt = a2(xt)
128
+ xt = c2(xt)
129
+ x = xt + x
130
+
131
+ return x
132
+
133
+ def remove_weight_norm(self):
134
+ for l in self.convs1:
135
+ remove_weight_norm(l)
136
+ for l in self.convs2:
137
+ remove_weight_norm(l)
138
+
139
+
140
+ class AMPBlock2(torch.nn.Module):
141
+ """
142
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
143
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
144
+
145
+ Args:
146
+ h (AttrDict): Hyperparameters.
147
+ channels (int): Number of convolution channels.
148
+ kernel_size (int): Size of the convolution kernel. Default is 3.
149
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
150
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ h: AttrDict,
156
+ channels: int,
157
+ kernel_size: int = 3,
158
+ dilation: tuple = (1, 3, 5),
159
+ activation: str = None,
160
+ ):
161
+ super().__init__()
162
+
163
+ self.h = h
164
+
165
+ self.convs = nn.ModuleList(
166
+ [
167
+ weight_norm(
168
+ Conv1d(
169
+ channels,
170
+ channels,
171
+ kernel_size,
172
+ stride=1,
173
+ dilation=d,
174
+ padding=get_padding(kernel_size, d),
175
+ )
176
+ )
177
+ for d in dilation
178
+ ]
179
+ )
180
+ self.convs.apply(init_weights)
181
+
182
+ self.num_layers = len(self.convs) # Total number of conv layers
183
+
184
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
185
+ if self.h.get("use_cuda_kernel", False):
186
+ from .alias_free_activation.cuda.activation1d import (
187
+ Activation1d as CudaActivation1d,
188
+ )
189
+
190
+ Activation1d = CudaActivation1d
191
+ else:
192
+ Activation1d = TorchActivation1d
193
+
194
+ # Activation functions
195
+ if activation == "snake":
196
+ self.activations = nn.ModuleList(
197
+ [
198
+ Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
199
+ for _ in range(self.num_layers)
200
+ ]
201
+ )
202
+ elif activation == "snakebeta":
203
+ self.activations = nn.ModuleList(
204
+ [
205
+ Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
206
+ for _ in range(self.num_layers)
207
+ ]
208
+ )
209
+ else:
210
+ raise NotImplementedError(
211
+ "activation incorrectly specified. check the config file and look for 'activation'."
212
+ )
213
+
214
+ def forward(self, x):
215
+ for c, a in zip(self.convs, self.activations):
216
+ xt = a(x)
217
+ xt = c(xt)
218
+ x = xt + x
219
+ return x
220
+
221
+ def remove_weight_norm(self):
222
+ for l in self.convs:
223
+ remove_weight_norm(l)
224
+
225
+
226
+ class BigVGAN(
227
+ torch.nn.Module,
228
+ PyTorchModelHubMixin,
229
+ # library_name="bigvgan",
230
+ # repo_url="https://github.com/NVIDIA/BigVGAN",
231
+ # docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
232
+ # pipeline_tag="audio-to-audio",
233
+ # license="mit",
234
+ # tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
235
+ ):
236
+ """
237
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
238
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
239
+
240
+ Args:
241
+ h (AttrDict): Hyperparameters.
242
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
243
+
244
+ Note:
245
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
246
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
247
+ """
248
+
249
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
250
+ super().__init__()
251
+ self.h = h
252
+ self.h["use_cuda_kernel"] = use_cuda_kernel
253
+
254
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
255
+ if self.h.get("use_cuda_kernel", False):
256
+ from .alias_free_activation.cuda.activation1d import (
257
+ Activation1d as CudaActivation1d,
258
+ )
259
+
260
+ Activation1d = CudaActivation1d
261
+ else:
262
+ Activation1d = TorchActivation1d
263
+
264
+ self.num_kernels = len(h.resblock_kernel_sizes)
265
+ self.num_upsamples = len(h.upsample_rates)
266
+
267
+ # Pre-conv
268
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
269
+
270
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
271
+ if h.resblock == "1":
272
+ resblock_class = AMPBlock1
273
+ elif h.resblock == "2":
274
+ resblock_class = AMPBlock2
275
+ else:
276
+ raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
277
+
278
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
279
+ self.ups = nn.ModuleList()
280
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
281
+ self.ups.append(
282
+ nn.ModuleList(
283
+ [
284
+ weight_norm(
285
+ ConvTranspose1d(
286
+ h.upsample_initial_channel // (2**i),
287
+ h.upsample_initial_channel // (2 ** (i + 1)),
288
+ k,
289
+ u,
290
+ padding=(k - u) // 2,
291
+ )
292
+ )
293
+ ]
294
+ )
295
+ )
296
+
297
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
298
+ self.resblocks = nn.ModuleList()
299
+ for i in range(len(self.ups)):
300
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
301
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
302
+ self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
303
+
304
+ # Post-conv
305
+ activation_post = (
306
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
307
+ if h.activation == "snake"
308
+ else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
309
+ )
310
+ if activation_post is None:
311
+ raise NotImplementedError(
312
+ "activation incorrectly specified. check the config file and look for 'activation'."
313
+ )
314
+
315
+ self.activation_post = Activation1d(activation=activation_post)
316
+
317
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
318
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
319
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
320
+
321
+ # Weight initialization
322
+ for i in range(len(self.ups)):
323
+ self.ups[i].apply(init_weights)
324
+ self.conv_post.apply(init_weights)
325
+
326
+ # Final tanh activation. Defaults to True for backward compatibility
327
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
328
+
329
+ def forward(self, x):
330
+ # Pre-conv
331
+ x = self.conv_pre(x)
332
+
333
+ for i in range(self.num_upsamples):
334
+ # Upsampling
335
+ for i_up in range(len(self.ups[i])):
336
+ x = self.ups[i][i_up](x)
337
+ # AMP blocks
338
+ xs = None
339
+ for j in range(self.num_kernels):
340
+ if xs is None:
341
+ xs = self.resblocks[i * self.num_kernels + j](x)
342
+ else:
343
+ xs += self.resblocks[i * self.num_kernels + j](x)
344
+ x = xs / self.num_kernels
345
+
346
+ # Post-conv
347
+ x = self.activation_post(x)
348
+ x = self.conv_post(x)
349
+ # Final tanh activation
350
+ if self.use_tanh_at_final:
351
+ x = torch.tanh(x)
352
+ else:
353
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
354
+
355
+ return x
356
+
357
+ def remove_weight_norm(self):
358
+ try:
359
+ # print("Removing weight norm...")
360
+ for l in self.ups:
361
+ for l_i in l:
362
+ remove_weight_norm(l_i)
363
+ for l in self.resblocks:
364
+ l.remove_weight_norm()
365
+ remove_weight_norm(self.conv_pre)
366
+ remove_weight_norm(self.conv_post)
367
+ except ValueError:
368
+ print("[INFO] Model already removed weight norm. Skipping!")
369
+ pass
370
+
371
+ # Additional methods for huggingface_hub support
372
+ def _save_pretrained(self, save_directory: Path) -> None:
373
+ """Save weights and config.json from a Pytorch model to a local directory."""
374
+
375
+ model_path = save_directory / "bigvgan_generator.pt"
376
+ torch.save({"generator": self.state_dict()}, model_path)
377
+
378
+ config_path = save_directory / "config.json"
379
+ with open(config_path, "w") as config_file:
380
+ json.dump(self.h, config_file, indent=4)
381
+
382
+ @classmethod
383
+ def _from_pretrained(
384
+ cls,
385
+ *,
386
+ model_id: str,
387
+ revision: str,
388
+ cache_dir: str,
389
+ force_download: bool,
390
+ proxies: Optional[Dict],
391
+ resume_download: bool,
392
+ local_files_only: bool,
393
+ token: Union[str, bool, None],
394
+ map_location: str = "cpu", # Additional argument
395
+ strict: bool = False, # Additional argument
396
+ use_cuda_kernel: bool = False,
397
+ **model_kwargs,
398
+ ):
399
+ """Load Pytorch pretrained weights and return the loaded model."""
400
+
401
+ # Download and load hyperparameters (h) used by BigVGAN
402
+ if os.path.isdir(model_id):
403
+ # print("Loading config.json from local directory")
404
+ config_file = os.path.join(model_id, "config.json")
405
+ else:
406
+ config_file = hf_hub_download(
407
+ repo_id=model_id,
408
+ filename="config.json",
409
+ revision=revision,
410
+ cache_dir=cache_dir,
411
+ force_download=force_download,
412
+ proxies=proxies,
413
+ resume_download=resume_download,
414
+ token=token,
415
+ local_files_only=local_files_only,
416
+ )
417
+ h = load_hparams_from_json(config_file)
418
+
419
+ # instantiate BigVGAN using h
420
+ if use_cuda_kernel:
421
+ print(
422
+ "[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
423
+ )
424
+ print(
425
+ "[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
426
+ )
427
+ print(
428
+ "[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
429
+ )
430
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
431
+
432
+ # Download and load pretrained generator weight
433
+ if os.path.isdir(model_id):
434
+ # print("Loading weights from local directory")
435
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
436
+ else:
437
+ # print(f"Loading weights from {model_id}")
438
+ model_file = hf_hub_download(
439
+ repo_id=model_id,
440
+ filename="bigvgan_generator.pt",
441
+ revision=revision,
442
+ cache_dir=cache_dir,
443
+ force_download=force_download,
444
+ proxies=proxies,
445
+ resume_download=resume_download,
446
+ token=token,
447
+ local_files_only=local_files_only,
448
+ )
449
+
450
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
451
+
452
+ try:
453
+ model.load_state_dict(checkpoint_dict["generator"])
454
+ except RuntimeError:
455
+ print(
456
+ "[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
457
+ )
458
+ model.remove_weight_norm()
459
+ model.load_state_dict(checkpoint_dict["generator"])
460
+
461
+ return model
GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 100,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 24000,
33
+
34
+ "fmin": 0,
35
+ "fmax": 12000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 80,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 22050,
33
+
34
+ "fmin": 0,
35
+ "fmax": 8000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 32,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "activation": "snakebeta",
18
+ "snake_logscale": true,
19
+
20
+ "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
21
+ "mpd_reshapes": [2, 3, 5, 7, 11],
22
+ "use_spectral_norm": false,
23
+ "discriminator_channel_mult": 1,
24
+
25
+ "segment_size": 8192,
26
+ "num_mels": 100,
27
+ "num_freq": 1025,
28
+ "n_fft": 1024,
29
+ "hop_size": 256,
30
+ "win_size": 1024,
31
+
32
+ "sampling_rate": 24000,
33
+
34
+ "fmin": 0,
35
+ "fmax": 12000,
36
+ "fmax_for_loss": null,
37
+
38
+ "num_workers": 4,
39
+
40
+ "dist_config": {
41
+ "dist_backend": "nccl",
42
+ "dist_url": "tcp://localhost:54321",
43
+ "world_size": 1
44
+ }
45
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_fmax8k_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 80,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 22050,
49
+
50
+ "fmin": 0,
51
+ "fmax": 8000,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_24khz_100band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 100,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 24000,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_256x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [4,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [8,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 128,
43
+ "num_freq": 1025,
44
+ "n_fft": 1024,
45
+ "hop_size": 256,
46
+ "win_size": 1024,
47
+
48
+ "sampling_rate": 44100,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_44khz_128band_512x.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 4,
5
+ "learning_rate": 0.0001,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.9999996,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,4,2,2,2,2],
12
+ "upsample_kernel_sizes": [16,8,4,4,4,4],
13
+ "upsample_initial_channel": 1536,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+
17
+ "use_tanh_at_final": false,
18
+ "use_bias_at_final": false,
19
+
20
+ "activation": "snakebeta",
21
+ "snake_logscale": true,
22
+
23
+ "use_cqtd_instead_of_mrd": true,
24
+ "cqtd_filters": 128,
25
+ "cqtd_max_filters": 1024,
26
+ "cqtd_filters_scale": 1,
27
+ "cqtd_dilations": [1, 2, 4],
28
+ "cqtd_hop_lengths": [512, 256, 256],
29
+ "cqtd_n_octaves": [9, 9, 9],
30
+ "cqtd_bins_per_octaves": [24, 36, 48],
31
+
32
+ "mpd_reshapes": [2, 3, 5, 7, 11],
33
+ "use_spectral_norm": false,
34
+ "discriminator_channel_mult": 1,
35
+
36
+ "use_multiscale_melloss": true,
37
+ "lambda_melloss": 15,
38
+
39
+ "clip_grad_norm": 500,
40
+
41
+ "segment_size": 65536,
42
+ "num_mels": 128,
43
+ "num_freq": 2049,
44
+ "n_fft": 2048,
45
+ "hop_size": 512,
46
+ "win_size": 2048,
47
+
48
+ "sampling_rate": 44100,
49
+
50
+ "fmin": 0,
51
+ "fmax": null,
52
+ "fmax_for_loss": null,
53
+
54
+ "num_workers": 4,
55
+
56
+ "dist_config": {
57
+ "dist_backend": "nccl",
58
+ "dist_url": "tcp://localhost:54321",
59
+ "world_size": 1
60
+ }
61
+ }
GPT_SoVITS/BigVGAN/discriminators.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.nn as nn
11
+ from torch.nn import Conv2d
12
+ from torch.nn.utils import weight_norm, spectral_norm
13
+ from torchaudio.transforms import Spectrogram, Resample
14
+
15
+ from env import AttrDict
16
+ from utils import get_padding
17
+ import typing
18
+ from typing import List, Tuple
19
+
20
+
21
+ class DiscriminatorP(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ h: AttrDict,
25
+ period: List[int],
26
+ kernel_size: int = 5,
27
+ stride: int = 3,
28
+ use_spectral_norm: bool = False,
29
+ ):
30
+ super().__init__()
31
+ self.period = period
32
+ self.d_mult = h.discriminator_channel_mult
33
+ norm_f = weight_norm if not use_spectral_norm else spectral_norm
34
+
35
+ self.convs = nn.ModuleList(
36
+ [
37
+ norm_f(
38
+ Conv2d(
39
+ 1,
40
+ int(32 * self.d_mult),
41
+ (kernel_size, 1),
42
+ (stride, 1),
43
+ padding=(get_padding(5, 1), 0),
44
+ )
45
+ ),
46
+ norm_f(
47
+ Conv2d(
48
+ int(32 * self.d_mult),
49
+ int(128 * self.d_mult),
50
+ (kernel_size, 1),
51
+ (stride, 1),
52
+ padding=(get_padding(5, 1), 0),
53
+ )
54
+ ),
55
+ norm_f(
56
+ Conv2d(
57
+ int(128 * self.d_mult),
58
+ int(512 * self.d_mult),
59
+ (kernel_size, 1),
60
+ (stride, 1),
61
+ padding=(get_padding(5, 1), 0),
62
+ )
63
+ ),
64
+ norm_f(
65
+ Conv2d(
66
+ int(512 * self.d_mult),
67
+ int(1024 * self.d_mult),
68
+ (kernel_size, 1),
69
+ (stride, 1),
70
+ padding=(get_padding(5, 1), 0),
71
+ )
72
+ ),
73
+ norm_f(
74
+ Conv2d(
75
+ int(1024 * self.d_mult),
76
+ int(1024 * self.d_mult),
77
+ (kernel_size, 1),
78
+ 1,
79
+ padding=(2, 0),
80
+ )
81
+ ),
82
+ ]
83
+ )
84
+ self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
85
+
86
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
87
+ fmap = []
88
+
89
+ # 1d to 2d
90
+ b, c, t = x.shape
91
+ if t % self.period != 0: # pad first
92
+ n_pad = self.period - (t % self.period)
93
+ x = F.pad(x, (0, n_pad), "reflect")
94
+ t = t + n_pad
95
+ x = x.view(b, c, t // self.period, self.period)
96
+
97
+ for l in self.convs:
98
+ x = l(x)
99
+ x = F.leaky_relu(x, 0.1)
100
+ fmap.append(x)
101
+ x = self.conv_post(x)
102
+ fmap.append(x)
103
+ x = torch.flatten(x, 1, -1)
104
+
105
+ return x, fmap
106
+
107
+
108
+ class MultiPeriodDiscriminator(torch.nn.Module):
109
+ def __init__(self, h: AttrDict):
110
+ super().__init__()
111
+ self.mpd_reshapes = h.mpd_reshapes
112
+ print(f"mpd_reshapes: {self.mpd_reshapes}")
113
+ self.discriminators = nn.ModuleList(
114
+ [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
115
+ )
116
+
117
+ def forward(
118
+ self, y: torch.Tensor, y_hat: torch.Tensor
119
+ ) -> Tuple[
120
+ List[torch.Tensor],
121
+ List[torch.Tensor],
122
+ List[List[torch.Tensor]],
123
+ List[List[torch.Tensor]],
124
+ ]:
125
+ y_d_rs = []
126
+ y_d_gs = []
127
+ fmap_rs = []
128
+ fmap_gs = []
129
+ for i, d in enumerate(self.discriminators):
130
+ y_d_r, fmap_r = d(y)
131
+ y_d_g, fmap_g = d(y_hat)
132
+ y_d_rs.append(y_d_r)
133
+ fmap_rs.append(fmap_r)
134
+ y_d_gs.append(y_d_g)
135
+ fmap_gs.append(fmap_g)
136
+
137
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
138
+
139
+
140
+ class DiscriminatorR(nn.Module):
141
+ def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
142
+ super().__init__()
143
+
144
+ self.resolution = resolution
145
+ assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
146
+ self.lrelu_slope = 0.1
147
+
148
+ norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
149
+ if hasattr(cfg, "mrd_use_spectral_norm"):
150
+ print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
151
+ norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
152
+ self.d_mult = cfg.discriminator_channel_mult
153
+ if hasattr(cfg, "mrd_channel_mult"):
154
+ print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
155
+ self.d_mult = cfg.mrd_channel_mult
156
+
157
+ self.convs = nn.ModuleList(
158
+ [
159
+ norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
160
+ norm_f(
161
+ nn.Conv2d(
162
+ int(32 * self.d_mult),
163
+ int(32 * self.d_mult),
164
+ (3, 9),
165
+ stride=(1, 2),
166
+ padding=(1, 4),
167
+ )
168
+ ),
169
+ norm_f(
170
+ nn.Conv2d(
171
+ int(32 * self.d_mult),
172
+ int(32 * self.d_mult),
173
+ (3, 9),
174
+ stride=(1, 2),
175
+ padding=(1, 4),
176
+ )
177
+ ),
178
+ norm_f(
179
+ nn.Conv2d(
180
+ int(32 * self.d_mult),
181
+ int(32 * self.d_mult),
182
+ (3, 9),
183
+ stride=(1, 2),
184
+ padding=(1, 4),
185
+ )
186
+ ),
187
+ norm_f(
188
+ nn.Conv2d(
189
+ int(32 * self.d_mult),
190
+ int(32 * self.d_mult),
191
+ (3, 3),
192
+ padding=(1, 1),
193
+ )
194
+ ),
195
+ ]
196
+ )
197
+ self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
198
+
199
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
200
+ fmap = []
201
+
202
+ x = self.spectrogram(x)
203
+ x = x.unsqueeze(1)
204
+ for l in self.convs:
205
+ x = l(x)
206
+ x = F.leaky_relu(x, self.lrelu_slope)
207
+ fmap.append(x)
208
+ x = self.conv_post(x)
209
+ fmap.append(x)
210
+ x = torch.flatten(x, 1, -1)
211
+
212
+ return x, fmap
213
+
214
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
215
+ n_fft, hop_length, win_length = self.resolution
216
+ x = F.pad(
217
+ x,
218
+ (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
219
+ mode="reflect",
220
+ )
221
+ x = x.squeeze(1)
222
+ x = torch.stft(
223
+ x,
224
+ n_fft=n_fft,
225
+ hop_length=hop_length,
226
+ win_length=win_length,
227
+ center=False,
228
+ return_complex=True,
229
+ )
230
+ x = torch.view_as_real(x) # [B, F, TT, 2]
231
+ mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
232
+
233
+ return mag
234
+
235
+
236
+ class MultiResolutionDiscriminator(nn.Module):
237
+ def __init__(self, cfg, debug=False):
238
+ super().__init__()
239
+ self.resolutions = cfg.resolutions
240
+ assert len(self.resolutions) == 3, (
241
+ f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
242
+ )
243
+ self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
244
+
245
+ def forward(
246
+ self, y: torch.Tensor, y_hat: torch.Tensor
247
+ ) -> Tuple[
248
+ List[torch.Tensor],
249
+ List[torch.Tensor],
250
+ List[List[torch.Tensor]],
251
+ List[List[torch.Tensor]],
252
+ ]:
253
+ y_d_rs = []
254
+ y_d_gs = []
255
+ fmap_rs = []
256
+ fmap_gs = []
257
+
258
+ for i, d in enumerate(self.discriminators):
259
+ y_d_r, fmap_r = d(x=y)
260
+ y_d_g, fmap_g = d(x=y_hat)
261
+ y_d_rs.append(y_d_r)
262
+ fmap_rs.append(fmap_r)
263
+ y_d_gs.append(y_d_g)
264
+ fmap_gs.append(fmap_g)
265
+
266
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
267
+
268
+
269
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
270
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
271
+ # LICENSE is in incl_licenses directory.
272
+ class DiscriminatorB(nn.Module):
273
+ def __init__(
274
+ self,
275
+ window_length: int,
276
+ channels: int = 32,
277
+ hop_factor: float = 0.25,
278
+ bands: Tuple[Tuple[float, float], ...] = (
279
+ (0.0, 0.1),
280
+ (0.1, 0.25),
281
+ (0.25, 0.5),
282
+ (0.5, 0.75),
283
+ (0.75, 1.0),
284
+ ),
285
+ ):
286
+ super().__init__()
287
+ self.window_length = window_length
288
+ self.hop_factor = hop_factor
289
+ self.spec_fn = Spectrogram(
290
+ n_fft=window_length,
291
+ hop_length=int(window_length * hop_factor),
292
+ win_length=window_length,
293
+ power=None,
294
+ )
295
+ n_fft = window_length // 2 + 1
296
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
297
+ self.bands = bands
298
+ convs = lambda: nn.ModuleList(
299
+ [
300
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
301
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
302
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
303
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
304
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
305
+ ]
306
+ )
307
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
308
+
309
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
310
+
311
+ def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
312
+ # Remove DC offset
313
+ x = x - x.mean(dim=-1, keepdims=True)
314
+ # Peak normalize the volume of input audio
315
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
316
+ x = self.spec_fn(x)
317
+ x = torch.view_as_real(x)
318
+ x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
319
+ # Split into bands
320
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
321
+ return x_bands
322
+
323
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
324
+ x_bands = self.spectrogram(x.squeeze(1))
325
+ fmap = []
326
+ x = []
327
+
328
+ for band, stack in zip(x_bands, self.band_convs):
329
+ for i, layer in enumerate(stack):
330
+ band = layer(band)
331
+ band = torch.nn.functional.leaky_relu(band, 0.1)
332
+ if i > 0:
333
+ fmap.append(band)
334
+ x.append(band)
335
+
336
+ x = torch.cat(x, dim=-1)
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+
340
+ return x, fmap
341
+
342
+
343
+ # Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
344
+ # Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
345
+ # LICENSE is in incl_licenses directory.
346
+ class MultiBandDiscriminator(nn.Module):
347
+ def __init__(
348
+ self,
349
+ h,
350
+ ):
351
+ """
352
+ Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
353
+ and the modified code adapted from https://github.com/gemelo-ai/vocos.
354
+ """
355
+ super().__init__()
356
+ # fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
357
+ self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
358
+ self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
359
+
360
+ def forward(
361
+ self, y: torch.Tensor, y_hat: torch.Tensor
362
+ ) -> Tuple[
363
+ List[torch.Tensor],
364
+ List[torch.Tensor],
365
+ List[List[torch.Tensor]],
366
+ List[List[torch.Tensor]],
367
+ ]:
368
+ y_d_rs = []
369
+ y_d_gs = []
370
+ fmap_rs = []
371
+ fmap_gs = []
372
+
373
+ for d in self.discriminators:
374
+ y_d_r, fmap_r = d(x=y)
375
+ y_d_g, fmap_g = d(x=y_hat)
376
+ y_d_rs.append(y_d_r)
377
+ fmap_rs.append(fmap_r)
378
+ y_d_gs.append(y_d_g)
379
+ fmap_gs.append(fmap_g)
380
+
381
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
382
+
383
+
384
+ # Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
385
+ # LICENSE is in incl_licenses directory.
386
+ class DiscriminatorCQT(nn.Module):
387
+ def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
388
+ super().__init__()
389
+ self.cfg = cfg
390
+
391
+ self.filters = cfg["cqtd_filters"]
392
+ self.max_filters = cfg["cqtd_max_filters"]
393
+ self.filters_scale = cfg["cqtd_filters_scale"]
394
+ self.kernel_size = (3, 9)
395
+ self.dilations = cfg["cqtd_dilations"]
396
+ self.stride = (1, 2)
397
+
398
+ self.in_channels = cfg["cqtd_in_channels"]
399
+ self.out_channels = cfg["cqtd_out_channels"]
400
+ self.fs = cfg["sampling_rate"]
401
+ self.hop_length = hop_length
402
+ self.n_octaves = n_octaves
403
+ self.bins_per_octave = bins_per_octave
404
+
405
+ # Lazy-load
406
+ from nnAudio import features
407
+
408
+ self.cqt_transform = features.cqt.CQT2010v2(
409
+ sr=self.fs * 2,
410
+ hop_length=self.hop_length,
411
+ n_bins=self.bins_per_octave * self.n_octaves,
412
+ bins_per_octave=self.bins_per_octave,
413
+ output_format="Complex",
414
+ pad_mode="constant",
415
+ )
416
+
417
+ self.conv_pres = nn.ModuleList()
418
+ for _ in range(self.n_octaves):
419
+ self.conv_pres.append(
420
+ nn.Conv2d(
421
+ self.in_channels * 2,
422
+ self.in_channels * 2,
423
+ kernel_size=self.kernel_size,
424
+ padding=self.get_2d_padding(self.kernel_size),
425
+ )
426
+ )
427
+
428
+ self.convs = nn.ModuleList()
429
+
430
+ self.convs.append(
431
+ nn.Conv2d(
432
+ self.in_channels * 2,
433
+ self.filters,
434
+ kernel_size=self.kernel_size,
435
+ padding=self.get_2d_padding(self.kernel_size),
436
+ )
437
+ )
438
+
439
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
440
+ for i, dilation in enumerate(self.dilations):
441
+ out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
442
+ self.convs.append(
443
+ weight_norm(
444
+ nn.Conv2d(
445
+ in_chs,
446
+ out_chs,
447
+ kernel_size=self.kernel_size,
448
+ stride=self.stride,
449
+ dilation=(dilation, 1),
450
+ padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
451
+ )
452
+ )
453
+ )
454
+ in_chs = out_chs
455
+ out_chs = min(
456
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
457
+ self.max_filters,
458
+ )
459
+ self.convs.append(
460
+ weight_norm(
461
+ nn.Conv2d(
462
+ in_chs,
463
+ out_chs,
464
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
465
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
466
+ )
467
+ )
468
+ )
469
+
470
+ self.conv_post = weight_norm(
471
+ nn.Conv2d(
472
+ out_chs,
473
+ self.out_channels,
474
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
475
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
476
+ )
477
+ )
478
+
479
+ self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
480
+ self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
481
+
482
+ self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
483
+ if self.cqtd_normalize_volume:
484
+ print(
485
+ "[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
486
+ )
487
+
488
+ def get_2d_padding(
489
+ self,
490
+ kernel_size: typing.Tuple[int, int],
491
+ dilation: typing.Tuple[int, int] = (1, 1),
492
+ ):
493
+ return (
494
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
495
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
496
+ )
497
+
498
+ def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
499
+ fmap = []
500
+
501
+ if self.cqtd_normalize_volume:
502
+ # Remove DC offset
503
+ x = x - x.mean(dim=-1, keepdims=True)
504
+ # Peak normalize the volume of input audio
505
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
506
+
507
+ x = self.resample(x)
508
+
509
+ z = self.cqt_transform(x)
510
+
511
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
512
+ z_phase = z[:, :, :, 1].unsqueeze(1)
513
+
514
+ z = torch.cat([z_amplitude, z_phase], dim=1)
515
+ z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
516
+
517
+ latent_z = []
518
+ for i in range(self.n_octaves):
519
+ latent_z.append(
520
+ self.conv_pres[i](
521
+ z[
522
+ :,
523
+ :,
524
+ :,
525
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
526
+ ]
527
+ )
528
+ )
529
+ latent_z = torch.cat(latent_z, dim=-1)
530
+
531
+ for i, l in enumerate(self.convs):
532
+ latent_z = l(latent_z)
533
+
534
+ latent_z = self.activation(latent_z)
535
+ fmap.append(latent_z)
536
+
537
+ latent_z = self.conv_post(latent_z)
538
+
539
+ return latent_z, fmap
540
+
541
+
542
+ class MultiScaleSubbandCQTDiscriminator(nn.Module):
543
+ def __init__(self, cfg: AttrDict):
544
+ super().__init__()
545
+
546
+ self.cfg = cfg
547
+ # Using get with defaults
548
+ self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
549
+ self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
550
+ self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
551
+ self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
552
+ self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
553
+ self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
554
+ # Multi-scale params to loop over
555
+ self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
556
+ self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
557
+ self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
558
+
559
+ self.discriminators = nn.ModuleList(
560
+ [
561
+ DiscriminatorCQT(
562
+ self.cfg,
563
+ hop_length=self.cfg["cqtd_hop_lengths"][i],
564
+ n_octaves=self.cfg["cqtd_n_octaves"][i],
565
+ bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
566
+ )
567
+ for i in range(len(self.cfg["cqtd_hop_lengths"]))
568
+ ]
569
+ )
570
+
571
+ def forward(
572
+ self, y: torch.Tensor, y_hat: torch.Tensor
573
+ ) -> Tuple[
574
+ List[torch.Tensor],
575
+ List[torch.Tensor],
576
+ List[List[torch.Tensor]],
577
+ List[List[torch.Tensor]],
578
+ ]:
579
+ y_d_rs = []
580
+ y_d_gs = []
581
+ fmap_rs = []
582
+ fmap_gs = []
583
+
584
+ for disc in self.discriminators:
585
+ y_d_r, fmap_r = disc(y)
586
+ y_d_g, fmap_g = disc(y_hat)
587
+ y_d_rs.append(y_d_r)
588
+ fmap_rs.append(fmap_r)
589
+ y_d_gs.append(y_d_g)
590
+ fmap_gs.append(fmap_g)
591
+
592
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
593
+
594
+
595
+ class CombinedDiscriminator(nn.Module):
596
+ """
597
+ Wrapper of chaining multiple discrimiantor architectures.
598
+ Example: combine mbd and cqtd as a single class
599
+ """
600
+
601
+ def __init__(self, list_discriminator: List[nn.Module]):
602
+ super().__init__()
603
+ self.discrimiantor = nn.ModuleList(list_discriminator)
604
+
605
+ def forward(
606
+ self, y: torch.Tensor, y_hat: torch.Tensor
607
+ ) -> Tuple[
608
+ List[torch.Tensor],
609
+ List[torch.Tensor],
610
+ List[List[torch.Tensor]],
611
+ List[List[torch.Tensor]],
612
+ ]:
613
+ y_d_rs = []
614
+ y_d_gs = []
615
+ fmap_rs = []
616
+ fmap_gs = []
617
+
618
+ for disc in self.discrimiantor:
619
+ y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
620
+ y_d_rs.extend(y_d_r)
621
+ fmap_rs.extend(fmap_r)
622
+ y_d_gs.extend(y_d_g)
623
+ fmap_gs.extend(fmap_g)
624
+
625
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
GPT_SoVITS/BigVGAN/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
GPT_SoVITS/BigVGAN/inference.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from __future__ import absolute_import, division, print_function, unicode_literals
5
+
6
+ import os
7
+ import argparse
8
+ import json
9
+ import torch
10
+ import librosa
11
+ from utils import load_checkpoint
12
+ from meldataset import get_mel_spectrogram
13
+ from scipy.io.wavfile import write
14
+ from env import AttrDict
15
+ from meldataset import MAX_WAV_VALUE
16
+ from bigvgan import BigVGAN as Generator
17
+
18
+ h = None
19
+ device = None
20
+ torch.backends.cudnn.benchmark = False
21
+
22
+
23
+ def inference(a, h):
24
+ generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
25
+
26
+ state_dict_g = load_checkpoint(a.checkpoint_file, device)
27
+ generator.load_state_dict(state_dict_g["generator"])
28
+
29
+ filelist = os.listdir(a.input_wavs_dir)
30
+
31
+ os.makedirs(a.output_dir, exist_ok=True)
32
+
33
+ generator.eval()
34
+ generator.remove_weight_norm()
35
+ with torch.no_grad():
36
+ for i, filname in enumerate(filelist):
37
+ # Load the ground truth audio and resample if necessary
38
+ wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
39
+ wav = torch.FloatTensor(wav).to(device)
40
+ # Compute mel spectrogram from the ground truth audio
41
+ x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
42
+
43
+ y_g_hat = generator(x)
44
+
45
+ audio = y_g_hat.squeeze()
46
+ audio = audio * MAX_WAV_VALUE
47
+ audio = audio.cpu().numpy().astype("int16")
48
+
49
+ output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
50
+ write(output_file, h.sampling_rate, audio)
51
+ print(output_file)
52
+
53
+
54
+ def main():
55
+ print("Initializing Inference Process..")
56
+
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--input_wavs_dir", default="test_files")
59
+ parser.add_argument("--output_dir", default="generated_files")
60
+ parser.add_argument("--checkpoint_file", required=True)
61
+ parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
62
+
63
+ a = parser.parse_args()
64
+
65
+ config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
66
+ with open(config_file) as f:
67
+ data = f.read()
68
+
69
+ global h
70
+ json_config = json.loads(data)
71
+ h = AttrDict(json_config)
72
+
73
+ torch.manual_seed(h.seed)
74
+ global device
75
+ if torch.cuda.is_available():
76
+ torch.cuda.manual_seed(h.seed)
77
+ device = torch.device("cuda")
78
+ else:
79
+ device = torch.device("cpu")
80
+
81
+ inference(a, h)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
GPT_SoVITS/BigVGAN/inference_e2e.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from __future__ import absolute_import, division, print_function, unicode_literals
5
+
6
+ import glob
7
+ import os
8
+ import numpy as np
9
+ import argparse
10
+ import json
11
+ import torch
12
+ from scipy.io.wavfile import write
13
+ from env import AttrDict
14
+ from meldataset import MAX_WAV_VALUE
15
+ from bigvgan import BigVGAN as Generator
16
+
17
+ h = None
18
+ device = None
19
+ torch.backends.cudnn.benchmark = False
20
+
21
+
22
+ def load_checkpoint(filepath, device):
23
+ assert os.path.isfile(filepath)
24
+ print(f"Loading '{filepath}'")
25
+ checkpoint_dict = torch.load(filepath, map_location=device)
26
+ print("Complete.")
27
+ return checkpoint_dict
28
+
29
+
30
+ def scan_checkpoint(cp_dir, prefix):
31
+ pattern = os.path.join(cp_dir, prefix + "*")
32
+ cp_list = glob.glob(pattern)
33
+ if len(cp_list) == 0:
34
+ return ""
35
+ return sorted(cp_list)[-1]
36
+
37
+
38
+ def inference(a, h):
39
+ generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
40
+
41
+ state_dict_g = load_checkpoint(a.checkpoint_file, device)
42
+ generator.load_state_dict(state_dict_g["generator"])
43
+
44
+ filelist = os.listdir(a.input_mels_dir)
45
+
46
+ os.makedirs(a.output_dir, exist_ok=True)
47
+
48
+ generator.eval()
49
+ generator.remove_weight_norm()
50
+ with torch.no_grad():
51
+ for i, filname in enumerate(filelist):
52
+ # Load the mel spectrogram in .npy format
53
+ x = np.load(os.path.join(a.input_mels_dir, filname))
54
+ x = torch.FloatTensor(x).to(device)
55
+ if len(x.shape) == 2:
56
+ x = x.unsqueeze(0)
57
+
58
+ y_g_hat = generator(x)
59
+
60
+ audio = y_g_hat.squeeze()
61
+ audio = audio * MAX_WAV_VALUE
62
+ audio = audio.cpu().numpy().astype("int16")
63
+
64
+ output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
65
+ write(output_file, h.sampling_rate, audio)
66
+ print(output_file)
67
+
68
+
69
+ def main():
70
+ print("Initializing Inference Process..")
71
+
72
+ parser = argparse.ArgumentParser()
73
+ parser.add_argument("--input_mels_dir", default="test_mel_files")
74
+ parser.add_argument("--output_dir", default="generated_files_from_mel")
75
+ parser.add_argument("--checkpoint_file", required=True)
76
+ parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
77
+
78
+ a = parser.parse_args()
79
+
80
+ config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
81
+ with open(config_file) as f:
82
+ data = f.read()
83
+
84
+ global h
85
+ json_config = json.loads(data)
86
+ h = AttrDict(json_config)
87
+
88
+ torch.manual_seed(h.seed)
89
+ global device
90
+ if torch.cuda.is_available():
91
+ torch.cuda.manual_seed(h.seed)
92
+ device = torch.device("cuda")
93
+ else:
94
+ device = torch.device("cpu")
95
+
96
+ inference(a, h)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
GPT_SoVITS/BigVGAN/loss.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from librosa.filters import mel as librosa_mel_fn
11
+ from scipy import signal
12
+
13
+ import typing
14
+ from typing import List, Tuple
15
+ from collections import namedtuple
16
+ import math
17
+ import functools
18
+
19
+
20
+ # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
21
+ # LICENSE is in incl_licenses directory.
22
+ class MultiScaleMelSpectrogramLoss(nn.Module):
23
+ """Compute distance between mel spectrograms. Can be used
24
+ in a multi-scale way.
25
+
26
+ Parameters
27
+ ----------
28
+ n_mels : List[int]
29
+ Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
30
+ window_lengths : List[int], optional
31
+ Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
32
+ loss_fn : typing.Callable, optional
33
+ How to compare each loss, by default nn.L1Loss()
34
+ clamp_eps : float, optional
35
+ Clamp on the log magnitude, below, by default 1e-5
36
+ mag_weight : float, optional
37
+ Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
38
+ log_weight : float, optional
39
+ Weight of log magnitude portion of loss, by default 1.0
40
+ pow : float, optional
41
+ Power to raise magnitude to before taking log, by default 1.0
42
+ weight : float, optional
43
+ Weight of this loss, by default 1.0
44
+ match_stride : bool, optional
45
+ Whether to match the stride of convolutional layers, by default False
46
+
47
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
48
+ Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ sampling_rate: int,
54
+ n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
55
+ window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
56
+ loss_fn: typing.Callable = nn.L1Loss(),
57
+ clamp_eps: float = 1e-5,
58
+ mag_weight: float = 0.0,
59
+ log_weight: float = 1.0,
60
+ pow: float = 1.0,
61
+ weight: float = 1.0,
62
+ match_stride: bool = False,
63
+ mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
64
+ mel_fmax: List[float] = [None, None, None, None, None, None, None],
65
+ window_type: str = "hann",
66
+ ):
67
+ super().__init__()
68
+ self.sampling_rate = sampling_rate
69
+
70
+ STFTParams = namedtuple(
71
+ "STFTParams",
72
+ ["window_length", "hop_length", "window_type", "match_stride"],
73
+ )
74
+
75
+ self.stft_params = [
76
+ STFTParams(
77
+ window_length=w,
78
+ hop_length=w // 4,
79
+ match_stride=match_stride,
80
+ window_type=window_type,
81
+ )
82
+ for w in window_lengths
83
+ ]
84
+ self.n_mels = n_mels
85
+ self.loss_fn = loss_fn
86
+ self.clamp_eps = clamp_eps
87
+ self.log_weight = log_weight
88
+ self.mag_weight = mag_weight
89
+ self.weight = weight
90
+ self.mel_fmin = mel_fmin
91
+ self.mel_fmax = mel_fmax
92
+ self.pow = pow
93
+
94
+ @staticmethod
95
+ @functools.lru_cache(None)
96
+ def get_window(
97
+ window_type,
98
+ window_length,
99
+ ):
100
+ return signal.get_window(window_type, window_length)
101
+
102
+ @staticmethod
103
+ @functools.lru_cache(None)
104
+ def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
105
+ return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
106
+
107
+ def mel_spectrogram(
108
+ self,
109
+ wav,
110
+ n_mels,
111
+ fmin,
112
+ fmax,
113
+ window_length,
114
+ hop_length,
115
+ match_stride,
116
+ window_type,
117
+ ):
118
+ """
119
+ Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
120
+ https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
121
+ """
122
+ B, C, T = wav.shape
123
+
124
+ if match_stride:
125
+ assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
126
+ right_pad = math.ceil(T / hop_length) * hop_length - T
127
+ pad = (window_length - hop_length) // 2
128
+ else:
129
+ right_pad = 0
130
+ pad = 0
131
+
132
+ wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
133
+
134
+ window = self.get_window(window_type, window_length)
135
+ window = torch.from_numpy(window).to(wav.device).float()
136
+
137
+ stft = torch.stft(
138
+ wav.reshape(-1, T),
139
+ n_fft=window_length,
140
+ hop_length=hop_length,
141
+ window=window,
142
+ return_complex=True,
143
+ center=True,
144
+ )
145
+ _, nf, nt = stft.shape
146
+ stft = stft.reshape(B, C, nf, nt)
147
+ if match_stride:
148
+ """
149
+ Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
150
+ """
151
+ stft = stft[..., 2:-2]
152
+ magnitude = torch.abs(stft)
153
+
154
+ nf = magnitude.shape[2]
155
+ mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
156
+ mel_basis = torch.from_numpy(mel_basis).to(wav.device)
157
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
158
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
159
+
160
+ return mel_spectrogram
161
+
162
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
163
+ """Computes mel loss between an estimate and a reference
164
+ signal.
165
+
166
+ Parameters
167
+ ----------
168
+ x : torch.Tensor
169
+ Estimate signal
170
+ y : torch.Tensor
171
+ Reference signal
172
+
173
+ Returns
174
+ -------
175
+ torch.Tensor
176
+ Mel loss.
177
+ """
178
+
179
+ loss = 0.0
180
+ for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
181
+ kwargs = {
182
+ "n_mels": n_mels,
183
+ "fmin": fmin,
184
+ "fmax": fmax,
185
+ "window_length": s.window_length,
186
+ "hop_length": s.hop_length,
187
+ "match_stride": s.match_stride,
188
+ "window_type": s.window_type,
189
+ }
190
+
191
+ x_mels = self.mel_spectrogram(x, **kwargs)
192
+ y_mels = self.mel_spectrogram(y, **kwargs)
193
+ x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
194
+ y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
195
+
196
+ loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
197
+ loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
198
+
199
+ return loss
200
+
201
+
202
+ # Loss functions
203
+ def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
204
+ loss = 0
205
+ for dr, dg in zip(fmap_r, fmap_g):
206
+ for rl, gl in zip(dr, dg):
207
+ loss += torch.mean(torch.abs(rl - gl))
208
+
209
+ return loss * 2 # This equates to lambda=2.0 for the feature matching loss
210
+
211
+
212
+ def discriminator_loss(
213
+ disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
214
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
215
+ loss = 0
216
+ r_losses = []
217
+ g_losses = []
218
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
219
+ r_loss = torch.mean((1 - dr) ** 2)
220
+ g_loss = torch.mean(dg**2)
221
+ loss += r_loss + g_loss
222
+ r_losses.append(r_loss.item())
223
+ g_losses.append(g_loss.item())
224
+
225
+ return loss, r_losses, g_losses
226
+
227
+
228
+ def generator_loss(
229
+ disc_outputs: List[torch.Tensor],
230
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
231
+ loss = 0
232
+ gen_losses = []
233
+ for dg in disc_outputs:
234
+ l = torch.mean((1 - dg) ** 2)
235
+ gen_losses.append(l)
236
+ loss += l
237
+
238
+ return loss, gen_losses
GPT_SoVITS/BigVGAN/meldataset.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ import torch
11
+ import torch.utils.data
12
+ import numpy as np
13
+ import librosa
14
+ from librosa.filters import mel as librosa_mel_fn
15
+ import pathlib
16
+ from tqdm import tqdm
17
+ from typing import List, Tuple, Optional
18
+ from .env import AttrDict
19
+
20
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
21
+
22
+
23
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
24
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
25
+
26
+
27
+ def dynamic_range_decompression(x, C=1):
28
+ return np.exp(x) / C
29
+
30
+
31
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
32
+ return torch.log(torch.clamp(x, min=clip_val) * C)
33
+
34
+
35
+ def dynamic_range_decompression_torch(x, C=1):
36
+ return torch.exp(x) / C
37
+
38
+
39
+ def spectral_normalize_torch(magnitudes):
40
+ return dynamic_range_compression_torch(magnitudes)
41
+
42
+
43
+ def spectral_de_normalize_torch(magnitudes):
44
+ return dynamic_range_decompression_torch(magnitudes)
45
+
46
+
47
+ mel_basis_cache = {}
48
+ hann_window_cache = {}
49
+
50
+
51
+ def mel_spectrogram(
52
+ y: torch.Tensor,
53
+ n_fft: int,
54
+ num_mels: int,
55
+ sampling_rate: int,
56
+ hop_size: int,
57
+ win_size: int,
58
+ fmin: int,
59
+ fmax: int = None,
60
+ center: bool = False,
61
+ ) -> torch.Tensor:
62
+ """
63
+ Calculate the mel spectrogram of an input signal.
64
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
65
+
66
+ Args:
67
+ y (torch.Tensor): Input signal.
68
+ n_fft (int): FFT size.
69
+ num_mels (int): Number of mel bins.
70
+ sampling_rate (int): Sampling rate of the input signal.
71
+ hop_size (int): Hop size for STFT.
72
+ win_size (int): Window size for STFT.
73
+ fmin (int): Minimum frequency for mel filterbank.
74
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
75
+ center (bool): Whether to pad the input to center the frames. Default is False.
76
+
77
+ Returns:
78
+ torch.Tensor: Mel spectrogram.
79
+ """
80
+ if torch.min(y) < -1.0:
81
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
82
+ if torch.max(y) > 1.0:
83
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
84
+
85
+ device = y.device
86
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
87
+
88
+ if key not in mel_basis_cache:
89
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
90
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
91
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
92
+
93
+ mel_basis = mel_basis_cache[key]
94
+ hann_window = hann_window_cache[key]
95
+
96
+ padding = (n_fft - hop_size) // 2
97
+ y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
98
+
99
+ spec = torch.stft(
100
+ y,
101
+ n_fft,
102
+ hop_length=hop_size,
103
+ win_length=win_size,
104
+ window=hann_window,
105
+ center=center,
106
+ pad_mode="reflect",
107
+ normalized=False,
108
+ onesided=True,
109
+ return_complex=True,
110
+ )
111
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
112
+
113
+ mel_spec = torch.matmul(mel_basis, spec)
114
+ mel_spec = spectral_normalize_torch(mel_spec)
115
+
116
+ return mel_spec
117
+
118
+
119
+ def get_mel_spectrogram(wav, h):
120
+ """
121
+ Generate mel spectrogram from a waveform using given hyperparameters.
122
+
123
+ Args:
124
+ wav (torch.Tensor): Input waveform.
125
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
126
+
127
+ Returns:
128
+ torch.Tensor: Mel spectrogram.
129
+ """
130
+ return mel_spectrogram(
131
+ wav,
132
+ h.n_fft,
133
+ h.num_mels,
134
+ h.sampling_rate,
135
+ h.hop_size,
136
+ h.win_size,
137
+ h.fmin,
138
+ h.fmax,
139
+ )
140
+
141
+
142
+ def get_dataset_filelist(a):
143
+ training_files = []
144
+ validation_files = []
145
+ list_unseen_validation_files = []
146
+
147
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
148
+ training_files = [
149
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
150
+ ]
151
+ print(f"first training file: {training_files[0]}")
152
+
153
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
154
+ validation_files = [
155
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
156
+ ]
157
+ print(f"first validation file: {validation_files[0]}")
158
+
159
+ for i in range(len(a.list_input_unseen_validation_file)):
160
+ with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
161
+ unseen_validation_files = [
162
+ os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
163
+ for x in fi.read().split("\n")
164
+ if len(x) > 0
165
+ ]
166
+ print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
167
+ list_unseen_validation_files.append(unseen_validation_files)
168
+
169
+ return training_files, validation_files, list_unseen_validation_files
170
+
171
+
172
+ class MelDataset(torch.utils.data.Dataset):
173
+ def __init__(
174
+ self,
175
+ training_files: List[str],
176
+ hparams: AttrDict,
177
+ segment_size: int,
178
+ n_fft: int,
179
+ num_mels: int,
180
+ hop_size: int,
181
+ win_size: int,
182
+ sampling_rate: int,
183
+ fmin: int,
184
+ fmax: Optional[int],
185
+ split: bool = True,
186
+ shuffle: bool = True,
187
+ device: str = None,
188
+ fmax_loss: Optional[int] = None,
189
+ fine_tuning: bool = False,
190
+ base_mels_path: str = None,
191
+ is_seen: bool = True,
192
+ ):
193
+ self.audio_files = training_files
194
+ random.seed(1234)
195
+ if shuffle:
196
+ random.shuffle(self.audio_files)
197
+ self.hparams = hparams
198
+ self.is_seen = is_seen
199
+ if self.is_seen:
200
+ self.name = pathlib.Path(self.audio_files[0]).parts[0]
201
+ else:
202
+ self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
203
+
204
+ self.segment_size = segment_size
205
+ self.sampling_rate = sampling_rate
206
+ self.split = split
207
+ self.n_fft = n_fft
208
+ self.num_mels = num_mels
209
+ self.hop_size = hop_size
210
+ self.win_size = win_size
211
+ self.fmin = fmin
212
+ self.fmax = fmax
213
+ self.fmax_loss = fmax_loss
214
+ self.device = device
215
+ self.fine_tuning = fine_tuning
216
+ self.base_mels_path = base_mels_path
217
+
218
+ print("[INFO] checking dataset integrity...")
219
+ for i in tqdm(range(len(self.audio_files))):
220
+ assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
221
+
222
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
223
+ try:
224
+ filename = self.audio_files[index]
225
+
226
+ # Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
227
+ # Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
228
+ # The on-the-fly resampling during training will be done only for the obtained random chunk
229
+ audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
230
+
231
+ # Main logic that uses <mel, audio> pair for training BigVGAN
232
+ if not self.fine_tuning:
233
+ if self.split: # Training step
234
+ # Obtain randomized audio chunk
235
+ if source_sampling_rate != self.sampling_rate:
236
+ # Adjust segment size to crop if the source sr is different
237
+ target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
238
+ else:
239
+ target_segment_size = self.segment_size
240
+
241
+ # Compute upper bound index for the random chunk
242
+ random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
243
+
244
+ # Crop or pad audio to obtain random chunk with target_segment_size
245
+ if audio.shape[0] >= target_segment_size:
246
+ audio_start = random.randint(0, random_chunk_upper_bound)
247
+ audio = audio[audio_start : audio_start + target_segment_size]
248
+ else:
249
+ audio = np.pad(
250
+ audio,
251
+ (0, target_segment_size - audio.shape[0]),
252
+ mode="constant",
253
+ )
254
+
255
+ # Resample audio chunk to self.sampling rate
256
+ if source_sampling_rate != self.sampling_rate:
257
+ audio = librosa.resample(
258
+ audio,
259
+ orig_sr=source_sampling_rate,
260
+ target_sr=self.sampling_rate,
261
+ )
262
+ if audio.shape[0] > self.segment_size:
263
+ # trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
264
+ audio = audio[: self.segment_size]
265
+
266
+ else: # Validation step
267
+ # Resample full audio clip to target sampling rate
268
+ if source_sampling_rate != self.sampling_rate:
269
+ audio = librosa.resample(
270
+ audio,
271
+ orig_sr=source_sampling_rate,
272
+ target_sr=self.sampling_rate,
273
+ )
274
+ # Trim last elements to match audio length to self.hop_size * n for evaluation
275
+ if (audio.shape[0] % self.hop_size) != 0:
276
+ audio = audio[: -(audio.shape[0] % self.hop_size)]
277
+
278
+ # BigVGAN is trained using volume-normalized waveform
279
+ audio = librosa.util.normalize(audio) * 0.95
280
+
281
+ # Cast ndarray to torch tensor
282
+ audio = torch.FloatTensor(audio)
283
+ audio = audio.unsqueeze(0) # [B(1), self.segment_size]
284
+
285
+ # Compute mel spectrogram corresponding to audio
286
+ mel = mel_spectrogram(
287
+ audio,
288
+ self.n_fft,
289
+ self.num_mels,
290
+ self.sampling_rate,
291
+ self.hop_size,
292
+ self.win_size,
293
+ self.fmin,
294
+ self.fmax,
295
+ center=False,
296
+ ) # [B(1), self.num_mels, self.segment_size // self.hop_size]
297
+
298
+ # Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
299
+ else:
300
+ # For fine-tuning, assert that the waveform is in the defined sampling_rate
301
+ # Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
302
+ assert source_sampling_rate == self.sampling_rate, (
303
+ f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
304
+ )
305
+
306
+ # Cast ndarray to torch tensor
307
+ audio = torch.FloatTensor(audio)
308
+ audio = audio.unsqueeze(0) # [B(1), T_time]
309
+
310
+ # Load pre-computed mel from disk
311
+ mel = np.load(
312
+ os.path.join(
313
+ self.base_mels_path,
314
+ os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
315
+ )
316
+ )
317
+ mel = torch.from_numpy(mel)
318
+
319
+ if len(mel.shape) < 3:
320
+ mel = mel.unsqueeze(0) # ensure [B, C, T]
321
+
322
+ if self.split:
323
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
324
+
325
+ if audio.size(1) >= self.segment_size:
326
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
327
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
328
+ audio = audio[
329
+ :,
330
+ mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
331
+ ]
332
+
333
+ # Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
334
+ # NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
335
+ # To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
336
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
337
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
338
+
339
+ # Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
340
+ mel_loss = mel_spectrogram(
341
+ audio,
342
+ self.n_fft,
343
+ self.num_mels,
344
+ self.sampling_rate,
345
+ self.hop_size,
346
+ self.win_size,
347
+ self.fmin,
348
+ self.fmax_loss,
349
+ center=False,
350
+ ) # [B(1), self.num_mels, self.segment_size // self.hop_size]
351
+
352
+ # Shape sanity checks
353
+ assert (
354
+ audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
355
+ ), (
356
+ f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
357
+ )
358
+
359
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
360
+
361
+ # If it encounters error during loading the data, skip this sample and load random other sample to the batch
362
+ except Exception as e:
363
+ if self.fine_tuning:
364
+ raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
365
+ else:
366
+ print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
367
+ return self[random.randrange(len(self))]
368
+
369
+ def __len__(self):
370
+ return len(self.audio_files)
GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ | Field | Response |
2
+ | :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- |
3
+ | Participation considerations from adversely impacted groups protected classes in model design and testing: | None |
4
+ | Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. |
GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | Field | Response |
2
+ | :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
3
+ | Intended Application & Domain: | Generating waveform from mel spectrogram. |
4
+ | Model Type: | Convolutional Neural Network (CNN) |
5
+ | Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. |
6
+ | Output: | Audio Waveform |
7
+ | Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. |
8
+ | Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable |
9
+ | Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. |
10
+ | Verified to have met prescribed NVIDIA quality standards: | Yes |
11
+ | Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) |
12
+ | Potential Known Risks: | This model may generate low-quality or distorted soundwaves. |
13
+ | Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE |
GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+
3
+ ## Description:
4
+
5
+ BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs.
6
+
7
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
8
+
9
+ BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers.
10
+
11
+ BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles.
12
+
13
+ It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms. <br>
14
+
15
+ This model is ready for commercial use.<br>
16
+
17
+ ## References(s):
18
+
19
+ - [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658) <br>
20
+ - [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) <br>
21
+ - [Audio Demo](https://bigvgan-demo.github.io/) <br>
22
+
23
+ ## Model Architecture:
24
+
25
+ **Architecture Type:** Convolution Neural Network (CNN) <br>
26
+ **Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658<br>
27
+ **Model Version:** 2.0 <br>
28
+
29
+ ## Input:
30
+
31
+ **Input Type:** Audio <br>
32
+ **Input Format:** Mel Spectrogram <br>
33
+ **Input Parameters:** None <br>
34
+ **Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory.
35
+
36
+ ## Output:
37
+
38
+ **Input Type:** Audio <br>
39
+ **Output Format:** Audio Waveform <br>
40
+ **Output Parameters:** None <br>
41
+ **Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`.
42
+
43
+ ## Software Integration:
44
+
45
+ **Runtime Engine(s):** PyTorch
46
+
47
+ **Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta <br>
48
+
49
+ ## Preferred/Supported Operating System(s):
50
+
51
+ Linux
52
+
53
+ ## Model Version(s):
54
+
55
+ v2.0
56
+
57
+ ## Training, Testing, and Evaluation Datasets:
58
+
59
+ ### Training Dataset:
60
+
61
+ The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
62
+
63
+ **Links:**
64
+
65
+ - [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629)
66
+ - [AudioCaps](https://audiocaps.github.io/)
67
+ - [AudioSet](https://research.google.com/audioset/index.html)
68
+ - [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent)
69
+ - [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440)
70
+ - [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection)
71
+ - [FSDnoisy18k](https://zenodo.org/records/2529934)
72
+ - [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384)
73
+ - [Greatest Hits dataset](https://andrewowens.com/vis/)
74
+ - [GTZAN](https://ieeexplore.ieee.org/document/1021072)
75
+ - [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus)
76
+ - [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194)
77
+ - [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/)
78
+ - [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench)
79
+ - [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps)
80
+ - [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset)
81
+ - [NSynth](https://magenta.tensorflow.org/datasets/nsynth)
82
+ - [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset)
83
+ - [Audio Piano Triads Dataset](https://zenodo.org/records/4740877)
84
+ - [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097)
85
+ - [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543)
86
+ - [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433)
87
+ - [WavText5K](https://github.com/microsoft/WavText5K)
88
+ - [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10)
89
+ - [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/)
90
+ - [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/)
91
+ - [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875)
92
+ - [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60)
93
+ - [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/)
94
+ - [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353)
95
+ - [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/)
96
+ - [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus)
97
+ - [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443)
98
+
99
+ \*\* Data Collection Method by dataset <br>
100
+
101
+ - Human <br>
102
+
103
+ \*\* Labeling Method by dataset (for those with labels) <br>
104
+
105
+ - Hybrid: Automated, Human, Unknown <br>
106
+
107
+ ### Evaluating Dataset:
108
+
109
+ Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders.
110
+
111
+ \*\* Data Collection Method by dataset <br>
112
+
113
+ - Human <br>
114
+
115
+ \*\* Labeling Method by dataset <br>
116
+
117
+ - Automated <br>
118
+
119
+ ## Inference:
120
+
121
+ **Engine:** PyTorch <br>
122
+ **Test Hardware:** NVIDIA A100 GPU <br>
123
+
124
+ ## Ethical Considerations:
125
+
126
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ | Field | Response |
2
+ | :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- |
3
+ | Generatable or reverse engineerable personal information? | None |
4
+ | Protected class data used to create this model? | None |
5
+ | Was consent obtained for any personal data used? | Not Applicable (No Personal Data) |
6
+ | How often is dataset reviewed? | Before Release |
7
+ | Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable |
8
+ | If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable |
9
+ | If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable |
10
+ | If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable |
11
+ | Is data in dataset traceable? | Yes |
12
+ | Is there provenance for all datasets used in training? | Yes |
13
+ | Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
14
+ | Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. |
GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ | Field | Response |
2
+ | :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
3
+ | Model Application(s): | Synethic Audio Generation |
4
+ | Describe the life critical impact (if present). | Not Applicable |
5
+ | Use Case Restrictions: | None |
6
+ | Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |
GPT_SoVITS/BigVGAN/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ librosa>=0.8.1
4
+ scipy
5
+ tensorboard
6
+ soundfile
7
+ matplotlib
8
+ pesq
9
+ auraloss
10
+ tqdm
11
+ nnAudio
12
+ ninja
13
+ huggingface_hub>=0.23.4
GPT_SoVITS/BigVGAN/train.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+
8
+ import warnings
9
+
10
+ warnings.simplefilter(action="ignore", category=FutureWarning)
11
+ import itertools
12
+ import os
13
+ import time
14
+ import argparse
15
+ import json
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from torch.utils.data import DistributedSampler, DataLoader
20
+ import torch.multiprocessing as mp
21
+ from torch.distributed import init_process_group
22
+ from torch.nn.parallel import DistributedDataParallel
23
+ from env import AttrDict, build_env
24
+ from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
25
+
26
+ from bigvgan import BigVGAN
27
+ from discriminators import (
28
+ MultiPeriodDiscriminator,
29
+ MultiResolutionDiscriminator,
30
+ MultiBandDiscriminator,
31
+ MultiScaleSubbandCQTDiscriminator,
32
+ )
33
+ from loss import (
34
+ feature_loss,
35
+ generator_loss,
36
+ discriminator_loss,
37
+ MultiScaleMelSpectrogramLoss,
38
+ )
39
+
40
+ from utils import (
41
+ plot_spectrogram,
42
+ plot_spectrogram_clipped,
43
+ scan_checkpoint,
44
+ load_checkpoint,
45
+ save_checkpoint,
46
+ save_audio,
47
+ )
48
+ import torchaudio as ta
49
+ from pesq import pesq
50
+ from tqdm import tqdm
51
+ import auraloss
52
+
53
+ torch.backends.cudnn.benchmark = False
54
+
55
+
56
+ def train(rank, a, h):
57
+ if h.num_gpus > 1:
58
+ # initialize distributed
59
+ init_process_group(
60
+ backend=h.dist_config["dist_backend"],
61
+ init_method=h.dist_config["dist_url"],
62
+ world_size=h.dist_config["world_size"] * h.num_gpus,
63
+ rank=rank,
64
+ )
65
+
66
+ # Set seed and device
67
+ torch.cuda.manual_seed(h.seed)
68
+ torch.cuda.set_device(rank)
69
+ device = torch.device(f"cuda:{rank:d}")
70
+
71
+ # Define BigVGAN generator
72
+ generator = BigVGAN(h).to(device)
73
+
74
+ # Define discriminators. MPD is used by default
75
+ mpd = MultiPeriodDiscriminator(h).to(device)
76
+
77
+ # Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
78
+ # New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
79
+ if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
80
+ print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
81
+ # Variable name is kept as "mrd" for backward compatibility & minimal code change
82
+ mrd = MultiBandDiscriminator(h).to(device)
83
+ elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
84
+ print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
85
+ mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
86
+ else: # Fallback to original MRD in BigVGAN-v1
87
+ mrd = MultiResolutionDiscriminator(h).to(device)
88
+
89
+ # New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
90
+ if h.get("use_multiscale_melloss", False):
91
+ print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
92
+ fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
93
+ sampling_rate=h.sampling_rate
94
+ ) # NOTE: accepts waveform as input
95
+ else:
96
+ fn_mel_loss_singlescale = F.l1_loss
97
+
98
+ # Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
99
+ if rank == 0:
100
+ print(generator)
101
+ print(mpd)
102
+ print(mrd)
103
+ print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
104
+ print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
105
+ print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
106
+ os.makedirs(a.checkpoint_path, exist_ok=True)
107
+ print(f"Checkpoints directory: {a.checkpoint_path}")
108
+
109
+ if os.path.isdir(a.checkpoint_path):
110
+ # New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
111
+ cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
112
+ cp_do = scan_checkpoint(
113
+ a.checkpoint_path,
114
+ prefix="do_",
115
+ renamed_file="bigvgan_discriminator_optimizer.pt",
116
+ )
117
+
118
+ # Load the latest checkpoint if exists
119
+ steps = 0
120
+ if cp_g is None or cp_do is None:
121
+ state_dict_do = None
122
+ last_epoch = -1
123
+ else:
124
+ state_dict_g = load_checkpoint(cp_g, device)
125
+ state_dict_do = load_checkpoint(cp_do, device)
126
+ generator.load_state_dict(state_dict_g["generator"])
127
+ mpd.load_state_dict(state_dict_do["mpd"])
128
+ mrd.load_state_dict(state_dict_do["mrd"])
129
+ steps = state_dict_do["steps"] + 1
130
+ last_epoch = state_dict_do["epoch"]
131
+
132
+ # Initialize DDP, optimizers, and schedulers
133
+ if h.num_gpus > 1:
134
+ generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
135
+ mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
136
+ mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
137
+
138
+ optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
139
+ optim_d = torch.optim.AdamW(
140
+ itertools.chain(mrd.parameters(), mpd.parameters()),
141
+ h.learning_rate,
142
+ betas=[h.adam_b1, h.adam_b2],
143
+ )
144
+
145
+ if state_dict_do is not None:
146
+ optim_g.load_state_dict(state_dict_do["optim_g"])
147
+ optim_d.load_state_dict(state_dict_do["optim_d"])
148
+
149
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
150
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
151
+
152
+ # Define training and validation datasets
153
+
154
+ """
155
+ unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
156
+ Example: trained on LibriTTS, validate on VCTK
157
+ """
158
+ training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
159
+
160
+ trainset = MelDataset(
161
+ training_filelist,
162
+ h,
163
+ h.segment_size,
164
+ h.n_fft,
165
+ h.num_mels,
166
+ h.hop_size,
167
+ h.win_size,
168
+ h.sampling_rate,
169
+ h.fmin,
170
+ h.fmax,
171
+ shuffle=False if h.num_gpus > 1 else True,
172
+ fmax_loss=h.fmax_for_loss,
173
+ device=device,
174
+ fine_tuning=a.fine_tuning,
175
+ base_mels_path=a.input_mels_dir,
176
+ is_seen=True,
177
+ )
178
+
179
+ train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
180
+
181
+ train_loader = DataLoader(
182
+ trainset,
183
+ num_workers=h.num_workers,
184
+ shuffle=False,
185
+ sampler=train_sampler,
186
+ batch_size=h.batch_size,
187
+ pin_memory=True,
188
+ drop_last=True,
189
+ )
190
+
191
+ if rank == 0:
192
+ validset = MelDataset(
193
+ validation_filelist,
194
+ h,
195
+ h.segment_size,
196
+ h.n_fft,
197
+ h.num_mels,
198
+ h.hop_size,
199
+ h.win_size,
200
+ h.sampling_rate,
201
+ h.fmin,
202
+ h.fmax,
203
+ False,
204
+ False,
205
+ fmax_loss=h.fmax_for_loss,
206
+ device=device,
207
+ fine_tuning=a.fine_tuning,
208
+ base_mels_path=a.input_mels_dir,
209
+ is_seen=True,
210
+ )
211
+ validation_loader = DataLoader(
212
+ validset,
213
+ num_workers=1,
214
+ shuffle=False,
215
+ sampler=None,
216
+ batch_size=1,
217
+ pin_memory=True,
218
+ drop_last=True,
219
+ )
220
+
221
+ list_unseen_validset = []
222
+ list_unseen_validation_loader = []
223
+ for i in range(len(list_unseen_validation_filelist)):
224
+ unseen_validset = MelDataset(
225
+ list_unseen_validation_filelist[i],
226
+ h,
227
+ h.segment_size,
228
+ h.n_fft,
229
+ h.num_mels,
230
+ h.hop_size,
231
+ h.win_size,
232
+ h.sampling_rate,
233
+ h.fmin,
234
+ h.fmax,
235
+ False,
236
+ False,
237
+ fmax_loss=h.fmax_for_loss,
238
+ device=device,
239
+ fine_tuning=a.fine_tuning,
240
+ base_mels_path=a.input_mels_dir,
241
+ is_seen=False,
242
+ )
243
+ unseen_validation_loader = DataLoader(
244
+ unseen_validset,
245
+ num_workers=1,
246
+ shuffle=False,
247
+ sampler=None,
248
+ batch_size=1,
249
+ pin_memory=True,
250
+ drop_last=True,
251
+ )
252
+ list_unseen_validset.append(unseen_validset)
253
+ list_unseen_validation_loader.append(unseen_validation_loader)
254
+
255
+ # Tensorboard logger
256
+ sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
257
+ if a.save_audio: # Also save audio to disk if --save_audio is set to True
258
+ os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
259
+
260
+ """
261
+ Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
262
+ If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
263
+ """
264
+
265
+ def validate(rank, a, h, loader, mode="seen"):
266
+ assert rank == 0, "validate should only run on rank=0"
267
+ generator.eval()
268
+ torch.cuda.empty_cache()
269
+
270
+ val_err_tot = 0
271
+ val_pesq_tot = 0
272
+ val_mrstft_tot = 0
273
+
274
+ # Modules for evaluation metrics
275
+ pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
276
+ loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
277
+
278
+ if a.save_audio: # Also save audio to disk if --save_audio is set to True
279
+ os.makedirs(
280
+ os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
281
+ exist_ok=True,
282
+ )
283
+ os.makedirs(
284
+ os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
285
+ exist_ok=True,
286
+ )
287
+
288
+ with torch.no_grad():
289
+ print(f"step {steps} {mode} speaker validation...")
290
+
291
+ # Loop over validation set and compute metrics
292
+ for j, batch in enumerate(tqdm(loader)):
293
+ x, y, _, y_mel = batch
294
+ y = y.to(device)
295
+ if hasattr(generator, "module"):
296
+ y_g_hat = generator.module(x.to(device))
297
+ else:
298
+ y_g_hat = generator(x.to(device))
299
+ y_mel = y_mel.to(device, non_blocking=True)
300
+ y_g_hat_mel = mel_spectrogram(
301
+ y_g_hat.squeeze(1),
302
+ h.n_fft,
303
+ h.num_mels,
304
+ h.sampling_rate,
305
+ h.hop_size,
306
+ h.win_size,
307
+ h.fmin,
308
+ h.fmax_for_loss,
309
+ )
310
+ min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
311
+ val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
312
+
313
+ # PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
314
+ if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
315
+ # Resample to 16000 for pesq
316
+ y_16k = pesq_resampler(y)
317
+ y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
318
+ y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
319
+ y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
320
+ val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
321
+
322
+ # MRSTFT calculation
323
+ min_t = min(y.size(-1), y_g_hat.size(-1))
324
+ val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
325
+
326
+ # Log audio and figures to Tensorboard
327
+ if j % a.eval_subsample == 0: # Subsample every nth from validation set
328
+ if steps >= 0:
329
+ sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
330
+ if a.save_audio: # Also save audio to disk if --save_audio is set to True
331
+ save_audio(
332
+ y[0],
333
+ os.path.join(
334
+ a.checkpoint_path,
335
+ "samples",
336
+ f"gt_{mode}",
337
+ f"{j:04d}.wav",
338
+ ),
339
+ h.sampling_rate,
340
+ )
341
+ sw.add_figure(
342
+ f"gt_{mode}/y_spec_{j}",
343
+ plot_spectrogram(x[0]),
344
+ steps,
345
+ )
346
+
347
+ sw.add_audio(
348
+ f"generated_{mode}/y_hat_{j}",
349
+ y_g_hat[0],
350
+ steps,
351
+ h.sampling_rate,
352
+ )
353
+ if a.save_audio: # Also save audio to disk if --save_audio is set to True
354
+ save_audio(
355
+ y_g_hat[0, 0],
356
+ os.path.join(
357
+ a.checkpoint_path,
358
+ "samples",
359
+ f"{mode}_{steps:08d}",
360
+ f"{j:04d}.wav",
361
+ ),
362
+ h.sampling_rate,
363
+ )
364
+ # Spectrogram of synthesized audio
365
+ y_hat_spec = mel_spectrogram(
366
+ y_g_hat.squeeze(1),
367
+ h.n_fft,
368
+ h.num_mels,
369
+ h.sampling_rate,
370
+ h.hop_size,
371
+ h.win_size,
372
+ h.fmin,
373
+ h.fmax,
374
+ )
375
+ sw.add_figure(
376
+ f"generated_{mode}/y_hat_spec_{j}",
377
+ plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
378
+ steps,
379
+ )
380
+
381
+ """
382
+ Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
383
+ """
384
+ spec_delta = torch.clamp(
385
+ torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
386
+ min=1e-6,
387
+ max=1.0,
388
+ )
389
+ sw.add_figure(
390
+ f"delta_dclip1_{mode}/spec_{j}",
391
+ plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
392
+ steps,
393
+ )
394
+
395
+ val_err = val_err_tot / (j + 1)
396
+ val_pesq = val_pesq_tot / (j + 1)
397
+ val_mrstft = val_mrstft_tot / (j + 1)
398
+ # Log evaluation metrics to Tensorboard
399
+ sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
400
+ sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
401
+ sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
402
+
403
+ generator.train()
404
+
405
+ # If the checkpoint is loaded, start with validation loop
406
+ if steps != 0 and rank == 0 and not a.debug:
407
+ if not a.skip_seen:
408
+ validate(
409
+ rank,
410
+ a,
411
+ h,
412
+ validation_loader,
413
+ mode=f"seen_{train_loader.dataset.name}",
414
+ )
415
+ for i in range(len(list_unseen_validation_loader)):
416
+ validate(
417
+ rank,
418
+ a,
419
+ h,
420
+ list_unseen_validation_loader[i],
421
+ mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
422
+ )
423
+ # Exit the script if --evaluate is set to True
424
+ if a.evaluate:
425
+ exit()
426
+
427
+ # Main training loop
428
+ generator.train()
429
+ mpd.train()
430
+ mrd.train()
431
+ for epoch in range(max(0, last_epoch), a.training_epochs):
432
+ if rank == 0:
433
+ start = time.time()
434
+ print(f"Epoch: {epoch + 1}")
435
+
436
+ if h.num_gpus > 1:
437
+ train_sampler.set_epoch(epoch)
438
+
439
+ for i, batch in enumerate(train_loader):
440
+ if rank == 0:
441
+ start_b = time.time()
442
+ x, y, _, y_mel = batch
443
+
444
+ x = x.to(device, non_blocking=True)
445
+ y = y.to(device, non_blocking=True)
446
+ y_mel = y_mel.to(device, non_blocking=True)
447
+ y = y.unsqueeze(1)
448
+
449
+ y_g_hat = generator(x)
450
+ y_g_hat_mel = mel_spectrogram(
451
+ y_g_hat.squeeze(1),
452
+ h.n_fft,
453
+ h.num_mels,
454
+ h.sampling_rate,
455
+ h.hop_size,
456
+ h.win_size,
457
+ h.fmin,
458
+ h.fmax_for_loss,
459
+ )
460
+
461
+ optim_d.zero_grad()
462
+
463
+ # MPD
464
+ y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
465
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
466
+
467
+ # MRD
468
+ y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
469
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
470
+
471
+ loss_disc_all = loss_disc_s + loss_disc_f
472
+
473
+ # Set clip_grad_norm value
474
+ clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
475
+
476
+ # Whether to freeze D for initial training steps
477
+ if steps >= a.freeze_step:
478
+ loss_disc_all.backward()
479
+ grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
480
+ grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
481
+ optim_d.step()
482
+ else:
483
+ print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
484
+ grad_norm_mpd = 0.0
485
+ grad_norm_mrd = 0.0
486
+
487
+ # Generator
488
+ optim_g.zero_grad()
489
+
490
+ # L1 Mel-Spectrogram Loss
491
+ lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
492
+ if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
493
+ loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
494
+ else: # Uses mel <y_mel, y_g_hat_mel> for loss
495
+ loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
496
+
497
+ # MPD loss
498
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
499
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
500
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
501
+
502
+ # MRD loss
503
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
504
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
505
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
506
+
507
+ if steps >= a.freeze_step:
508
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
509
+ else:
510
+ print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
511
+ loss_gen_all = loss_mel
512
+
513
+ loss_gen_all.backward()
514
+ grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
515
+ optim_g.step()
516
+
517
+ if rank == 0:
518
+ # STDOUT logging
519
+ if steps % a.stdout_interval == 0:
520
+ mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
521
+ print(
522
+ f"Steps: {steps:d}, "
523
+ f"Gen Loss Total: {loss_gen_all:4.3f}, "
524
+ f"Mel Error: {mel_error:4.3f}, "
525
+ f"s/b: {time.time() - start_b:4.3f} "
526
+ f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
527
+ f"grad_norm_g: {grad_norm_g:4.3f}"
528
+ )
529
+
530
+ # Checkpointing
531
+ if steps % a.checkpoint_interval == 0 and steps != 0:
532
+ checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
533
+ save_checkpoint(
534
+ checkpoint_path,
535
+ {"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
536
+ )
537
+ checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
538
+ save_checkpoint(
539
+ checkpoint_path,
540
+ {
541
+ "mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
542
+ "mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
543
+ "optim_g": optim_g.state_dict(),
544
+ "optim_d": optim_d.state_dict(),
545
+ "steps": steps,
546
+ "epoch": epoch,
547
+ },
548
+ )
549
+
550
+ # Tensorboard summary logging
551
+ if steps % a.summary_interval == 0:
552
+ mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
553
+ sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
554
+ sw.add_scalar("training/mel_spec_error", mel_error, steps)
555
+ sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
556
+ sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
557
+ sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
558
+ sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
559
+ sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
560
+ sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
561
+ sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
562
+ sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
563
+ sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
564
+ sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
565
+ sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
566
+ sw.add_scalar("training/epoch", epoch + 1, steps)
567
+
568
+ # Validation
569
+ if steps % a.validation_interval == 0:
570
+ # Plot training input x so far used
571
+ for i_x in range(x.shape[0]):
572
+ sw.add_figure(
573
+ f"training_input/x_{i_x}",
574
+ plot_spectrogram(x[i_x].cpu()),
575
+ steps,
576
+ )
577
+ sw.add_audio(
578
+ f"training_input/y_{i_x}",
579
+ y[i_x][0],
580
+ steps,
581
+ h.sampling_rate,
582
+ )
583
+
584
+ # Seen and unseen speakers validation loops
585
+ if not a.debug and steps != 0:
586
+ validate(
587
+ rank,
588
+ a,
589
+ h,
590
+ validation_loader,
591
+ mode=f"seen_{train_loader.dataset.name}",
592
+ )
593
+ for i in range(len(list_unseen_validation_loader)):
594
+ validate(
595
+ rank,
596
+ a,
597
+ h,
598
+ list_unseen_validation_loader[i],
599
+ mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
600
+ )
601
+ steps += 1
602
+
603
+ # BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
604
+ scheduler_g.step()
605
+ scheduler_d.step()
606
+
607
+ if rank == 0:
608
+ print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
609
+
610
+
611
+ def main():
612
+ print("Initializing Training Process..")
613
+
614
+ parser = argparse.ArgumentParser()
615
+
616
+ parser.add_argument("--group_name", default=None)
617
+
618
+ parser.add_argument("--input_wavs_dir", default="LibriTTS")
619
+ parser.add_argument("--input_mels_dir", default="ft_dataset")
620
+ parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
621
+ parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
622
+
623
+ parser.add_argument(
624
+ "--list_input_unseen_wavs_dir",
625
+ nargs="+",
626
+ default=["tests/LibriTTS", "tests/LibriTTS"],
627
+ )
628
+ parser.add_argument(
629
+ "--list_input_unseen_validation_file",
630
+ nargs="+",
631
+ default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
632
+ )
633
+
634
+ parser.add_argument("--checkpoint_path", default="exp/bigvgan")
635
+ parser.add_argument("--config", default="")
636
+
637
+ parser.add_argument("--training_epochs", default=100000, type=int)
638
+ parser.add_argument("--stdout_interval", default=5, type=int)
639
+ parser.add_argument("--checkpoint_interval", default=50000, type=int)
640
+ parser.add_argument("--summary_interval", default=100, type=int)
641
+ parser.add_argument("--validation_interval", default=50000, type=int)
642
+
643
+ parser.add_argument(
644
+ "--freeze_step",
645
+ default=0,
646
+ type=int,
647
+ help="freeze D for the first specified steps. G only uses regression loss for these steps.",
648
+ )
649
+
650
+ parser.add_argument("--fine_tuning", default=False, type=bool)
651
+
652
+ parser.add_argument(
653
+ "--debug",
654
+ default=False,
655
+ type=bool,
656
+ help="debug mode. skips validation loop throughout training",
657
+ )
658
+ parser.add_argument(
659
+ "--evaluate",
660
+ default=False,
661
+ type=bool,
662
+ help="only run evaluation from checkpoint and exit",
663
+ )
664
+ parser.add_argument(
665
+ "--eval_subsample",
666
+ default=5,
667
+ type=int,
668
+ help="subsampling during evaluation loop",
669
+ )
670
+ parser.add_argument(
671
+ "--skip_seen",
672
+ default=False,
673
+ type=bool,
674
+ help="skip seen dataset. useful for test set inference",
675
+ )
676
+ parser.add_argument(
677
+ "--save_audio",
678
+ default=False,
679
+ type=bool,
680
+ help="save audio of test set inference to disk",
681
+ )
682
+
683
+ a = parser.parse_args()
684
+
685
+ with open(a.config) as f:
686
+ data = f.read()
687
+
688
+ json_config = json.loads(data)
689
+ h = AttrDict(json_config)
690
+
691
+ build_env(a.config, "config.json", a.checkpoint_path)
692
+
693
+ torch.manual_seed(h.seed)
694
+ if torch.cuda.is_available():
695
+ torch.cuda.manual_seed(h.seed)
696
+ h.num_gpus = torch.cuda.device_count()
697
+ h.batch_size = int(h.batch_size / h.num_gpus)
698
+ print(f"Batch size per GPU: {h.batch_size}")
699
+ else:
700
+ pass
701
+
702
+ if h.num_gpus > 1:
703
+ mp.spawn(
704
+ train,
705
+ nprocs=h.num_gpus,
706
+ args=(
707
+ a,
708
+ h,
709
+ ),
710
+ )
711
+ else:
712
+ train(0, a, h)
713
+
714
+
715
+ if __name__ == "__main__":
716
+ main()
GPT_SoVITS/BigVGAN/utils0.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+ from .meldataset import MAX_WAV_VALUE
13
+ from scipy.io.wavfile import write
14
+
15
+
16
+ def plot_spectrogram(spectrogram):
17
+ fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28
+ fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(
30
+ spectrogram,
31
+ aspect="auto",
32
+ origin="lower",
33
+ interpolation="none",
34
+ vmin=1e-6,
35
+ vmax=clip_max,
36
+ )
37
+ plt.colorbar(im, ax=ax)
38
+
39
+ fig.canvas.draw()
40
+ plt.close()
41
+
42
+ return fig
43
+
44
+
45
+ def init_weights(m, mean=0.0, std=0.01):
46
+ classname = m.__class__.__name__
47
+ if classname.find("Conv") != -1:
48
+ m.weight.data.normal_(mean, std)
49
+
50
+
51
+ def apply_weight_norm(m):
52
+ classname = m.__class__.__name__
53
+ if classname.find("Conv") != -1:
54
+ weight_norm(m)
55
+
56
+
57
+ def get_padding(kernel_size, dilation=1):
58
+ return int((kernel_size * dilation - dilation) / 2)
59
+
60
+
61
+ def load_checkpoint(filepath, device):
62
+ assert os.path.isfile(filepath)
63
+ print(f"Loading '{filepath}'")
64
+ checkpoint_dict = torch.load(filepath, map_location=device)
65
+ print("Complete.")
66
+ return checkpoint_dict
67
+
68
+
69
+ def save_checkpoint(filepath, obj):
70
+ print(f"Saving checkpoint to {filepath}")
71
+ torch.save(obj, filepath)
72
+ print("Complete.")
73
+
74
+
75
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76
+ # Fallback to original scanning logic first
77
+ pattern = os.path.join(cp_dir, prefix + "????????")
78
+ cp_list = glob.glob(pattern)
79
+
80
+ if len(cp_list) > 0:
81
+ last_checkpoint_path = sorted(cp_list)[-1]
82
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83
+ return last_checkpoint_path
84
+
85
+ # If no pattern-based checkpoints are found, check for renamed file
86
+ if renamed_file:
87
+ renamed_path = os.path.join(cp_dir, renamed_file)
88
+ if os.path.isfile(renamed_path):
89
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90
+ return renamed_path
91
+
92
+ return None
93
+
94
+
95
+ def save_audio(audio, path, sr):
96
+ # wav: torch with 1d shape
97
+ audio = audio * MAX_WAV_VALUE
98
+ audio = audio.cpu().numpy().astype("int16")
99
+ write(path, sr, audio)
GPT_SoVITS/download.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.insert(0, now_dir)
6
+ from text.g2pw import G2PWPinyin
7
+
8
+ g2pw = G2PWPinyin(
9
+ model_dir="GPT_SoVITS/text/G2PWModel",
10
+ model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
11
+ v_to_u=False,
12
+ neutral_tone_with_five=True,
13
+ )
GPT_SoVITS/export_torch_script.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
2
+ # reference: https://github.com/lifeiteng/vall-e
3
+ import argparse
4
+ from typing import Optional
5
+ from my_utils import load_audio
6
+ import torch
7
+ import torchaudio
8
+
9
+ from torch import IntTensor, LongTensor, Tensor, nn
10
+ from torch.nn import functional as F
11
+
12
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
13
+ from feature_extractor import cnhubert
14
+
15
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
16
+ from module.models_onnx import SynthesizerTrn
17
+
18
+ from inference_webui import get_phones_and_bert
19
+
20
+ import os
21
+ import soundfile
22
+
23
+ default_config = {
24
+ "embedding_dim": 512,
25
+ "hidden_dim": 512,
26
+ "num_head": 8,
27
+ "num_layers": 12,
28
+ "num_codebook": 8,
29
+ "p_dropout": 0.0,
30
+ "vocab_size": 1024 + 1,
31
+ "phoneme_vocab_size": 512,
32
+ "EOS": 1024,
33
+ }
34
+
35
+
36
+ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
37
+ config = dict_s1["config"]
38
+ config["model"]["dropout"] = float(config["model"]["dropout"])
39
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
40
+ t2s_model.load_state_dict(dict_s1["weight"])
41
+ t2s_model = t2s_model.eval()
42
+ return t2s_model
43
+
44
+
45
+ @torch.jit.script
46
+ def logits_to_probs(
47
+ logits,
48
+ previous_tokens: Optional[torch.Tensor] = None,
49
+ temperature: float = 1.0,
50
+ top_k: Optional[int] = None,
51
+ top_p: Optional[int] = None,
52
+ repetition_penalty: float = 1.0,
53
+ ):
54
+ # if previous_tokens is not None:
55
+ # previous_tokens = previous_tokens.squeeze()
56
+ # print(logits.shape,previous_tokens.shape)
57
+ # pdb.set_trace()
58
+ if previous_tokens is not None and repetition_penalty != 1.0:
59
+ previous_tokens = previous_tokens.long()
60
+ score = torch.gather(logits, dim=1, index=previous_tokens)
61
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
62
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
63
+
64
+ if top_p is not None and top_p < 1.0:
65
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
66
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
67
+ sorted_indices_to_remove = cum_probs > top_p
68
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
69
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
70
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
71
+
72
+ logits = logits / max(temperature, 1e-5)
73
+
74
+ if top_k is not None:
75
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
76
+ pivot = v[:, -1].unsqueeze(-1)
77
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
78
+
79
+ probs = torch.nn.functional.softmax(logits, dim=-1)
80
+ return probs
81
+
82
+
83
+ @torch.jit.script
84
+ def multinomial_sample_one_no_sync(probs_sort):
85
+ # Does multinomial sampling without a cuda synchronization
86
+ q = torch.randn_like(probs_sort)
87
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
88
+
89
+
90
+ @torch.jit.script
91
+ def sample(
92
+ logits,
93
+ previous_tokens,
94
+ temperature: float = 1.0,
95
+ top_k: Optional[int] = None,
96
+ top_p: Optional[int] = None,
97
+ repetition_penalty: float = 1.0,
98
+ ):
99
+ probs = logits_to_probs(
100
+ logits=logits,
101
+ previous_tokens=previous_tokens,
102
+ temperature=temperature,
103
+ top_k=top_k,
104
+ top_p=top_p,
105
+ repetition_penalty=repetition_penalty,
106
+ )
107
+ idx_next = multinomial_sample_one_no_sync(probs)
108
+ return idx_next, probs
109
+
110
+
111
+ @torch.jit.script
112
+ def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
113
+ hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
114
+ y = torch.nn.functional.pad(
115
+ y.unsqueeze(1),
116
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
117
+ mode="reflect",
118
+ )
119
+ y = y.squeeze(1)
120
+ spec = torch.stft(
121
+ y,
122
+ n_fft,
123
+ hop_length=hop_size,
124
+ win_length=win_size,
125
+ window=hann_window,
126
+ center=center,
127
+ pad_mode="reflect",
128
+ normalized=False,
129
+ onesided=True,
130
+ return_complex=False,
131
+ )
132
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
133
+ return spec
134
+
135
+
136
+ class DictToAttrRecursive(dict):
137
+ def __init__(self, input_dict):
138
+ super().__init__(input_dict)
139
+ for key, value in input_dict.items():
140
+ if isinstance(value, dict):
141
+ value = DictToAttrRecursive(value)
142
+ self[key] = value
143
+ setattr(self, key, value)
144
+
145
+ def __getattr__(self, item):
146
+ try:
147
+ return self[item]
148
+ except KeyError:
149
+ raise AttributeError(f"Attribute {item} not found")
150
+
151
+ def __setattr__(self, key, value):
152
+ if isinstance(value, dict):
153
+ value = DictToAttrRecursive(value)
154
+ super(DictToAttrRecursive, self).__setitem__(key, value)
155
+ super().__setattr__(key, value)
156
+
157
+ def __delattr__(self, item):
158
+ try:
159
+ del self[item]
160
+ except KeyError:
161
+ raise AttributeError(f"Attribute {item} not found")
162
+
163
+
164
+ @torch.jit.script
165
+ class T2SMLP:
166
+ def __init__(self, w1, b1, w2, b2):
167
+ self.w1 = w1
168
+ self.b1 = b1
169
+ self.w2 = w2
170
+ self.b2 = b2
171
+
172
+ def forward(self, x):
173
+ x = F.relu(F.linear(x, self.w1, self.b1))
174
+ x = F.linear(x, self.w2, self.b2)
175
+ return x
176
+
177
+
178
+ @torch.jit.script
179
+ class T2SBlock:
180
+ def __init__(
181
+ self,
182
+ num_heads: int,
183
+ hidden_dim: int,
184
+ mlp: T2SMLP,
185
+ qkv_w,
186
+ qkv_b,
187
+ out_w,
188
+ out_b,
189
+ norm_w1,
190
+ norm_b1,
191
+ norm_eps1: float,
192
+ norm_w2,
193
+ norm_b2,
194
+ norm_eps2: float,
195
+ ):
196
+ self.num_heads = num_heads
197
+ self.mlp = mlp
198
+ self.hidden_dim: int = hidden_dim
199
+ self.qkv_w = qkv_w
200
+ self.qkv_b = qkv_b
201
+ self.out_w = out_w
202
+ self.out_b = out_b
203
+ self.norm_w1 = norm_w1
204
+ self.norm_b1 = norm_b1
205
+ self.norm_eps1 = norm_eps1
206
+ self.norm_w2 = norm_w2
207
+ self.norm_b2 = norm_b2
208
+ self.norm_eps2 = norm_eps2
209
+
210
+ self.false = torch.tensor(False, dtype=torch.bool)
211
+
212
+ @torch.jit.ignore
213
+ def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
214
+ if padding_mask is None:
215
+ return x
216
+
217
+ if padding_mask.dtype == torch.bool:
218
+ return x.masked_fill(padding_mask, 0)
219
+ else:
220
+ return x * padding_mask
221
+
222
+ def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
223
+ q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
224
+
225
+ batch_size = q.shape[0]
226
+ q_len = q.shape[1]
227
+ kv_len = k.shape[1]
228
+
229
+ q = self.to_mask(q, padding_mask)
230
+ k_cache = self.to_mask(k, padding_mask)
231
+ v_cache = self.to_mask(v, padding_mask)
232
+
233
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
234
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
235
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
236
+
237
+ attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
238
+
239
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
240
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
241
+ attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
242
+
243
+ if padding_mask is not None:
244
+ for i in range(batch_size):
245
+ # mask = padding_mask[i,:,0]
246
+ if self.false.device != padding_mask.device:
247
+ self.false = self.false.to(padding_mask.device)
248
+ idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
249
+ x_item = x[i, idx, :].unsqueeze(0)
250
+ attn_item = attn[i, idx, :].unsqueeze(0)
251
+ x_item = x_item + attn_item
252
+ x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
253
+ x_item = x_item + self.mlp.forward(x_item)
254
+ x_item = F.layer_norm(
255
+ x_item,
256
+ [self.hidden_dim],
257
+ self.norm_w2,
258
+ self.norm_b2,
259
+ self.norm_eps2,
260
+ )
261
+ x[i, idx, :] = x_item.squeeze(0)
262
+ x = self.to_mask(x, padding_mask)
263
+ else:
264
+ x = x + attn
265
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
266
+ x = x + self.mlp.forward(x)
267
+ x = F.layer_norm(
268
+ x,
269
+ [self.hidden_dim],
270
+ self.norm_w2,
271
+ self.norm_b2,
272
+ self.norm_eps2,
273
+ )
274
+ return x, k_cache, v_cache
275
+
276
+ def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
277
+ q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
278
+
279
+ k_cache = torch.cat([k_cache, k], dim=1)
280
+ v_cache = torch.cat([v_cache, v], dim=1)
281
+
282
+ batch_size = q.shape[0]
283
+ q_len = q.shape[1]
284
+ kv_len = k_cache.shape[1]
285
+
286
+ q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
287
+ k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
288
+ v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
289
+
290
+ attn = F.scaled_dot_product_attention(q, k, v)
291
+
292
+ attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
293
+ attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
294
+ attn = F.linear(attn, self.out_w, self.out_b)
295
+
296
+ x = x + attn
297
+ x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
298
+ x = x + self.mlp.forward(x)
299
+ x = F.layer_norm(
300
+ x,
301
+ [self.hidden_dim],
302
+ self.norm_w2,
303
+ self.norm_b2,
304
+ self.norm_eps2,
305
+ )
306
+ return x, k_cache, v_cache
307
+
308
+
309
+ @torch.jit.script
310
+ class T2STransformer:
311
+ def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
312
+ self.num_blocks: int = num_blocks
313
+ self.blocks = blocks
314
+
315
+ def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
316
+ k_cache: list[torch.Tensor] = []
317
+ v_cache: list[torch.Tensor] = []
318
+ for i in range(self.num_blocks):
319
+ x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
320
+ k_cache.append(k_cache_)
321
+ v_cache.append(v_cache_)
322
+ return x, k_cache, v_cache
323
+
324
+ def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
325
+ for i in range(self.num_blocks):
326
+ x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
327
+ return x, k_cache, v_cache
328
+
329
+
330
+ class VitsModel(nn.Module):
331
+ def __init__(self, vits_path):
332
+ super().__init__()
333
+ # dict_s2 = torch.load(vits_path,map_location="cpu")
334
+ dict_s2 = torch.load(vits_path)
335
+ self.hps = dict_s2["config"]
336
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
337
+ self.hps["model"]["version"] = "v1"
338
+ else:
339
+ self.hps["model"]["version"] = "v2"
340
+
341
+ self.hps = DictToAttrRecursive(self.hps)
342
+ self.hps.model.semantic_frame_rate = "25hz"
343
+ self.vq_model = SynthesizerTrn(
344
+ self.hps.data.filter_length // 2 + 1,
345
+ self.hps.train.segment_size // self.hps.data.hop_length,
346
+ n_speakers=self.hps.data.n_speakers,
347
+ **self.hps.model,
348
+ )
349
+ self.vq_model.eval()
350
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
351
+
352
+ def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
353
+ refer = spectrogram_torch(
354
+ ref_audio,
355
+ self.hps.data.filter_length,
356
+ self.hps.data.sampling_rate,
357
+ self.hps.data.hop_length,
358
+ self.hps.data.win_length,
359
+ center=False,
360
+ )
361
+ return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
362
+
363
+
364
+ class T2SModel(nn.Module):
365
+ def __init__(self, raw_t2s: Text2SemanticLightningModule):
366
+ super(T2SModel, self).__init__()
367
+ self.model_dim = raw_t2s.model.model_dim
368
+ self.embedding_dim = raw_t2s.model.embedding_dim
369
+ self.num_head = raw_t2s.model.num_head
370
+ self.num_layers = raw_t2s.model.num_layers
371
+ self.vocab_size = raw_t2s.model.vocab_size
372
+ self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
373
+ # self.p_dropout = float(raw_t2s.model.p_dropout)
374
+ self.EOS: int = int(raw_t2s.model.EOS)
375
+ self.norm_first = raw_t2s.model.norm_first
376
+ assert self.EOS == self.vocab_size - 1
377
+ self.hz = 50
378
+
379
+ self.bert_proj = raw_t2s.model.bert_proj
380
+ self.ar_text_embedding = raw_t2s.model.ar_text_embedding
381
+ self.ar_text_position = raw_t2s.model.ar_text_position
382
+ self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
383
+ self.ar_audio_position = raw_t2s.model.ar_audio_position
384
+
385
+ # self.t2s_transformer = T2STransformer(self.num_layers, blocks)
386
+ # self.t2s_transformer = raw_t2s.model.t2s_transformer
387
+
388
+ blocks = []
389
+ h = raw_t2s.model.h
390
+
391
+ for i in range(self.num_layers):
392
+ layer = h.layers[i]
393
+ t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
394
+
395
+ block = T2SBlock(
396
+ self.num_head,
397
+ self.model_dim,
398
+ t2smlp,
399
+ layer.self_attn.in_proj_weight,
400
+ layer.self_attn.in_proj_bias,
401
+ layer.self_attn.out_proj.weight,
402
+ layer.self_attn.out_proj.bias,
403
+ layer.norm1.weight,
404
+ layer.norm1.bias,
405
+ layer.norm1.eps,
406
+ layer.norm2.weight,
407
+ layer.norm2.bias,
408
+ layer.norm2.eps,
409
+ )
410
+
411
+ blocks.append(block)
412
+
413
+ self.t2s_transformer = T2STransformer(self.num_layers, blocks)
414
+
415
+ # self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
416
+ self.ar_predict_layer = raw_t2s.model.ar_predict_layer
417
+ # self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
418
+ self.max_sec = raw_t2s.config["data"]["max_sec"]
419
+ self.top_k = int(raw_t2s.config["inference"]["top_k"])
420
+ self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
421
+
422
+ def forward(
423
+ self,
424
+ prompts: LongTensor,
425
+ ref_seq: LongTensor,
426
+ text_seq: LongTensor,
427
+ ref_bert: torch.Tensor,
428
+ text_bert: torch.Tensor,
429
+ top_k: LongTensor,
430
+ ):
431
+ bert = torch.cat([ref_bert.T, text_bert.T], 1)
432
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
433
+ bert = bert.unsqueeze(0)
434
+
435
+ x = self.ar_text_embedding(all_phoneme_ids)
436
+ x = x + self.bert_proj(bert.transpose(1, 2))
437
+ x: torch.Tensor = self.ar_text_position(x)
438
+
439
+ early_stop_num = self.early_stop_num
440
+
441
+ # [1,N,512] [1,N]
442
+ # y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
443
+ y = prompts
444
+ # x_example = x[:,:,0] * 0.0
445
+
446
+ x_len = x.shape[1]
447
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
448
+
449
+ y_emb = self.ar_audio_embedding(y)
450
+ y_len = y_emb.shape[1]
451
+ prefix_len = y.shape[1]
452
+ y_pos = self.ar_audio_position(y_emb)
453
+ xy_pos = torch.concat([x, y_pos], dim=1)
454
+
455
+ bsz = x.shape[0]
456
+ src_len = x_len + y_len
457
+ x_attn_mask_pad = F.pad(
458
+ x_attn_mask,
459
+ (0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
460
+ value=True,
461
+ )
462
+ y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
463
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
464
+ (x_len, 0),
465
+ value=False,
466
+ )
467
+ xy_attn_mask = (
468
+ torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
469
+ .unsqueeze(0)
470
+ .expand(bsz * self.num_head, -1, -1)
471
+ .view(bsz, self.num_head, src_len, src_len)
472
+ .to(device=x.device, dtype=torch.bool)
473
+ )
474
+
475
+ idx = 0
476
+ top_k = int(top_k)
477
+
478
+ xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
479
+
480
+ logits = self.ar_predict_layer(xy_dec[:, -1])
481
+ logits = logits[:, :-1]
482
+ samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
483
+ y = torch.concat([y, samples], dim=1)
484
+ y_emb = self.ar_audio_embedding(y[:, -1:])
485
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
486
+ :, y_len + idx
487
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
488
+
489
+ stop = False
490
+ # for idx in range(1, 50):
491
+ for idx in range(1, 1500):
492
+ # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
493
+ # y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
494
+ xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
495
+ logits = self.ar_predict_layer(xy_dec[:, -1])
496
+
497
+ if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
498
+ logits = logits[:, :-1]
499
+
500
+ samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
501
+
502
+ y = torch.concat([y, samples], dim=1)
503
+
504
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
505
+ stop = True
506
+ if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
507
+ stop = True
508
+ if stop:
509
+ if y.shape[1] == 0:
510
+ y = torch.concat([y, torch.zeros_like(samples)], dim=1)
511
+ break
512
+
513
+ y_emb = self.ar_audio_embedding(y[:, -1:])
514
+ xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
515
+ :, y_len + idx
516
+ ].to(dtype=y_emb.dtype, device=y_emb.device)
517
+
518
+ y[0, -1] = 0
519
+
520
+ return y[:, -idx:].unsqueeze(0)
521
+
522
+
523
+ bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
524
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
525
+ cnhubert.cnhubert_base_path = cnhubert_base_path
526
+
527
+
528
+ @torch.jit.script
529
+ def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
530
+ phone_level_feature = []
531
+ for i in range(word2ph.shape[0]):
532
+ repeat_feature = res[i].repeat(word2ph[i].item(), 1)
533
+ phone_level_feature.append(repeat_feature)
534
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
535
+ # [sum(word2ph), 1024]
536
+ return phone_level_feature
537
+
538
+
539
+ class MyBertModel(torch.nn.Module):
540
+ def __init__(self, bert_model):
541
+ super(MyBertModel, self).__init__()
542
+ self.bert = bert_model
543
+
544
+ def forward(
545
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
546
+ ):
547
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
548
+ # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
549
+ res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
550
+ return build_phone_level_feature(res, word2ph)
551
+
552
+
553
+ class SSLModel(torch.nn.Module):
554
+ def __init__(self):
555
+ super().__init__()
556
+ self.ssl = cnhubert.get_model().model
557
+
558
+ def forward(self, ref_audio_16k) -> torch.Tensor:
559
+ ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
560
+ return ssl_content
561
+
562
+
563
+ class ExportSSLModel(torch.nn.Module):
564
+ def __init__(self, ssl: SSLModel):
565
+ super().__init__()
566
+ self.ssl = ssl
567
+
568
+ def forward(self, ref_audio: torch.Tensor):
569
+ return self.ssl(ref_audio)
570
+
571
+ @torch.jit.export
572
+ def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
573
+ audio = resamplex(ref_audio, src_sr, dst_sr).float()
574
+ return audio
575
+
576
+
577
+ def export_bert(output_path):
578
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
579
+
580
+ text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
581
+ ref_bert_inputs = tokenizer(text, return_tensors="pt")
582
+ word2ph = []
583
+ for c in text:
584
+ if c in [",", "。", ":", "?", ",", ".", "?"]:
585
+ word2ph.append(1)
586
+ else:
587
+ word2ph.append(2)
588
+ ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
589
+
590
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
591
+ my_bert_model = MyBertModel(bert_model)
592
+
593
+ ref_bert_inputs = {
594
+ "input_ids": ref_bert_inputs["input_ids"],
595
+ "attention_mask": ref_bert_inputs["attention_mask"],
596
+ "token_type_ids": ref_bert_inputs["token_type_ids"],
597
+ "word2ph": ref_bert_inputs["word2ph"],
598
+ }
599
+
600
+ torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
601
+ torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
602
+ torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
603
+ torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
604
+
605
+ my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
606
+ output_path = os.path.join(output_path, "bert_model.pt")
607
+ my_bert_model.save(output_path)
608
+ print("#### exported bert ####")
609
+
610
+
611
+ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
612
+ if not os.path.exists(output_path):
613
+ os.makedirs(output_path)
614
+ print(f"目录已创建: {output_path}")
615
+ else:
616
+ print(f"目录已存在: {output_path}")
617
+
618
+ ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
619
+ ssl = SSLModel()
620
+ if export_bert_and_ssl:
621
+ s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
622
+ ssl_path = os.path.join(output_path, "ssl_model.pt")
623
+ torch.jit.script(s).save(ssl_path)
624
+ print("#### exported ssl ####")
625
+ export_bert(output_path)
626
+ else:
627
+ s = ExportSSLModel(ssl)
628
+
629
+ print(f"device: {device}")
630
+
631
+ ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
632
+ ref_seq = torch.LongTensor([ref_seq_id]).to(device)
633
+ ref_bert = ref_bert_T.T.to(ref_seq.device)
634
+ text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
635
+ "这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
636
+ )
637
+ text_seq = torch.LongTensor([text_seq_id]).to(device)
638
+ text_bert = text_bert_T.T.to(text_seq.device)
639
+
640
+ ssl_content = ssl(ref_audio).to(device)
641
+
642
+ # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
643
+ vits = VitsModel(vits_path).to(device)
644
+ vits.eval()
645
+
646
+ # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
647
+ # dict_s1 = torch.load(gpt_path, map_location=device)
648
+ dict_s1 = torch.load(gpt_path)
649
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
650
+ print("#### get_raw_t2s_model ####")
651
+ print(raw_t2s.config)
652
+ t2s_m = T2SModel(raw_t2s)
653
+ t2s_m.eval()
654
+ t2s = torch.jit.script(t2s_m).to(device)
655
+ print("#### script t2s_m ####")
656
+
657
+ print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
658
+ gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
659
+ gpt_sovits.eval()
660
+
661
+ ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
662
+
663
+ torch._dynamo.mark_dynamic(ssl_content, 2)
664
+ torch._dynamo.mark_dynamic(ref_audio_sr, 1)
665
+ torch._dynamo.mark_dynamic(ref_seq, 1)
666
+ torch._dynamo.mark_dynamic(text_seq, 1)
667
+ torch._dynamo.mark_dynamic(ref_bert, 0)
668
+ torch._dynamo.mark_dynamic(text_bert, 0)
669
+
670
+ top_k = torch.LongTensor([5]).to(device)
671
+
672
+ with torch.no_grad():
673
+ gpt_sovits_export = torch.jit.trace(
674
+ gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
675
+ )
676
+
677
+ gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
678
+ gpt_sovits_export.save(gpt_sovits_path)
679
+ print("#### exported gpt_sovits ####")
680
+
681
+
682
+ @torch.jit.script
683
+ def parse_audio(ref_audio):
684
+ ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
685
+ ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
686
+ return ref_audio_16k, ref_audio_sr
687
+
688
+
689
+ @torch.jit.script
690
+ def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
691
+ return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
692
+
693
+
694
+ class GPT_SoVITS(nn.Module):
695
+ def __init__(self, t2s: T2SModel, vits: VitsModel):
696
+ super().__init__()
697
+ self.t2s = t2s
698
+ self.vits = vits
699
+
700
+ def forward(
701
+ self,
702
+ ssl_content: torch.Tensor,
703
+ ref_audio_sr: torch.Tensor,
704
+ ref_seq: Tensor,
705
+ text_seq: Tensor,
706
+ ref_bert: Tensor,
707
+ text_bert: Tensor,
708
+ top_k: LongTensor,
709
+ speed=1.0,
710
+ ):
711
+ codes = self.vits.vq_model.extract_latent(ssl_content)
712
+ prompt_semantic = codes[0, 0]
713
+ prompts = prompt_semantic.unsqueeze(0)
714
+
715
+ pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
716
+ audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
717
+ return audio
718
+
719
+
720
+ def test():
721
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
722
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
723
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
724
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
725
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
726
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
727
+
728
+ args = parser.parse_args()
729
+ gpt_path = args.gpt_model
730
+ vits_path = args.sovits_model
731
+ ref_audio_path = args.ref_audio
732
+ ref_text = args.ref_text
733
+
734
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
735
+ # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
736
+ # bert = MyBertModel(bert_model)
737
+ my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
738
+
739
+ # dict_s1 = torch.load(gpt_path, map_location="cuda")
740
+ # raw_t2s = get_raw_t2s_model(dict_s1)
741
+ # t2s = T2SModel(raw_t2s)
742
+ # t2s.eval()
743
+ # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
744
+
745
+ # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
746
+ # vits = VitsModel(vits_path)
747
+ # vits.eval()
748
+
749
+ # ssl = ExportSSLModel(SSLModel()).to('cuda')
750
+ # ssl.eval()
751
+ ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
752
+
753
+ # gpt_sovits = GPT_SoVITS(t2s,vits)
754
+ gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
755
+
756
+ ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
757
+ ref_seq = torch.LongTensor([ref_seq_id])
758
+ ref_bert = ref_bert_T.T.to(ref_seq.device)
759
+ # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
760
+ text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
761
+
762
+ text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
763
+
764
+ test_bert = tokenizer(text, return_tensors="pt")
765
+ word2ph = []
766
+ for c in text:
767
+ if c in [",", "。", ":", "?", "?", ",", "."]:
768
+ word2ph.append(1)
769
+ else:
770
+ word2ph.append(2)
771
+ test_bert["word2ph"] = torch.Tensor(word2ph).int()
772
+
773
+ test_bert = my_bert(
774
+ test_bert["input_ids"].to("cuda"),
775
+ test_bert["attention_mask"].to("cuda"),
776
+ test_bert["token_type_ids"].to("cuda"),
777
+ test_bert["word2ph"].to("cuda"),
778
+ )
779
+
780
+ text_seq = torch.LongTensor([text_seq_id])
781
+ text_bert = text_bert_T.T.to(text_seq.device)
782
+
783
+ print("text_bert:", text_bert.shape, text_bert)
784
+ print("test_bert:", test_bert.shape, test_bert)
785
+ print(torch.allclose(text_bert.to("cuda"), test_bert))
786
+
787
+ print("text_seq:", text_seq.shape)
788
+ print("text_bert:", text_bert.shape, text_bert.type())
789
+
790
+ # [1,N]
791
+ ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
792
+ print("ref_audio:", ref_audio.shape)
793
+
794
+ ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
795
+ print("start ssl")
796
+ ssl_content = ssl(ref_audio)
797
+
798
+ print("start gpt_sovits:")
799
+ print("ssl_content:", ssl_content.shape)
800
+ print("ref_audio_sr:", ref_audio_sr.shape)
801
+ print("ref_seq:", ref_seq.shape)
802
+ ref_seq = ref_seq.to("cuda")
803
+ print("text_seq:", text_seq.shape)
804
+ text_seq = text_seq.to("cuda")
805
+ print("ref_bert:", ref_bert.shape)
806
+ ref_bert = ref_bert.to("cuda")
807
+ print("text_bert:", text_bert.shape)
808
+ text_bert = text_bert.to("cuda")
809
+
810
+ top_k = torch.LongTensor([5]).to("cuda")
811
+
812
+ with torch.no_grad():
813
+ audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
814
+ print("start write wav")
815
+ soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
816
+
817
+
818
+ import text
819
+ import json
820
+
821
+
822
+ def export_symbel(version="v2"):
823
+ if version == "v1":
824
+ symbols = text._symbol_to_id_v1
825
+ with open("onnx/symbols_v1.json", "w") as file:
826
+ json.dump(symbols, file, indent=4)
827
+ else:
828
+ symbols = text._symbol_to_id_v2
829
+ with open("onnx/symbols_v2.json", "w") as file:
830
+ json.dump(symbols, file, indent=4)
831
+
832
+
833
+ def main():
834
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
835
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
836
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
837
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
838
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
839
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
840
+ parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
841
+ parser.add_argument("--device", help="Device to use")
842
+
843
+ args = parser.parse_args()
844
+ export(
845
+ gpt_path=args.gpt_model,
846
+ vits_path=args.sovits_model,
847
+ ref_audio_path=args.ref_audio,
848
+ ref_text=args.ref_text,
849
+ output_path=args.output_path,
850
+ device=args.device,
851
+ export_bert_and_ssl=args.export_common_model,
852
+ )
853
+
854
+
855
+ import inference_webui
856
+
857
+ if __name__ == "__main__":
858
+ inference_webui.is_half = False
859
+ inference_webui.dtype = torch.float32
860
+ main()
861
+ # test()
GPT_SoVITS/export_torch_script_v3.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from export_torch_script import (
3
+ T2SModel,
4
+ get_raw_t2s_model,
5
+ resamplex,
6
+ spectrogram_torch,
7
+ )
8
+ from f5_tts.model.backbones.dit import DiT
9
+ from inference_webui import get_phones_and_bert
10
+ import librosa
11
+ from module import commons
12
+ from module.mel_processing import mel_spectrogram_torch
13
+ from module.models_onnx import CFM, SynthesizerTrnV3
14
+ import numpy as np
15
+ import torch._dynamo.config
16
+ import torchaudio
17
+ import logging
18
+ import uvicorn
19
+ import torch
20
+ import soundfile
21
+ from librosa.filters import mel as librosa_mel_fn
22
+
23
+
24
+ from inference_webui import get_spepc, norm_spec, resample, ssl_model
25
+
26
+ logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
27
+ logger = logging.getLogger("uvicorn")
28
+
29
+ is_half = True
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ now_dir = os.getcwd()
32
+
33
+
34
+ class MelSpectrgram(torch.nn.Module):
35
+ def __init__(
36
+ self,
37
+ dtype,
38
+ device,
39
+ n_fft,
40
+ num_mels,
41
+ sampling_rate,
42
+ hop_size,
43
+ win_size,
44
+ fmin,
45
+ fmax,
46
+ center=False,
47
+ ):
48
+ super().__init__()
49
+ self.hann_window = torch.hann_window(1024).to(device=device, dtype=dtype)
50
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
51
+ self.mel_basis = torch.from_numpy(mel).to(dtype=dtype, device=device)
52
+ self.n_fft: int = n_fft
53
+ self.hop_size: int = hop_size
54
+ self.win_size: int = win_size
55
+ self.center: bool = center
56
+
57
+ def forward(self, y):
58
+ y = torch.nn.functional.pad(
59
+ y.unsqueeze(1),
60
+ (
61
+ int((self.n_fft - self.hop_size) / 2),
62
+ int((self.n_fft - self.hop_size) / 2),
63
+ ),
64
+ mode="reflect",
65
+ )
66
+ y = y.squeeze(1)
67
+ spec = torch.stft(
68
+ y,
69
+ self.n_fft,
70
+ hop_length=self.hop_size,
71
+ win_length=self.win_size,
72
+ window=self.hann_window,
73
+ center=self.center,
74
+ pad_mode="reflect",
75
+ normalized=False,
76
+ onesided=True,
77
+ return_complex=False,
78
+ )
79
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
80
+ spec = torch.matmul(self.mel_basis, spec)
81
+ # spec = spectral_normalize_torch(spec)
82
+ spec = torch.log(torch.clamp(spec, min=1e-5))
83
+ return spec
84
+
85
+
86
+ class ExportDitBlocks(torch.nn.Module):
87
+ def __init__(self, dit: DiT):
88
+ super().__init__()
89
+ self.transformer_blocks = dit.transformer_blocks
90
+ self.norm_out = dit.norm_out
91
+ self.proj_out = dit.proj_out
92
+ self.depth = dit.depth
93
+
94
+ def forward(self, x, t, mask, rope):
95
+ for block in self.transformer_blocks:
96
+ x = block(x, t, mask=mask, rope=(rope, 1.0))
97
+ x = self.norm_out(x, t)
98
+ output = self.proj_out(x)
99
+ return output
100
+
101
+
102
+ class ExportDitEmbed(torch.nn.Module):
103
+ def __init__(self, dit: DiT):
104
+ super().__init__()
105
+ self.time_embed = dit.time_embed
106
+ self.d_embed = dit.d_embed
107
+ self.text_embed = dit.text_embed
108
+ self.input_embed = dit.input_embed
109
+ self.rotary_embed = dit.rotary_embed
110
+ self.rotary_embed.inv_freq.to(device)
111
+
112
+ def forward(
113
+ self,
114
+ x0: torch.Tensor, # nosied input audio # noqa: F722
115
+ cond0: torch.Tensor, # masked cond audio # noqa: F722
116
+ x_lens: torch.Tensor,
117
+ time: torch.Tensor, # time step # noqa: F821 F722
118
+ dt_base_bootstrap: torch.Tensor,
119
+ text0: torch.Tensor, # noqa: F722#####condition feature
120
+ ):
121
+ x = x0.transpose(2, 1)
122
+ cond = cond0.transpose(2, 1)
123
+ text = text0.transpose(2, 1)
124
+ mask = commons.sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
125
+
126
+ t = self.time_embed(time) + self.d_embed(dt_base_bootstrap)
127
+ text_embed = self.text_embed(text, x.shape[1])
128
+ rope_t = torch.arange(x.shape[1], device=device)
129
+ rope, _ = self.rotary_embed(rope_t)
130
+ x = self.input_embed(x, cond, text_embed)
131
+ return x, t, mask, rope
132
+
133
+
134
+ class ExportDiT(torch.nn.Module):
135
+ def __init__(self, dit: DiT):
136
+ super().__init__()
137
+ if dit != None:
138
+ self.embed = ExportDitEmbed(dit)
139
+ self.blocks = ExportDitBlocks(dit)
140
+ else:
141
+ self.embed = None
142
+ self.blocks = None
143
+
144
+ def forward( # x, prompt_x, x_lens, t, style,cond
145
+ self, # d is channel,n is T
146
+ x0: torch.Tensor, # nosied input audio # noqa: F722
147
+ cond0: torch.Tensor, # masked cond audio # noqa: F722
148
+ x_lens: torch.Tensor,
149
+ time: torch.Tensor, # time step # noqa: F821 F722
150
+ dt_base_bootstrap: torch.Tensor,
151
+ text0: torch.Tensor, # noqa: F722#####condition feature
152
+ ):
153
+ x, t, mask, rope = self.embed(x0, cond0, x_lens, time, dt_base_bootstrap, text0)
154
+ output = self.blocks(x, t, mask, rope)
155
+ return output
156
+
157
+
158
+ class ExportCFM(torch.nn.Module):
159
+ def __init__(self, cfm: CFM):
160
+ super().__init__()
161
+ self.cfm = cfm
162
+
163
+ def forward(
164
+ self,
165
+ fea_ref: torch.Tensor,
166
+ fea_todo_chunk: torch.Tensor,
167
+ mel2: torch.Tensor,
168
+ sample_steps: torch.LongTensor,
169
+ ):
170
+ T_min = fea_ref.size(2)
171
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
172
+ cfm_res = self.cfm(fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps)
173
+ cfm_res = cfm_res[:, :, mel2.shape[2] :]
174
+ mel2 = cfm_res[:, :, -T_min:]
175
+ fea_ref = fea_todo_chunk[:, :, -T_min:]
176
+ return cfm_res, fea_ref, mel2
177
+
178
+
179
+ mel_fn = lambda x: mel_spectrogram_torch(
180
+ x,
181
+ **{
182
+ "n_fft": 1024,
183
+ "win_size": 1024,
184
+ "hop_size": 256,
185
+ "num_mels": 100,
186
+ "sampling_rate": 24000,
187
+ "fmin": 0,
188
+ "fmax": None,
189
+ "center": False,
190
+ },
191
+ )
192
+
193
+ spec_min = -12
194
+ spec_max = 2
195
+
196
+
197
+ @torch.jit.script
198
+ def norm_spec(x):
199
+ spec_min = -12
200
+ spec_max = 2
201
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
202
+
203
+
204
+ def denorm_spec(x):
205
+ spec_min = -12
206
+ spec_max = 2
207
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
208
+
209
+
210
+ class ExportGPTSovitsHalf(torch.nn.Module):
211
+ def __init__(self, hps, t2s_m: T2SModel, vq_model: SynthesizerTrnV3):
212
+ super().__init__()
213
+ self.hps = hps
214
+ self.t2s_m = t2s_m
215
+ self.vq_model = vq_model
216
+ self.mel2 = MelSpectrgram(
217
+ dtype=torch.float32,
218
+ device=device,
219
+ n_fft=1024,
220
+ num_mels=100,
221
+ sampling_rate=24000,
222
+ hop_size=256,
223
+ win_size=1024,
224
+ fmin=0,
225
+ fmax=None,
226
+ center=False,
227
+ )
228
+ # self.dtype = dtype
229
+ self.filter_length: int = hps.data.filter_length
230
+ self.sampling_rate: int = hps.data.sampling_rate
231
+ self.hop_length: int = hps.data.hop_length
232
+ self.win_length: int = hps.data.win_length
233
+
234
+ def forward(
235
+ self,
236
+ ssl_content,
237
+ ref_audio_32k: torch.FloatTensor,
238
+ phoneme_ids0,
239
+ phoneme_ids1,
240
+ bert1,
241
+ bert2,
242
+ top_k,
243
+ ):
244
+ refer = spectrogram_torch(
245
+ ref_audio_32k,
246
+ self.filter_length,
247
+ self.sampling_rate,
248
+ self.hop_length,
249
+ self.win_length,
250
+ center=False,
251
+ ).to(ssl_content.dtype)
252
+
253
+ codes = self.vq_model.extract_latent(ssl_content)
254
+ prompt_semantic = codes[0, 0]
255
+ prompt = prompt_semantic.unsqueeze(0)
256
+ # print('extract_latent',codes.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
257
+
258
+ pred_semantic = self.t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
259
+ # print('t2s_m',pred_semantic.shape,datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
260
+
261
+ ge = self.vq_model.create_ge(refer)
262
+ # print('create_ge',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
263
+
264
+ prompt_ = prompt.unsqueeze(0)
265
+ fea_ref = self.vq_model(prompt_, phoneme_ids0, ge)
266
+ # print('fea_ref',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
267
+ # print(prompt_.shape, phoneme_ids0.shape, ge.shape)
268
+ # print(fea_ref.shape)
269
+
270
+ ref_24k = resamplex(ref_audio_32k, 32000, 24000)
271
+ mel2 = norm_spec(self.mel2(ref_24k)).to(ssl_content.dtype)
272
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
273
+ mel2 = mel2[:, :, :T_min]
274
+ fea_ref = fea_ref[:, :, :T_min]
275
+ if T_min > 468:
276
+ mel2 = mel2[:, :, -468:]
277
+ fea_ref = fea_ref[:, :, -468:]
278
+ T_min = 468
279
+
280
+ fea_todo = self.vq_model(pred_semantic, phoneme_ids1, ge)
281
+ # print('fea_todo',datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
282
+ # print(pred_semantic.shape, phoneme_ids1.shape, ge.shape)
283
+ # print(fea_todo.shape)
284
+
285
+ return fea_ref, fea_todo, mel2
286
+
287
+
288
+ class GPTSoVITSV3(torch.nn.Module):
289
+ def __init__(self, gpt_sovits_half, cfm, bigvgan):
290
+ super().__init__()
291
+ self.gpt_sovits_half = gpt_sovits_half
292
+ self.cfm = cfm
293
+ self.bigvgan = bigvgan
294
+
295
+ def forward(
296
+ self,
297
+ ssl_content,
298
+ ref_audio_32k: torch.FloatTensor,
299
+ phoneme_ids0: torch.LongTensor,
300
+ phoneme_ids1: torch.LongTensor,
301
+ bert1,
302
+ bert2,
303
+ top_k: torch.LongTensor,
304
+ sample_steps: torch.LongTensor,
305
+ ):
306
+ # current_time = datetime.now()
307
+ # print("gpt_sovits_half",current_time.strftime("%Y-%m-%d %H:%M:%S"))
308
+ fea_ref, fea_todo, mel2 = self.gpt_sovits_half(
309
+ ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
310
+ )
311
+ chunk_len = 934 - fea_ref.shape[2]
312
+ wav_gen_list = []
313
+ idx = 0
314
+ wav_gen_length = fea_todo.shape[2] * 256
315
+ while 1:
316
+ # current_time = datetime.now()
317
+ # print("idx:",idx,current_time.strftime("%Y-%m-%d %H:%M:%S"))
318
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
319
+ if fea_todo_chunk.shape[-1] == 0:
320
+ break
321
+
322
+ # 因为导出的模型在不同shape时会重新编译还是怎么的,会卡顿10s这样,
323
+ # 所以在这里补0让他shape维持不变
324
+ # 但是这样会导致生成的音频长度不对,所以在最后截取一下。
325
+ # 经过 bigvgan 之后音频长度就是 fea_todo.shape[2] * 256
326
+ complete_len = chunk_len - fea_todo_chunk.shape[-1]
327
+ if complete_len != 0:
328
+ fea_todo_chunk = torch.cat(
329
+ [
330
+ fea_todo_chunk,
331
+ torch.zeros(1, 512, complete_len).to(fea_todo_chunk.device).to(fea_todo_chunk.dtype),
332
+ ],
333
+ 2,
334
+ )
335
+
336
+ cfm_res, fea_ref, mel2 = self.cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
337
+ idx += chunk_len
338
+
339
+ cfm_res = denorm_spec(cfm_res)
340
+ bigvgan_res = self.bigvgan(cfm_res)
341
+ wav_gen_list.append(bigvgan_res)
342
+
343
+ wav_gen = torch.cat(wav_gen_list, 2)
344
+ return wav_gen[0][0][:wav_gen_length]
345
+
346
+
347
+ def init_bigvgan():
348
+ global bigvgan_model
349
+ from BigVGAN import bigvgan
350
+
351
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained(
352
+ "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
353
+ use_cuda_kernel=False,
354
+ ) # if True, RuntimeError: Ninja is required to load C++ extensions
355
+ # remove weight norm in the model and set to eval mode
356
+ bigvgan_model.remove_weight_norm()
357
+ bigvgan_model = bigvgan_model.eval()
358
+ if is_half == True:
359
+ bigvgan_model = bigvgan_model.half().to(device)
360
+ else:
361
+ bigvgan_model = bigvgan_model.to(device)
362
+
363
+
364
+ class Sovits:
365
+ def __init__(self, vq_model: SynthesizerTrnV3, cfm: CFM, hps):
366
+ self.vq_model = vq_model
367
+ self.hps = hps
368
+ cfm.estimator = ExportDiT(cfm.estimator)
369
+ self.cfm = cfm
370
+
371
+
372
+ class DictToAttrRecursive(dict):
373
+ def __init__(self, input_dict):
374
+ super().__init__(input_dict)
375
+ for key, value in input_dict.items():
376
+ if isinstance(value, dict):
377
+ value = DictToAttrRecursive(value)
378
+ self[key] = value
379
+ setattr(self, key, value)
380
+
381
+ def __getattr__(self, item):
382
+ try:
383
+ return self[item]
384
+ except KeyError:
385
+ raise AttributeError(f"Attribute {item} not found")
386
+
387
+ def __setattr__(self, key, value):
388
+ if isinstance(value, dict):
389
+ value = DictToAttrRecursive(value)
390
+ super(DictToAttrRecursive, self).__setitem__(key, value)
391
+ super().__setattr__(key, value)
392
+
393
+ def __delattr__(self, item):
394
+ try:
395
+ del self[item]
396
+ except KeyError:
397
+ raise AttributeError(f"Attribute {item} not found")
398
+
399
+
400
+ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
401
+
402
+
403
+ def get_sovits_weights(sovits_path):
404
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
405
+ is_exist_s2gv3 = os.path.exists(path_sovits_v3)
406
+
407
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
408
+ if if_lora_v3 == True and is_exist_s2gv3 == False:
409
+ logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
410
+
411
+ dict_s2 = load_sovits_new(sovits_path)
412
+ hps = dict_s2["config"]
413
+ hps = DictToAttrRecursive(hps)
414
+ hps.model.semantic_frame_rate = "25hz"
415
+ if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
416
+ hps.model.version = "v2" # v3model,v2sybomls
417
+ elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
418
+ hps.model.version = "v1"
419
+ else:
420
+ hps.model.version = "v2"
421
+
422
+ if model_version == "v3":
423
+ hps.model.version = "v3"
424
+
425
+ logger.info(f"hps: {hps}")
426
+
427
+ vq_model = SynthesizerTrnV3(
428
+ hps.data.filter_length // 2 + 1,
429
+ hps.train.segment_size // hps.data.hop_length,
430
+ n_speakers=hps.data.n_speakers,
431
+ **hps.model,
432
+ )
433
+ # init_bigvgan()
434
+ model_version = hps.model.version
435
+ logger.info(f"模型版本: {model_version}")
436
+
437
+ if is_half == True:
438
+ vq_model = vq_model.half().to(device)
439
+ else:
440
+ vq_model = vq_model.to(device)
441
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
442
+ vq_model.eval()
443
+
444
+ cfm = vq_model.cfm
445
+ del vq_model.cfm
446
+
447
+ sovits = Sovits(vq_model, cfm, hps)
448
+ return sovits
449
+
450
+
451
+ logger.info(f"torch version {torch.__version__}")
452
+ # ssl_model = cnhubert.get_model()
453
+ # if is_half:
454
+ # ssl_model = ssl_model.half().to(device)
455
+ # else:
456
+ # ssl_model = ssl_model.to(device)
457
+
458
+
459
+ def export_cfm(
460
+ e_cfm: ExportCFM,
461
+ mu: torch.Tensor,
462
+ x_lens: torch.LongTensor,
463
+ prompt: torch.Tensor,
464
+ n_timesteps: torch.IntTensor,
465
+ temperature=1.0,
466
+ ):
467
+ cfm = e_cfm.cfm
468
+
469
+ B, T = mu.size(0), mu.size(1)
470
+ x = torch.randn([B, cfm.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
471
+ print("x:", x.shape, x.dtype)
472
+ prompt_len = prompt.size(-1)
473
+ prompt_x = torch.zeros_like(x, dtype=mu.dtype)
474
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
475
+ x[..., :prompt_len] = 0.0
476
+ mu = mu.transpose(2, 1)
477
+
478
+ ntimestep = int(n_timesteps)
479
+
480
+ t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
481
+ d = torch.tensor(1.0 / ntimestep, dtype=x.dtype, device=x.device)
482
+
483
+ t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
484
+ d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
485
+
486
+ print(
487
+ "cfm input shapes:",
488
+ x.shape,
489
+ prompt_x.shape,
490
+ x_lens.shape,
491
+ t_tensor.shape,
492
+ d_tensor.shape,
493
+ mu.shape,
494
+ )
495
+
496
+ print("cfm input dtypes:", x.dtype, prompt_x.dtype, x_lens.dtype, t_tensor.dtype, d_tensor.dtype, mu.dtype)
497
+
498
+ estimator: ExportDiT = torch.jit.trace(
499
+ cfm.estimator,
500
+ optimize=True,
501
+ example_inputs=(x, prompt_x, x_lens, t_tensor, d_tensor, mu),
502
+ )
503
+ estimator.save("onnx/ad/estimator.pt")
504
+ # torch.onnx.export(
505
+ # cfm.estimator,
506
+ # (x, prompt_x, x_lens, t_tensor, d_tensor, mu),
507
+ # "onnx/ad/dit.onnx",
508
+ # input_names=["x", "prompt_x", "x_lens", "t", "d", "mu"],
509
+ # output_names=["output"],
510
+ # dynamic_axes={
511
+ # "x": [2],
512
+ # "prompt_x": [2],
513
+ # "mu": [2],
514
+ # },
515
+ # )
516
+ print("save estimator ok")
517
+ cfm.estimator = estimator
518
+ export_cfm = torch.jit.script(e_cfm)
519
+ export_cfm.save("onnx/ad/cfm.pt")
520
+ # sovits.cfm = cfm
521
+ # cfm.save("onnx/ad/cfm.pt")
522
+ return export_cfm
523
+
524
+
525
+ def export():
526
+ sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
527
+
528
+ init_bigvgan()
529
+
530
+ dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
531
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
532
+ print("#### get_raw_t2s_model ####")
533
+ print(raw_t2s.config)
534
+
535
+ if is_half:
536
+ raw_t2s = raw_t2s.half().to(device)
537
+
538
+ t2s_m = T2SModel(raw_t2s)
539
+ t2s_m.eval()
540
+ script_t2s = torch.jit.script(t2s_m).to(device)
541
+
542
+ hps = sovits.hps
543
+ ref_wav_path = "onnx/ad/ref.wav"
544
+ speed = 1.0
545
+ sample_steps = 32
546
+ dtype = torch.float16 if is_half == True else torch.float32
547
+ refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
548
+ zero_wav = np.zeros(
549
+ int(hps.data.sampling_rate * 0.3),
550
+ dtype=np.float16 if is_half == True else np.float32,
551
+ )
552
+
553
+ with torch.no_grad():
554
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
555
+ wav16k = torch.from_numpy(wav16k)
556
+ zero_wav_torch = torch.from_numpy(zero_wav)
557
+
558
+ if is_half == True:
559
+ wav16k = wav16k.half().to(device)
560
+ zero_wav_torch = zero_wav_torch.half().to(device)
561
+ else:
562
+ wav16k = wav16k.to(device)
563
+ zero_wav_torch = zero_wav_torch.to(device)
564
+ wav16k = torch.cat([wav16k, zero_wav_torch])
565
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
566
+ codes = sovits.vq_model.extract_latent(ssl_content)
567
+ prompt_semantic = codes[0, 0]
568
+ prompt = prompt_semantic.unsqueeze(0).to(device)
569
+
570
+ phones1, bert1, norm_text1 = get_phones_and_bert(
571
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
572
+ )
573
+ phones2, bert2, norm_text2 = get_phones_and_bert(
574
+ "这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
575
+ "auto",
576
+ "v3",
577
+ )
578
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
579
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
580
+
581
+ # codes = sovits.vq_model.extract_latent(ssl_content)
582
+ # prompt_semantic = codes[0, 0]
583
+ # prompts = prompt_semantic.unsqueeze(0)
584
+
585
+ top_k = torch.LongTensor([15]).to(device)
586
+ print("topk", top_k)
587
+
588
+ bert1 = bert1.T.to(device)
589
+ bert2 = bert2.T.to(device)
590
+ print(
591
+ prompt.dtype,
592
+ phoneme_ids0.dtype,
593
+ phoneme_ids1.dtype,
594
+ bert1.dtype,
595
+ bert2.dtype,
596
+ top_k.dtype,
597
+ )
598
+ print(
599
+ prompt.shape,
600
+ phoneme_ids0.shape,
601
+ phoneme_ids1.shape,
602
+ bert1.shape,
603
+ bert2.shape,
604
+ top_k.shape,
605
+ )
606
+ pred_semantic = t2s_m(prompt, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k)
607
+
608
+ ge = sovits.vq_model.create_ge(refer)
609
+ prompt_ = prompt.unsqueeze(0)
610
+
611
+ torch._dynamo.mark_dynamic(prompt_, 2)
612
+ torch._dynamo.mark_dynamic(phoneme_ids0, 1)
613
+
614
+ fea_ref = sovits.vq_model(prompt_, phoneme_ids0, ge)
615
+
616
+ inputs = {
617
+ "forward": (prompt_, phoneme_ids0, ge),
618
+ "extract_latent": ssl_content,
619
+ "create_ge": refer,
620
+ }
621
+
622
+ trace_vq_model = torch.jit.trace_module(sovits.vq_model, inputs, optimize=True)
623
+ trace_vq_model.save("onnx/ad/vq_model.pt")
624
+
625
+ print(fea_ref.shape, fea_ref.dtype, ge.shape)
626
+ print(prompt_.shape, phoneme_ids0.shape, ge.shape)
627
+
628
+ # vq_model = torch.jit.trace(
629
+ # sovits.vq_model,
630
+ # optimize=True,
631
+ # # strict=False,
632
+ # example_inputs=(prompt_, phoneme_ids0, ge),
633
+ # )
634
+ # vq_model = sovits.vq_model
635
+ vq_model = trace_vq_model
636
+
637
+ gpt_sovits_half = ExportGPTSovitsHalf(sovits.hps, script_t2s, trace_vq_model)
638
+ torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v3_half.pt")
639
+
640
+ ref_audio, sr = torchaudio.load(ref_wav_path)
641
+ ref_audio = ref_audio.to(device).float()
642
+ if ref_audio.shape[0] == 2:
643
+ ref_audio = ref_audio.mean(0).unsqueeze(0)
644
+ if sr != 24000:
645
+ ref_audio = resample(ref_audio, sr)
646
+ # mel2 = mel_fn(ref_audio)
647
+ mel2 = norm_spec(mel_fn(ref_audio))
648
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
649
+ fea_ref = fea_ref[:, :, :T_min]
650
+ print("fea_ref:", fea_ref.shape, T_min)
651
+ if T_min > 468:
652
+ mel2 = mel2[:, :, -468:]
653
+ fea_ref = fea_ref[:, :, -468:]
654
+ T_min = 468
655
+ chunk_len = 934 - T_min
656
+ mel2 = mel2.to(dtype)
657
+
658
+ # fea_todo, ge = sovits.vq_model(pred_semantic,y_lengths, phoneme_ids1, ge)
659
+ fea_todo = vq_model(pred_semantic, phoneme_ids1, ge)
660
+
661
+ cfm_resss = []
662
+ idx = 0
663
+ sample_steps = torch.LongTensor([sample_steps]).to(device)
664
+ export_cfm_ = ExportCFM(sovits.cfm)
665
+ while 1:
666
+ print("idx:", idx)
667
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
668
+ if fea_todo_chunk.shape[-1] == 0:
669
+ break
670
+
671
+ print(
672
+ "export_cfm:",
673
+ fea_ref.shape,
674
+ fea_todo_chunk.shape,
675
+ mel2.shape,
676
+ sample_steps.shape,
677
+ )
678
+ if idx == 0:
679
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
680
+ export_cfm_ = export_cfm(
681
+ export_cfm_,
682
+ fea,
683
+ torch.LongTensor([fea.size(1)]).to(fea.device),
684
+ mel2,
685
+ sample_steps,
686
+ )
687
+ # torch.onnx.export(
688
+ # export_cfm_,
689
+ # (
690
+ # fea_ref,
691
+ # fea_todo_chunk,
692
+ # mel2,
693
+ # sample_steps,
694
+ # ),
695
+ # "onnx/ad/cfm.onnx",
696
+ # input_names=["fea_ref", "fea_todo_chunk", "mel2", "sample_steps"],
697
+ # output_names=["cfm_res", "fea_ref_", "mel2_"],
698
+ # dynamic_axes={
699
+ # "fea_ref": [2],
700
+ # "fea_todo_chunk": [2],
701
+ # "mel2": [2],
702
+ # },
703
+ # )
704
+
705
+ idx += chunk_len
706
+
707
+ cfm_res, fea_ref, mel2 = export_cfm_(fea_ref, fea_todo_chunk, mel2, sample_steps)
708
+ cfm_resss.append(cfm_res)
709
+ continue
710
+
711
+ cmf_res = torch.cat(cfm_resss, 2)
712
+ cmf_res = denorm_spec(cmf_res).to(device)
713
+ print("cmf_res:", cmf_res.shape, cmf_res.dtype)
714
+ with torch.inference_mode():
715
+ cmf_res_rand = torch.randn(1, 100, 934).to(device).to(dtype)
716
+ torch._dynamo.mark_dynamic(cmf_res_rand, 2)
717
+ bigvgan_model_ = torch.jit.trace(bigvgan_model, optimize=True, example_inputs=(cmf_res_rand,))
718
+ bigvgan_model_.save("onnx/ad/bigvgan_model.pt")
719
+ wav_gen = bigvgan_model(cmf_res)
720
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
721
+ audio = wav_gen[0][0].cpu().detach().numpy()
722
+
723
+ sr = 24000
724
+ soundfile.write("out.export.wav", (audio * 32768).astype(np.int16), sr)
725
+
726
+
727
+ from datetime import datetime
728
+
729
+
730
+ def test_export(
731
+ todo_text,
732
+ gpt_sovits_v3_half,
733
+ cfm,
734
+ bigvgan,
735
+ output,
736
+ ):
737
+ # hps = sovits.hps
738
+ ref_wav_path = "onnx/ad/ref.wav"
739
+ speed = 1.0
740
+ sample_steps = 8
741
+
742
+ dtype = torch.float16 if is_half == True else torch.float32
743
+
744
+ zero_wav = np.zeros(
745
+ int(16000 * 0.3),
746
+ dtype=np.float16 if is_half == True else np.float32,
747
+ )
748
+
749
+ with torch.no_grad():
750
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
751
+ wav16k = torch.from_numpy(wav16k)
752
+ zero_wav_torch = torch.from_numpy(zero_wav)
753
+
754
+ if is_half == True:
755
+ wav16k = wav16k.half().to(device)
756
+ zero_wav_torch = zero_wav_torch.half().to(device)
757
+ else:
758
+ wav16k = wav16k.to(device)
759
+ zero_wav_torch = zero_wav_torch.to(device)
760
+ wav16k = torch.cat([wav16k, zero_wav_torch])
761
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
762
+
763
+ ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
764
+ ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
765
+
766
+ phones1, bert1, norm_text1 = get_phones_and_bert(
767
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
768
+ )
769
+ phones2, bert2, norm_text2 = get_phones_and_bert(
770
+ todo_text,
771
+ "zh",
772
+ "v3",
773
+ )
774
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
775
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
776
+
777
+ bert1 = bert1.T.to(device)
778
+ bert2 = bert2.T.to(device)
779
+ top_k = torch.LongTensor([15]).to(device)
780
+
781
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
782
+ logger.info("start inference %s", current_time)
783
+ print(
784
+ ssl_content.shape,
785
+ ref_audio_32k.shape,
786
+ phoneme_ids0.shape,
787
+ phoneme_ids1.shape,
788
+ bert1.shape,
789
+ bert2.shape,
790
+ top_k.shape,
791
+ )
792
+ fea_ref, fea_todo, mel2 = gpt_sovits_v3_half(
793
+ ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k
794
+ )
795
+ chunk_len = 934 - fea_ref.shape[2]
796
+ print(fea_ref.shape, fea_todo.shape, mel2.shape)
797
+
798
+ cfm_resss = []
799
+ sample_steps = torch.LongTensor([sample_steps])
800
+ idx = 0
801
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
802
+ logger.info("start cfm %s", current_time)
803
+ wav_gen_length = fea_todo.shape[2] * 256
804
+
805
+ while 1:
806
+ current_time = datetime.now()
807
+ print("idx:", idx, current_time.strftime("%Y-%m-%d %H:%M:%S"))
808
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
809
+ if fea_todo_chunk.shape[-1] == 0:
810
+ break
811
+
812
+ complete_len = chunk_len - fea_todo_chunk.shape[-1]
813
+ if complete_len != 0:
814
+ fea_todo_chunk = torch.cat([fea_todo_chunk, torch.zeros(1, 512, complete_len).to(device).to(dtype)], 2)
815
+
816
+ cfm_res, fea_ref, mel2 = cfm(fea_ref, fea_todo_chunk, mel2, sample_steps)
817
+ # if complete_len > 0 :
818
+ # cfm_res = cfm_res[:, :, :-complete_len]
819
+ # fea_ref = fea_ref[:, :, :-complete_len]
820
+ # mel2 = mel2[:, :, :-complete_len]
821
+
822
+ idx += chunk_len
823
+
824
+ current_time = datetime.now()
825
+ print("cfm end", current_time.strftime("%Y-%m-%d %H:%M:%S"))
826
+ cfm_res = denorm_spec(cfm_res).to(device)
827
+ bigvgan_res = bigvgan(cfm_res)
828
+ cfm_resss.append(bigvgan_res)
829
+
830
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
831
+ logger.info("start bigvgan %s", current_time)
832
+ wav_gen = torch.cat(cfm_resss, 2)
833
+ # cmf_res = denorm_spec(cmf_res)
834
+ # cmf_res = cmf_res.to(device)
835
+ # print("cmf_res:", cmf_res.shape)
836
+
837
+ # cmf_res = torch.cat([cmf_res,torch.zeros([1,100,2000-cmf_res.size(2)],device=device,dtype=cmf_res.dtype)], 2)
838
+
839
+ # wav_gen = bigvgan(cmf_res)
840
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
841
+ wav_gen = wav_gen[:, :, :wav_gen_length]
842
+
843
+ audio = wav_gen[0][0].cpu().detach().numpy()
844
+ logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
845
+ sr = 24000
846
+ soundfile.write(output, (audio * 32768).astype(np.int16), sr)
847
+
848
+
849
+ def test_export1(
850
+ todo_text,
851
+ gpt_sovits_v3,
852
+ output,
853
+ ):
854
+ # hps = sovits.hps
855
+ ref_wav_path = "onnx/ad/ref.wav"
856
+ speed = 1.0
857
+ sample_steps = torch.LongTensor([16])
858
+
859
+ dtype = torch.float16 if is_half == True else torch.float32
860
+
861
+ zero_wav = np.zeros(
862
+ int(24000 * 0.3),
863
+ dtype=np.float16 if is_half == True else np.float32,
864
+ )
865
+
866
+ with torch.no_grad():
867
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
868
+ wav16k = torch.from_numpy(wav16k)
869
+ zero_wav_torch = torch.from_numpy(zero_wav)
870
+
871
+ if is_half == True:
872
+ wav16k = wav16k.half().to(device)
873
+ zero_wav_torch = zero_wav_torch.half().to(device)
874
+ else:
875
+ wav16k = wav16k.to(device)
876
+ zero_wav_torch = zero_wav_torch.to(device)
877
+ wav16k = torch.cat([wav16k, zero_wav_torch])
878
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
879
+ print("ssl_content:", ssl_content.shape, ssl_content.dtype)
880
+
881
+ ref_audio_32k, _ = librosa.load(ref_wav_path, sr=32000)
882
+ ref_audio_32k = torch.from_numpy(ref_audio_32k).unsqueeze(0).to(device).float()
883
+
884
+ phones1, bert1, norm_text1 = get_phones_and_bert(
885
+ "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
886
+ )
887
+ phones2, bert2, norm_text2 = get_phones_and_bert(
888
+ todo_text,
889
+ "zh",
890
+ "v3",
891
+ )
892
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
893
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
894
+
895
+ bert1 = bert1.T.to(device)
896
+ bert2 = bert2.T.to(device)
897
+ top_k = torch.LongTensor([15]).to(device)
898
+
899
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
900
+ logger.info("start inference %s", current_time)
901
+ print(
902
+ ssl_content.shape,
903
+ ref_audio_32k.shape,
904
+ phoneme_ids0.shape,
905
+ phoneme_ids1.shape,
906
+ bert1.shape,
907
+ bert2.shape,
908
+ top_k.shape,
909
+ )
910
+ wav_gen = gpt_sovits_v3(ssl_content, ref_audio_32k, phoneme_ids0, phoneme_ids1, bert1, bert2, top_k, sample_steps)
911
+ print("wav_gen:", wav_gen.shape, wav_gen.dtype)
912
+
913
+ wav_gen = torch.cat([wav_gen, zero_wav_torch], 0)
914
+
915
+ audio = wav_gen.cpu().detach().numpy()
916
+ logger.info("end bigvgan %s", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
917
+ sr = 24000
918
+ soundfile.write(output, (audio * 32768).astype(np.int16), sr)
919
+
920
+
921
+ import time
922
+
923
+
924
+ def test_():
925
+ sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
926
+
927
+ # cfm = ExportCFM(sovits.cfm)
928
+ # cfm.cfm.estimator = dit
929
+ sovits.cfm = None
930
+
931
+ cfm = torch.jit.load("onnx/ad/cfm.pt", map_location=device)
932
+ # cfm = torch.jit.optimize_for_inference(cfm)
933
+ cfm = cfm.half().to(device)
934
+
935
+ cfm.eval()
936
+
937
+ logger.info("cfm ok")
938
+
939
+ dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
940
+ # v2 的 gpt 也可以用
941
+ # dict_s1 = torch.load("GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt")
942
+ raw_t2s = get_raw_t2s_model(dict_s1).to(device)
943
+ print("#### get_raw_t2s_model ####")
944
+ print(raw_t2s.config)
945
+ if is_half:
946
+ raw_t2s = raw_t2s.half().to(device)
947
+ t2s_m = T2SModel(raw_t2s).half().to(device)
948
+ t2s_m.eval()
949
+ t2s_m = torch.jit.script(t2s_m)
950
+ t2s_m.eval()
951
+ # t2s_m.top_k = 15
952
+ logger.info("t2s_m ok")
953
+
954
+ vq_model: torch.jit.ScriptModule = torch.jit.load("onnx/ad/vq_model.pt", map_location=device)
955
+ # vq_model = torch.jit.optimize_for_inference(vq_model)
956
+ # vq_model = vq_model.half().to(device)
957
+ vq_model.eval()
958
+ # vq_model = sovits.vq_model
959
+ logger.info("vq_model ok")
960
+
961
+ # gpt_sovits_v3_half = torch.jit.load("onnx/ad/gpt_sovits_v3_half.pt")
962
+ # gpt_sovits_v3_half = torch.jit.optimize_for_inference(gpt_sovits_v3_half)
963
+ # gpt_sovits_v3_half = gpt_sovits_v3_half.half()
964
+ # gpt_sovits_v3_half = gpt_sovits_v3_half.cuda()
965
+ # gpt_sovits_v3_half.eval()
966
+ gpt_sovits_v3_half = ExportGPTSovitsHalf(sovits.hps, t2s_m, vq_model)
967
+ logger.info("gpt_sovits_v3_half ok")
968
+
969
+ # init_bigvgan()
970
+ # global bigvgan_model
971
+ bigvgan_model = torch.jit.load("onnx/ad/bigvgan_model.pt")
972
+ # bigvgan_model = torch.jit.optimize_for_inference(bigvgan_model)
973
+ bigvgan_model = bigvgan_model.half()
974
+ bigvgan_model = bigvgan_model.cuda()
975
+ bigvgan_model.eval()
976
+
977
+ logger.info("bigvgan ok")
978
+
979
+ gpt_sovits_v3 = GPTSoVITSV3(gpt_sovits_v3_half, cfm, bigvgan_model)
980
+ gpt_sovits_v3 = torch.jit.script(gpt_sovits_v3)
981
+ gpt_sovits_v3.save("onnx/ad/gpt_sovits_v3.pt")
982
+ gpt_sovits_v3 = gpt_sovits_v3.half().to(device)
983
+ gpt_sovits_v3.eval()
984
+ print("save gpt_sovits_v3 ok")
985
+
986
+ time.sleep(5)
987
+ # print("thread:", torch.get_num_threads())
988
+ # print("thread:", torch.get_num_interop_threads())
989
+ # torch.set_num_interop_threads(1)
990
+ # torch.set_num_threads(1)
991
+
992
+ test_export1(
993
+ "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
994
+ gpt_sovits_v3,
995
+ "out.wav",
996
+ )
997
+
998
+ test_export1(
999
+ "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
1000
+ gpt_sovits_v3,
1001
+ "out2.wav",
1002
+ )
1003
+
1004
+ # test_export(
1005
+ # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP. 哈哈哈...",
1006
+ # gpt_sovits_v3_half,
1007
+ # cfm,
1008
+ # bigvgan_model,
1009
+ # "out2.wav",
1010
+ # )
1011
+
1012
+
1013
+ def test_export_gpt_sovits_v3():
1014
+ gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
1015
+ # test_export1(
1016
+ # "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
1017
+ # gpt_sovits_v3,
1018
+ # "out3.wav",
1019
+ # )
1020
+ # test_export1(
1021
+ # "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
1022
+ # gpt_sovits_v3,
1023
+ # "out4.wav",
1024
+ # )
1025
+ test_export1(
1026
+ "风萧萧兮易水寒,壮士一去兮不复还.",
1027
+ gpt_sovits_v3,
1028
+ "out5.wav",
1029
+ )
1030
+
1031
+
1032
+ with torch.no_grad():
1033
+ # export()
1034
+ test_()
1035
+ # test_export_gpt_sovits_v3()
GPT_SoVITS/inference_cli.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import soundfile as sf
4
+
5
+ from tools.i18n.i18n import I18nAuto
6
+ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
7
+
8
+ i18n = I18nAuto()
9
+
10
+
11
+ def synthesize(
12
+ GPT_model_path,
13
+ SoVITS_model_path,
14
+ ref_audio_path,
15
+ ref_text_path,
16
+ ref_language,
17
+ target_text_path,
18
+ target_language,
19
+ output_path,
20
+ ):
21
+ # Read reference text
22
+ with open(ref_text_path, "r", encoding="utf-8") as file:
23
+ ref_text = file.read()
24
+
25
+ # Read target text
26
+ with open(target_text_path, "r", encoding="utf-8") as file:
27
+ target_text = file.read()
28
+
29
+ # Change model weights
30
+ change_gpt_weights(gpt_path=GPT_model_path)
31
+ change_sovits_weights(sovits_path=SoVITS_model_path)
32
+
33
+ # Synthesize audio
34
+ synthesis_result = get_tts_wav(
35
+ ref_wav_path=ref_audio_path,
36
+ prompt_text=ref_text,
37
+ prompt_language=i18n(ref_language),
38
+ text=target_text,
39
+ text_language=i18n(target_language),
40
+ top_p=1,
41
+ temperature=1,
42
+ )
43
+
44
+ result_list = list(synthesis_result)
45
+
46
+ if result_list:
47
+ last_sampling_rate, last_audio_data = result_list[-1]
48
+ output_wav_path = os.path.join(output_path, "output.wav")
49
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
50
+ print(f"Audio saved to {output_wav_path}")
51
+
52
+
53
+ def main():
54
+ parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
55
+ parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
56
+ parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
57
+ parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
58
+ parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
59
+ parser.add_argument(
60
+ "--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
61
+ )
62
+ parser.add_argument("--target_text", required=True, help="Path to the target text file")
63
+ parser.add_argument(
64
+ "--target_language",
65
+ required=True,
66
+ choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
67
+ help="Language of the target text",
68
+ )
69
+ parser.add_argument("--output_path", required=True, help="Path to the output directory")
70
+
71
+ args = parser.parse_args()
72
+
73
+ synthesize(
74
+ args.gpt_model,
75
+ args.sovits_model,
76
+ args.ref_audio,
77
+ args.ref_text,
78
+ args.ref_language,
79
+ args.target_text,
80
+ args.target_language,
81
+ args.output_path,
82
+ )
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
GPT_SoVITS/inference_gui.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from PyQt5.QtCore import QEvent
4
+ from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
5
+ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
6
+ import soundfile as sf
7
+
8
+ from tools.i18n.i18n import I18nAuto
9
+
10
+ i18n = I18nAuto()
11
+
12
+ from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
13
+
14
+
15
+ class GPTSoVITSGUI(QMainWindow):
16
+ GPT_Path = gpt_path
17
+ SoVITS_Path = sovits_path
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ self.setWindowTitle("GPT-SoVITS GUI")
23
+ self.setGeometry(800, 450, 950, 850)
24
+
25
+ self.setStyleSheet("""
26
+ QWidget {
27
+ background-color: #a3d3b1;
28
+ }
29
+
30
+ QTabWidget::pane {
31
+ background-color: #a3d3b1;
32
+ }
33
+
34
+ QTabWidget::tab-bar {
35
+ alignment: left;
36
+ }
37
+
38
+ QTabBar::tab {
39
+ background: #8da4bf;
40
+ color: #ffffff;
41
+ padding: 8px;
42
+ }
43
+
44
+ QTabBar::tab:selected {
45
+ background: #2a3f54;
46
+ }
47
+
48
+ QLabel {
49
+ color: #000000;
50
+ }
51
+
52
+ QPushButton {
53
+ background-color: #4CAF50;
54
+ color: white;
55
+ padding: 8px;
56
+ border: 1px solid #4CAF50;
57
+ border-radius: 4px;
58
+ }
59
+
60
+ QPushButton:hover {
61
+ background-color: #45a049;
62
+ border: 1px solid #45a049;
63
+ box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
64
+ }
65
+ """)
66
+
67
+ license_text = (
68
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
69
+ "如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
70
+ )
71
+ license_label = QLabel(license_text)
72
+ license_label.setWordWrap(True)
73
+
74
+ self.GPT_model_label = QLabel("选择GPT模型:")
75
+ self.GPT_model_input = QLineEdit()
76
+ self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
77
+ self.GPT_model_input.setText(self.GPT_Path)
78
+ self.GPT_model_input.setReadOnly(True)
79
+ self.GPT_model_button = QPushButton("选择GPT模型文件")
80
+ self.GPT_model_button.clicked.connect(self.select_GPT_model)
81
+
82
+ self.SoVITS_model_label = QLabel("选择SoVITS模型:")
83
+ self.SoVITS_model_input = QLineEdit()
84
+ self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
85
+ self.SoVITS_model_input.setText(self.SoVITS_Path)
86
+ self.SoVITS_model_input.setReadOnly(True)
87
+ self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
88
+ self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
89
+
90
+ self.ref_audio_label = QLabel("上传参考音频:")
91
+ self.ref_audio_input = QLineEdit()
92
+ self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
93
+ self.ref_audio_input.setReadOnly(True)
94
+ self.ref_audio_button = QPushButton("选择音频文件")
95
+ self.ref_audio_button.clicked.connect(self.select_ref_audio)
96
+
97
+ self.ref_text_label = QLabel("参考音频文本:")
98
+ self.ref_text_input = QLineEdit()
99
+ self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
100
+ self.ref_text_button = QPushButton("上传文本")
101
+ self.ref_text_button.clicked.connect(self.upload_ref_text)
102
+
103
+ self.ref_language_label = QLabel("参考音频语言:")
104
+ self.ref_language_combobox = QComboBox()
105
+ self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
106
+ self.ref_language_combobox.setCurrentText("多语种混合")
107
+
108
+ self.target_text_label = QLabel("合成目标文本:")
109
+ self.target_text_input = QLineEdit()
110
+ self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
111
+ self.target_text_button = QPushButton("上传文本")
112
+ self.target_text_button.clicked.connect(self.upload_target_text)
113
+
114
+ self.target_language_label = QLabel("合成音频语言:")
115
+ self.target_language_combobox = QComboBox()
116
+ self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
117
+ self.target_language_combobox.setCurrentText("多语种混合")
118
+
119
+ self.output_label = QLabel("输出音频路径:")
120
+ self.output_input = QLineEdit()
121
+ self.output_input.setPlaceholderText("拖拽或选择文件")
122
+ self.output_input.setReadOnly(True)
123
+ self.output_button = QPushButton("选择文件夹")
124
+ self.output_button.clicked.connect(self.select_output_path)
125
+
126
+ self.output_text = QTextEdit()
127
+ self.output_text.setReadOnly(True)
128
+
129
+ self.add_drag_drop_events(
130
+ [
131
+ self.GPT_model_input,
132
+ self.SoVITS_model_input,
133
+ self.ref_audio_input,
134
+ self.ref_text_input,
135
+ self.target_text_input,
136
+ self.output_input,
137
+ ]
138
+ )
139
+
140
+ self.synthesize_button = QPushButton("合成")
141
+ self.synthesize_button.clicked.connect(self.synthesize)
142
+
143
+ self.clear_output_button = QPushButton("清空输出")
144
+ self.clear_output_button.clicked.connect(self.clear_output)
145
+
146
+ self.status_bar = QStatusBar()
147
+
148
+ main_layout = QVBoxLayout()
149
+
150
+ input_layout = QGridLayout(self)
151
+ input_layout.setSpacing(10)
152
+
153
+ input_layout.addWidget(license_label, 0, 0, 1, 3)
154
+
155
+ input_layout.addWidget(self.GPT_model_label, 1, 0)
156
+ input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
157
+ input_layout.addWidget(self.GPT_model_button, 2, 2)
158
+
159
+ input_layout.addWidget(self.SoVITS_model_label, 3, 0)
160
+ input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
161
+ input_layout.addWidget(self.SoVITS_model_button, 4, 2)
162
+
163
+ input_layout.addWidget(self.ref_audio_label, 5, 0)
164
+ input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
165
+ input_layout.addWidget(self.ref_audio_button, 6, 2)
166
+
167
+ input_layout.addWidget(self.ref_language_label, 7, 0)
168
+ input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
169
+ input_layout.addWidget(self.ref_text_label, 9, 0)
170
+ input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
171
+ input_layout.addWidget(self.ref_text_button, 10, 2)
172
+
173
+ input_layout.addWidget(self.target_language_label, 11, 0)
174
+ input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
175
+ input_layout.addWidget(self.target_text_label, 13, 0)
176
+ input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
177
+ input_layout.addWidget(self.target_text_button, 14, 2)
178
+
179
+ input_layout.addWidget(self.output_label, 15, 0)
180
+ input_layout.addWidget(self.output_input, 16, 0, 1, 2)
181
+ input_layout.addWidget(self.output_button, 16, 2)
182
+
183
+ main_layout.addLayout(input_layout)
184
+
185
+ output_layout = QVBoxLayout()
186
+ output_layout.addWidget(self.output_text)
187
+ main_layout.addLayout(output_layout)
188
+
189
+ main_layout.addWidget(self.synthesize_button)
190
+
191
+ main_layout.addWidget(self.clear_output_button)
192
+
193
+ main_layout.addWidget(self.status_bar)
194
+
195
+ self.central_widget = QWidget()
196
+ self.central_widget.setLayout(main_layout)
197
+ self.setCentralWidget(self.central_widget)
198
+
199
+ def dragEnterEvent(self, event):
200
+ if event.mimeData().hasUrls():
201
+ event.acceptProposedAction()
202
+
203
+ def dropEvent(self, event):
204
+ if event.mimeData().hasUrls():
205
+ file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
206
+ if len(file_paths) == 1:
207
+ self.update_ref_audio(file_paths[0])
208
+ else:
209
+ self.update_ref_audio(", ".join(file_paths))
210
+
211
+ def add_drag_drop_events(self, widgets):
212
+ for widget in widgets:
213
+ widget.setAcceptDrops(True)
214
+ widget.installEventFilter(self)
215
+
216
+ def eventFilter(self, obj, event):
217
+ if event.type() in (QEvent.DragEnter, QEvent.Drop):
218
+ mime_data = event.mimeData()
219
+ if mime_data.hasUrls():
220
+ event.acceptProposedAction()
221
+
222
+ return super().eventFilter(obj, event)
223
+
224
+ def select_GPT_model(self):
225
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
226
+ if file_path:
227
+ self.GPT_model_input.setText(file_path)
228
+
229
+ def select_SoVITS_model(self):
230
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
231
+ if file_path:
232
+ self.SoVITS_model_input.setText(file_path)
233
+
234
+ def select_ref_audio(self):
235
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
236
+ if file_path:
237
+ self.update_ref_audio(file_path)
238
+
239
+ def upload_ref_text(self):
240
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
241
+ if file_path:
242
+ with open(file_path, "r", encoding="utf-8") as file:
243
+ content = file.read()
244
+ self.ref_text_input.setText(content)
245
+
246
+ def upload_target_text(self):
247
+ file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
248
+ if file_path:
249
+ with open(file_path, "r", encoding="utf-8") as file:
250
+ content = file.read()
251
+ self.target_text_input.setText(content)
252
+
253
+ def select_output_path(self):
254
+ options = QFileDialog.Options()
255
+ options |= QFileDialog.DontUseNativeDialog
256
+ options |= QFileDialog.ShowDirsOnly
257
+
258
+ folder_dialog = QFileDialog()
259
+ folder_dialog.setOptions(options)
260
+ folder_dialog.setFileMode(QFileDialog.Directory)
261
+
262
+ if folder_dialog.exec_():
263
+ folder_path = folder_dialog.selectedFiles()[0]
264
+ self.output_input.setText(folder_path)
265
+
266
+ def update_ref_audio(self, file_path):
267
+ self.ref_audio_input.setText(file_path)
268
+
269
+ def clear_output(self):
270
+ self.output_text.clear()
271
+
272
+ def synthesize(self):
273
+ GPT_model_path = self.GPT_model_input.text()
274
+ SoVITS_model_path = self.SoVITS_model_input.text()
275
+ ref_audio_path = self.ref_audio_input.text()
276
+ language_combobox = self.ref_language_combobox.currentText()
277
+ language_combobox = i18n(language_combobox)
278
+ ref_text = self.ref_text_input.text()
279
+ target_language_combobox = self.target_language_combobox.currentText()
280
+ target_language_combobox = i18n(target_language_combobox)
281
+ target_text = self.target_text_input.text()
282
+ output_path = self.output_input.text()
283
+
284
+ if GPT_model_path != self.GPT_Path:
285
+ change_gpt_weights(gpt_path=GPT_model_path)
286
+ self.GPT_Path = GPT_model_path
287
+ if SoVITS_model_path != self.SoVITS_Path:
288
+ change_sovits_weights(sovits_path=SoVITS_model_path)
289
+ self.SoVITS_Path = SoVITS_model_path
290
+
291
+ synthesis_result = get_tts_wav(
292
+ ref_wav_path=ref_audio_path,
293
+ prompt_text=ref_text,
294
+ prompt_language=language_combobox,
295
+ text=target_text,
296
+ text_language=target_language_combobox,
297
+ )
298
+
299
+ result_list = list(synthesis_result)
300
+
301
+ if result_list:
302
+ last_sampling_rate, last_audio_data = result_list[-1]
303
+ output_wav_path = os.path.join(output_path, "output.wav")
304
+ sf.write(output_wav_path, last_audio_data, last_sampling_rate)
305
+
306
+ result = "Audio saved to " + output_wav_path
307
+
308
+ self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
309
+ self.output_text.append("处理结果:\n" + result)
310
+
311
+
312
+ if __name__ == "__main__":
313
+ app = QApplication(sys.argv)
314
+ mainWin = GPTSoVITSGUI()
315
+ mainWin.show()
316
+ sys.exit(app.exec_())
GPT_SoVITS/inference_webui.py ADDED
@@ -0,0 +1,1280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ """
9
+
10
+ import logging
11
+ import traceback
12
+ import warnings
13
+
14
+ import torchaudio
15
+
16
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
17
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
18
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
19
+ logging.getLogger("httpx").setLevel(logging.ERROR)
20
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
21
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
22
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
23
+ logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
24
+ warnings.simplefilter(action="ignore", category=FutureWarning)
25
+
26
+ import json
27
+ import os
28
+ import re
29
+ import sys
30
+
31
+ import torch
32
+ from text.LangSegmenter import LangSegmenter
33
+
34
+ try:
35
+ import gradio.analytics as analytics
36
+
37
+ analytics.version_check = lambda: None
38
+ except:
39
+ ...
40
+ version = model_version = os.environ.get("version", "v2")
41
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
42
+ path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
43
+ is_exist_s2gv3 = os.path.exists(path_sovits_v3)
44
+ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
45
+ pretrained_sovits_name = [
46
+ "GPT_SoVITS/pretrained_models/s2G488k.pth",
47
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
48
+ "GPT_SoVITS/pretrained_models/s2Gv3.pth",
49
+ "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
50
+ ]
51
+ pretrained_gpt_name = [
52
+ "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
53
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
54
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
55
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
56
+ ]
57
+
58
+
59
+ _ = [[], []]
60
+ for i in range(4):
61
+ if os.path.exists(pretrained_gpt_name[i]):
62
+ _[0].append(pretrained_gpt_name[i])
63
+ if os.path.exists(pretrained_sovits_name[i]):
64
+ _[-1].append(pretrained_sovits_name[i])
65
+ pretrained_gpt_name, pretrained_sovits_name = _
66
+
67
+
68
+ if os.path.exists("./weight.json"):
69
+ pass
70
+ else:
71
+ with open("./weight.json", "w", encoding="utf-8") as file:
72
+ json.dump({"GPT": {}, "SoVITS": {}}, file)
73
+
74
+ with open("./weight.json", "r", encoding="utf-8") as file:
75
+ weight_data = file.read()
76
+ weight_data = json.loads(weight_data)
77
+ gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
78
+ sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
79
+ if isinstance(gpt_path, list):
80
+ gpt_path = gpt_path[0]
81
+ if isinstance(sovits_path, list):
82
+ sovits_path = sovits_path[0]
83
+
84
+ # gpt_path = os.environ.get(
85
+ # "gpt_path", pretrained_gpt_name
86
+ # )
87
+ # sovits_path = os.environ.get("sovits_path", pretrained_sovits_name)
88
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", "GPT_SoVITS/pretrained_models/chinese-hubert-base")
89
+ bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
90
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
91
+ infer_ttswebui = int(infer_ttswebui)
92
+ is_share = os.environ.get("is_share", "False")
93
+ is_share = eval(is_share)
94
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
96
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
97
+ # is_half=False
98
+ punctuation = set(["!", "?", "…", ",", ".", "-", " "])
99
+ import gradio as gr
100
+ import librosa
101
+ import numpy as np
102
+ from feature_extractor import cnhubert
103
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
104
+
105
+ cnhubert.cnhubert_base_path = cnhubert_base_path
106
+
107
+ import random
108
+
109
+ from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
110
+
111
+
112
+ def set_seed(seed):
113
+ if seed == -1:
114
+ seed = random.randint(0, 1000000)
115
+ seed = int(seed)
116
+ random.seed(seed)
117
+ os.environ["PYTHONHASHSEED"] = str(seed)
118
+ np.random.seed(seed)
119
+ torch.manual_seed(seed)
120
+ torch.cuda.manual_seed(seed)
121
+
122
+
123
+ # set_seed(42)
124
+
125
+ from time import time as ttime
126
+
127
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
128
+ from peft import LoraConfig, get_peft_model
129
+ from text import cleaned_text_to_sequence
130
+ from text.cleaner import clean_text
131
+
132
+ from tools.i18n.i18n import I18nAuto, scan_language_list
133
+
134
+ language = os.environ.get("language", "Auto")
135
+ language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
136
+ i18n = I18nAuto(language=language)
137
+
138
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
139
+
140
+ if torch.cuda.is_available():
141
+ device = "cuda"
142
+ else:
143
+ device = "cpu"
144
+
145
+ dict_language_v1 = {
146
+ i18n("中文"): "all_zh", # 全部按中文识别
147
+ i18n("英文"): "en", # 全部按英文识别#######不变
148
+ i18n("日文"): "all_ja", # 全部按日文识别
149
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
150
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
151
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
152
+ }
153
+ dict_language_v2 = {
154
+ i18n("中文"): "all_zh", # 全部按中文识别
155
+ i18n("英文"): "en", # 全部按英文识别#######不变
156
+ i18n("日文"): "all_ja", # 全部按日文识别
157
+ i18n("粤语"): "all_yue", # 全部按中文识别
158
+ i18n("韩文"): "all_ko", # 全部按韩文识别
159
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
160
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
161
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
162
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
163
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
164
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
165
+ }
166
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
167
+
168
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
169
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
170
+ if is_half == True:
171
+ bert_model = bert_model.half().to(device)
172
+ else:
173
+ bert_model = bert_model.to(device)
174
+
175
+
176
+ def get_bert_feature(text, word2ph):
177
+ with torch.no_grad():
178
+ inputs = tokenizer(text, return_tensors="pt")
179
+ for i in inputs:
180
+ inputs[i] = inputs[i].to(device)
181
+ res = bert_model(**inputs, output_hidden_states=True)
182
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
183
+ assert len(word2ph) == len(text)
184
+ phone_level_feature = []
185
+ for i in range(len(word2ph)):
186
+ repeat_feature = res[i].repeat(word2ph[i], 1)
187
+ phone_level_feature.append(repeat_feature)
188
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
189
+ return phone_level_feature.T
190
+
191
+
192
+ class DictToAttrRecursive(dict):
193
+ def __init__(self, input_dict):
194
+ super().__init__(input_dict)
195
+ for key, value in input_dict.items():
196
+ if isinstance(value, dict):
197
+ value = DictToAttrRecursive(value)
198
+ self[key] = value
199
+ setattr(self, key, value)
200
+
201
+ def __getattr__(self, item):
202
+ try:
203
+ return self[item]
204
+ except KeyError:
205
+ raise AttributeError(f"Attribute {item} not found")
206
+
207
+ def __setattr__(self, key, value):
208
+ if isinstance(value, dict):
209
+ value = DictToAttrRecursive(value)
210
+ super(DictToAttrRecursive, self).__setitem__(key, value)
211
+ super().__setattr__(key, value)
212
+
213
+ def __delattr__(self, item):
214
+ try:
215
+ del self[item]
216
+ except KeyError:
217
+ raise AttributeError(f"Attribute {item} not found")
218
+
219
+
220
+ ssl_model = cnhubert.get_model()
221
+ if is_half == True:
222
+ ssl_model = ssl_model.half().to(device)
223
+ else:
224
+ ssl_model = ssl_model.to(device)
225
+
226
+ resample_transform_dict = {}
227
+
228
+
229
+ def resample(audio_tensor, sr0,sr1):
230
+ global resample_transform_dict
231
+ key="%s-%s"%(sr0,sr1)
232
+ if key not in resample_transform_dict:
233
+ resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
234
+ return resample_transform_dict[key](audio_tensor)
235
+
236
+
237
+ ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
238
+ # symbol_version-model_version-if_lora_v3
239
+ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
240
+
241
+ v3v4set={"v3","v4"}
242
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
243
+ global vq_model, hps, version, model_version, dict_language, if_lora_v3
244
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
245
+ print(sovits_path,version, model_version, if_lora_v3)
246
+ is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
247
+ if if_lora_v3 == True and is_exist == False:
248
+ info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
249
+ gr.Warning(info)
250
+ raise FileExistsError(info)
251
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
252
+ if prompt_language is not None and text_language is not None:
253
+ if prompt_language in list(dict_language.keys()):
254
+ prompt_text_update, prompt_language_update = (
255
+ {"__type__": "update"},
256
+ {"__type__": "update", "value": prompt_language},
257
+ )
258
+ else:
259
+ prompt_text_update = {"__type__": "update", "value": ""}
260
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
261
+ if text_language in list(dict_language.keys()):
262
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
263
+ else:
264
+ text_update = {"__type__": "update", "value": ""}
265
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
266
+ if model_version in v3v4set:
267
+ visible_sample_steps = True
268
+ visible_inp_refs = False
269
+ else:
270
+ visible_sample_steps = False
271
+ visible_inp_refs = True
272
+ yield (
273
+ {"__type__": "update", "choices": list(dict_language.keys())},
274
+ {"__type__": "update", "choices": list(dict_language.keys())},
275
+ prompt_text_update,
276
+ prompt_language_update,
277
+ text_update,
278
+ text_language_update,
279
+ {"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
280
+ {"__type__": "update", "visible": visible_inp_refs},
281
+ {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
282
+ {"__type__": "update", "visible": True if model_version =="v3" else False},
283
+ {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
284
+ )
285
+
286
+ dict_s2 = load_sovits_new(sovits_path)
287
+ hps = dict_s2["config"]
288
+ hps = DictToAttrRecursive(hps)
289
+ hps.model.semantic_frame_rate = "25hz"
290
+ if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
291
+ hps.model.version = "v2" # v3model,v2sybomls
292
+ elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
293
+ hps.model.version = "v1"
294
+ else:
295
+ hps.model.version = "v2"
296
+ version = hps.model.version
297
+ # print("sovits版本:",hps.model.version)
298
+ if model_version not in v3v4set:
299
+ vq_model = SynthesizerTrn(
300
+ hps.data.filter_length // 2 + 1,
301
+ hps.train.segment_size // hps.data.hop_length,
302
+ n_speakers=hps.data.n_speakers,
303
+ **hps.model,
304
+ )
305
+ model_version = version
306
+ else:
307
+ vq_model = SynthesizerTrnV3(
308
+ hps.data.filter_length // 2 + 1,
309
+ hps.train.segment_size // hps.data.hop_length,
310
+ n_speakers=hps.data.n_speakers,
311
+ **hps.model,
312
+ )
313
+ if "pretrained" not in sovits_path:
314
+ try:
315
+ del vq_model.enc_q
316
+ except:
317
+ pass
318
+ if is_half == True:
319
+ vq_model = vq_model.half().to(device)
320
+ else:
321
+ vq_model = vq_model.to(device)
322
+ vq_model.eval()
323
+ if if_lora_v3 == False:
324
+ print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
325
+ else:
326
+ path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
327
+ print(
328
+ "loading sovits_%spretrained_G"%model_version,
329
+ vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
330
+ )
331
+ lora_rank = dict_s2["lora_rank"]
332
+ lora_config = LoraConfig(
333
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
334
+ r=lora_rank,
335
+ lora_alpha=lora_rank,
336
+ init_lora_weights=True,
337
+ )
338
+ vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
339
+ print("loading sovits_%s_lora%s" % (model_version,lora_rank))
340
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
341
+ vq_model.cfm = vq_model.cfm.merge_and_unload()
342
+ # torch.save(vq_model.state_dict(),"merge_win.pth")
343
+ vq_model.eval()
344
+
345
+ yield (
346
+ {"__type__": "update", "choices": list(dict_language.keys())},
347
+ {"__type__": "update", "choices": list(dict_language.keys())},
348
+ prompt_text_update,
349
+ prompt_language_update,
350
+ text_update,
351
+ text_language_update,
352
+ {"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
353
+ {"__type__": "update", "visible": visible_inp_refs},
354
+ {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
355
+ {"__type__": "update", "visible": True if model_version =="v3" else False},
356
+ {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
357
+ )
358
+ with open("./weight.json") as f:
359
+ data = f.read()
360
+ data = json.loads(data)
361
+ data["SoVITS"][version] = sovits_path
362
+ with open("./weight.json", "w") as f:
363
+ f.write(json.dumps(data))
364
+
365
+
366
+ try:
367
+ next(change_sovits_weights(sovits_path))
368
+ except:
369
+ pass
370
+
371
+
372
+ def change_gpt_weights(gpt_path):
373
+ global hz, max_sec, t2s_model, config
374
+ hz = 50
375
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
376
+ config = dict_s1["config"]
377
+ max_sec = config["data"]["max_sec"]
378
+ t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
379
+ t2s_model.load_state_dict(dict_s1["weight"])
380
+ if is_half == True:
381
+ t2s_model = t2s_model.half()
382
+ t2s_model = t2s_model.to(device)
383
+ t2s_model.eval()
384
+ # total = sum([param.nelement() for param in t2s_model.parameters()])
385
+ # print("Number of parameter: %.2fM" % (total / 1e6))
386
+ with open("./weight.json") as f:
387
+ data = f.read()
388
+ data = json.loads(data)
389
+ data["GPT"][version] = gpt_path
390
+ with open("./weight.json", "w") as f:
391
+ f.write(json.dumps(data))
392
+
393
+
394
+ change_gpt_weights(gpt_path)
395
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
396
+ import torch
397
+
398
+ now_dir = os.getcwd()
399
+
400
+
401
+ def init_bigvgan():
402
+ global bigvgan_model,hifigan_model
403
+ from BigVGAN import bigvgan
404
+
405
+ bigvgan_model = bigvgan.BigVGAN.from_pretrained(
406
+ "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
407
+ use_cuda_kernel=False,
408
+ ) # if True, RuntimeError: Ninja is required to load C++ extensions
409
+ # remove weight norm in the model and set to eval mode
410
+ bigvgan_model.remove_weight_norm()
411
+ bigvgan_model = bigvgan_model.eval()
412
+ if hifigan_model:
413
+ hifigan_model=hifigan_model.cpu()
414
+ hifigan_model=None
415
+ try:torch.cuda.empty_cache()
416
+ except:pass
417
+ if is_half == True:
418
+ bigvgan_model = bigvgan_model.half().to(device)
419
+ else:
420
+ bigvgan_model = bigvgan_model.to(device)
421
+
422
+ def init_hifigan():
423
+ global hifigan_model,bigvgan_model
424
+ hifigan_model = Generator(
425
+ initial_channel=100,
426
+ resblock="1",
427
+ resblock_kernel_sizes=[3, 7, 11],
428
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
429
+ upsample_rates=[10, 6, 2, 2, 2],
430
+ upsample_initial_channel=512,
431
+ upsample_kernel_sizes=[20, 12, 4, 4, 4],
432
+ gin_channels=0, is_bias=True
433
+ )
434
+ hifigan_model.eval()
435
+ hifigan_model.remove_weight_norm()
436
+ state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
437
+ print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
438
+ if bigvgan_model:
439
+ bigvgan_model=bigvgan_model.cpu()
440
+ bigvgan_model=None
441
+ try:torch.cuda.empty_cache()
442
+ except:pass
443
+ if is_half == True:
444
+ hifigan_model = hifigan_model.half().to(device)
445
+ else:
446
+ hifigan_model = hifigan_model.to(device)
447
+
448
+ bigvgan_model=hifigan_model=None
449
+ if model_version=="v3":
450
+ init_bigvgan()
451
+ if model_version=="v4":
452
+ init_hifigan()
453
+
454
+
455
+ def get_spepc(hps, filename):
456
+ # audio = load_audio(filename, int(hps.data.sampling_rate))
457
+ audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate))
458
+ audio = torch.FloatTensor(audio)
459
+ maxx = audio.abs().max()
460
+ if maxx > 1:
461
+ audio /= min(2, maxx)
462
+ audio_norm = audio
463
+ audio_norm = audio_norm.unsqueeze(0)
464
+ spec = spectrogram_torch(
465
+ audio_norm,
466
+ hps.data.filter_length,
467
+ hps.data.sampling_rate,
468
+ hps.data.hop_length,
469
+ hps.data.win_length,
470
+ center=False,
471
+ )
472
+ return spec
473
+
474
+
475
+ def clean_text_inf(text, language, version):
476
+ language = language.replace("all_", "")
477
+ phones, word2ph, norm_text = clean_text(text, language, version)
478
+ phones = cleaned_text_to_sequence(phones, version)
479
+ return phones, word2ph, norm_text
480
+
481
+
482
+ dtype = torch.float16 if is_half == True else torch.float32
483
+
484
+
485
+ def get_bert_inf(phones, word2ph, norm_text, language):
486
+ language = language.replace("all_", "")
487
+ if language == "zh":
488
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
489
+ else:
490
+ bert = torch.zeros(
491
+ (1024, len(phones)),
492
+ dtype=torch.float16 if is_half == True else torch.float32,
493
+ ).to(device)
494
+
495
+ return bert
496
+
497
+
498
+ splits = {
499
+ ",",
500
+ "。",
501
+ "?",
502
+ "!",
503
+ ",",
504
+ ".",
505
+ "?",
506
+ "!",
507
+ "~",
508
+ ":",
509
+ ":",
510
+ "—",
511
+ "…",
512
+ }
513
+
514
+
515
+ def get_first(text):
516
+ pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
517
+ text = re.split(pattern, text)[0].strip()
518
+ return text
519
+
520
+
521
+ from text import chinese
522
+
523
+
524
+ def get_phones_and_bert(text, language, version, final=False):
525
+ if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
526
+ formattext = text
527
+ while " " in formattext:
528
+ formattext = formattext.replace(" ", " ")
529
+ if language == "all_zh":
530
+ if re.search(r"[A-Za-z]", formattext):
531
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
532
+ formattext = chinese.mix_text_normalize(formattext)
533
+ return get_phones_and_bert(formattext, "zh", version)
534
+ else:
535
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
536
+ bert = get_bert_feature(norm_text, word2ph).to(device)
537
+ elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
538
+ formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
539
+ formattext = chinese.mix_text_normalize(formattext)
540
+ return get_phones_and_bert(formattext, "yue", version)
541
+ else:
542
+ phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
543
+ bert = torch.zeros(
544
+ (1024, len(phones)),
545
+ dtype=torch.float16 if is_half == True else torch.float32,
546
+ ).to(device)
547
+ elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
548
+ textlist = []
549
+ langlist = []
550
+ if language == "auto":
551
+ for tmp in LangSegmenter.getTexts(text):
552
+ langlist.append(tmp["lang"])
553
+ textlist.append(tmp["text"])
554
+ elif language == "auto_yue":
555
+ for tmp in LangSegmenter.getTexts(text):
556
+ if tmp["lang"] == "zh":
557
+ tmp["lang"] = "yue"
558
+ langlist.append(tmp["lang"])
559
+ textlist.append(tmp["text"])
560
+ else:
561
+ for tmp in LangSegmenter.getTexts(text):
562
+ if tmp["lang"] == "en":
563
+ langlist.append(tmp["lang"])
564
+ else:
565
+ # 因无法区别中日韩文汉字,以用户输入为准
566
+ langlist.append(language)
567
+ textlist.append(tmp["text"])
568
+ print(textlist)
569
+ print(langlist)
570
+ phones_list = []
571
+ bert_list = []
572
+ norm_text_list = []
573
+ for i in range(len(textlist)):
574
+ lang = langlist[i]
575
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
576
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
577
+ phones_list.append(phones)
578
+ norm_text_list.append(norm_text)
579
+ bert_list.append(bert)
580
+ bert = torch.cat(bert_list, dim=1)
581
+ phones = sum(phones_list, [])
582
+ norm_text = "".join(norm_text_list)
583
+
584
+ if not final and len(phones) < 6:
585
+ return get_phones_and_bert("." + text, language, version, final=True)
586
+
587
+ return phones, bert.to(dtype), norm_text
588
+
589
+
590
+ from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
591
+
592
+ spec_min = -12
593
+ spec_max = 2
594
+
595
+
596
+ def norm_spec(x):
597
+ return (x - spec_min) / (spec_max - spec_min) * 2 - 1
598
+
599
+
600
+ def denorm_spec(x):
601
+ return (x + 1) / 2 * (spec_max - spec_min) + spec_min
602
+
603
+
604
+ mel_fn = lambda x: mel_spectrogram_torch(
605
+ x,
606
+ **{
607
+ "n_fft": 1024,
608
+ "win_size": 1024,
609
+ "hop_size": 256,
610
+ "num_mels": 100,
611
+ "sampling_rate": 24000,
612
+ "fmin": 0,
613
+ "fmax": None,
614
+ "center": False,
615
+ },
616
+ )
617
+ mel_fn_v4 = lambda x: mel_spectrogram_torch(
618
+ x,
619
+ **{
620
+ "n_fft": 1280,
621
+ "win_size": 1280,
622
+ "hop_size": 320,
623
+ "num_mels": 100,
624
+ "sampling_rate": 32000,
625
+ "fmin": 0,
626
+ "fmax": None,
627
+ "center": False,
628
+ },
629
+ )
630
+
631
+
632
+ def merge_short_text_in_array(texts, threshold):
633
+ if (len(texts)) < 2:
634
+ return texts
635
+ result = []
636
+ text = ""
637
+ for ele in texts:
638
+ text += ele
639
+ if len(text) >= threshold:
640
+ result.append(text)
641
+ text = ""
642
+ if len(text) > 0:
643
+ if len(result) == 0:
644
+ result.append(text)
645
+ else:
646
+ result[len(result) - 1] += text
647
+ return result
648
+
649
+
650
+ sr_model = None
651
+
652
+
653
+ def audio_sr(audio, sr):
654
+ global sr_model
655
+ if sr_model == None:
656
+ from tools.audio_sr import AP_BWE
657
+
658
+ try:
659
+ sr_model = AP_BWE(device, DictToAttrRecursive)
660
+ except FileNotFoundError:
661
+ gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
662
+ return audio.cpu().detach().numpy(), sr
663
+ return sr_model(audio, sr)
664
+
665
+
666
+ ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
667
+ # cache_tokens={}#暂未实现清理机制
668
+ cache = {}
669
+
670
+
671
+ def get_tts_wav(
672
+ ref_wav_path,
673
+ prompt_text,
674
+ prompt_language,
675
+ text,
676
+ text_language,
677
+ how_to_cut=i18n("不切"),
678
+ top_k=20,
679
+ top_p=0.6,
680
+ temperature=0.6,
681
+ ref_free=False,
682
+ speed=1,
683
+ if_freeze=False,
684
+ inp_refs=None,
685
+ sample_steps=8,
686
+ if_sr=False,
687
+ pause_second=0.3,
688
+ ):
689
+ global cache
690
+ if ref_wav_path:
691
+ pass
692
+ else:
693
+ gr.Warning(i18n("请上传参考音频"))
694
+ if text:
695
+ pass
696
+ else:
697
+ gr.Warning(i18n("请填入推理文本"))
698
+ t = []
699
+ if prompt_text is None or len(prompt_text) == 0:
700
+ ref_free = True
701
+ if model_version in v3v4set:
702
+ ref_free = False # s2v3暂不支持ref_free
703
+ else:
704
+ if_sr = False
705
+ t0 = ttime()
706
+ prompt_language = dict_language[prompt_language]
707
+ text_language = dict_language[text_language]
708
+
709
+ if not ref_free:
710
+ prompt_text = prompt_text.strip("\n")
711
+ if prompt_text[-1] not in splits:
712
+ prompt_text += "。" if prompt_language != "en" else "."
713
+ print(i18n("实际输入的参考文本:"), prompt_text)
714
+ text = text.strip("\n")
715
+ # if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
716
+
717
+ print(i18n("实际输入的目标文本:"), text)
718
+ zero_wav = np.zeros(
719
+ int(hps.data.sampling_rate * pause_second),
720
+ dtype=np.float16 if is_half == True else np.float32,
721
+ )
722
+ zero_wav_torch = torch.from_numpy(zero_wav)
723
+ if is_half == True:
724
+ zero_wav_torch = zero_wav_torch.half().to(device)
725
+ else:
726
+ zero_wav_torch = zero_wav_torch.to(device)
727
+ if not ref_free:
728
+ with torch.no_grad():
729
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000)
730
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
731
+ gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
732
+ raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
733
+ wav16k = torch.from_numpy(wav16k)
734
+ if is_half == True:
735
+ wav16k = wav16k.half().to(device)
736
+ else:
737
+ wav16k = wav16k.to(device)
738
+ wav16k = torch.cat([wav16k, zero_wav_torch])
739
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
740
+ codes = vq_model.extract_latent(ssl_content)
741
+ prompt_semantic = codes[0, 0]
742
+ prompt = prompt_semantic.unsqueeze(0).to(device)
743
+
744
+ t1 = ttime()
745
+ t.append(t1 - t0)
746
+
747
+ if how_to_cut == i18n("凑四句一切"):
748
+ text = cut1(text)
749
+ elif how_to_cut == i18n("凑50字一切"):
750
+ text = cut2(text)
751
+ elif how_to_cut == i18n("按中文句号。切"):
752
+ text = cut3(text)
753
+ elif how_to_cut == i18n("按英文句号.切"):
754
+ text = cut4(text)
755
+ elif how_to_cut == i18n("按标点符号切"):
756
+ text = cut5(text)
757
+ while "\n\n" in text:
758
+ text = text.replace("\n\n", "\n")
759
+ print(i18n("实际输入的目标文本(切句后):"), text)
760
+ texts = text.split("\n")
761
+ texts = process_text(texts)
762
+ texts = merge_short_text_in_array(texts, 5)
763
+ audio_opt = []
764
+ ###s2v3暂不支持ref_free
765
+ if not ref_free:
766
+ phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
767
+
768
+ for i_text, text in enumerate(texts):
769
+ # 解决输入目标文本的空行导致报错的问题
770
+ if len(text.strip()) == 0:
771
+ continue
772
+ if text[-1] not in splits:
773
+ text += "。" if text_language != "en" else "."
774
+ print(i18n("实际输入的目标文本(每句):"), text)
775
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
776
+ print(i18n("前端处理后的文本(每句):"), norm_text2)
777
+ if not ref_free:
778
+ bert = torch.cat([bert1, bert2], 1)
779
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
780
+ else:
781
+ bert = bert2
782
+ all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
783
+
784
+ bert = bert.to(device).unsqueeze(0)
785
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
786
+
787
+ t2 = ttime()
788
+ # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
789
+ # print(cache.keys(),if_freeze)
790
+ if i_text in cache and if_freeze == True:
791
+ pred_semantic = cache[i_text]
792
+ else:
793
+ with torch.no_grad():
794
+ pred_semantic, idx = t2s_model.model.infer_panel(
795
+ all_phoneme_ids,
796
+ all_phoneme_len,
797
+ None if ref_free else prompt,
798
+ bert,
799
+ # prompt_phone_len=ph_offset,
800
+ top_k=top_k,
801
+ top_p=top_p,
802
+ temperature=temperature,
803
+ early_stop_num=hz * max_sec,
804
+ )
805
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
806
+ cache[i_text] = pred_semantic
807
+ t3 = ttime()
808
+ ###v3不存在以下逻辑和inp_refs
809
+ if model_version not in v3v4set:
810
+ refers = []
811
+ if inp_refs:
812
+ for path in inp_refs:
813
+ try:
814
+ refer = get_spepc(hps, path.name).to(dtype).to(device)
815
+ refers.append(refer)
816
+ except:
817
+ traceback.print_exc()
818
+ if len(refers) == 0:
819
+ refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
820
+ audio = vq_model.decode(
821
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
822
+ )[0][0] # .cpu().detach().numpy()
823
+ else:
824
+ refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
825
+ phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
826
+ phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
827
+ # print(11111111, phoneme_ids0, phoneme_ids1)
828
+ fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
829
+ ref_audio, sr = torchaudio.load(ref_wav_path)
830
+ ref_audio = ref_audio.to(device).float()
831
+ if ref_audio.shape[0] == 2:
832
+ ref_audio = ref_audio.mean(0).unsqueeze(0)
833
+ tgt_sr=24000 if model_version=="v3"else 32000
834
+ if sr != tgt_sr:
835
+ ref_audio = resample(ref_audio, sr,tgt_sr)
836
+ # print("ref_audio",ref_audio.abs().mean())
837
+ mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
838
+ mel2 = norm_spec(mel2)
839
+ T_min = min(mel2.shape[2], fea_ref.shape[2])
840
+ mel2 = mel2[:, :, :T_min]
841
+ fea_ref = fea_ref[:, :, :T_min]
842
+ Tref=468 if model_version=="v3"else 500
843
+ Tchunk=934 if model_version=="v3"else 1000
844
+ if T_min > Tref:
845
+ mel2 = mel2[:, :, -Tref:]
846
+ fea_ref = fea_ref[:, :, -Tref:]
847
+ T_min = Tref
848
+ chunk_len = Tchunk - T_min
849
+ mel2 = mel2.to(dtype)
850
+ fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
851
+ cfm_resss = []
852
+ idx = 0
853
+ while 1:
854
+ fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len]
855
+ if fea_todo_chunk.shape[-1] == 0:
856
+ break
857
+ idx += chunk_len
858
+ fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
859
+ cfm_res = vq_model.cfm.inference(
860
+ fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
861
+ )
862
+ cfm_res = cfm_res[:, :, mel2.shape[2] :]
863
+ mel2 = cfm_res[:, :, -T_min:]
864
+ fea_ref = fea_todo_chunk[:, :, -T_min:]
865
+ cfm_resss.append(cfm_res)
866
+ cfm_res = torch.cat(cfm_resss, 2)
867
+ cfm_res = denorm_spec(cfm_res)
868
+ if model_version=="v3":
869
+ if bigvgan_model == None:
870
+ init_bigvgan()
871
+ else:#v4
872
+ if hifigan_model == None:
873
+ init_hifigan()
874
+ vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
875
+ with torch.inference_mode():
876
+ wav_gen = vocoder_model(cfm_res)
877
+ audio = wav_gen[0][0] # .cpu().detach().numpy()
878
+ max_audio = torch.abs(audio).max() # 简单防止16bit爆音
879
+ if max_audio > 1:
880
+ audio = audio / max_audio
881
+ audio_opt.append(audio)
882
+ audio_opt.append(zero_wav_torch) # zero_wav
883
+ t4 = ttime()
884
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
885
+ t1 = ttime()
886
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
887
+ audio_opt = torch.cat(audio_opt, 0) # np.concatenate
888
+ if model_version in {"v1","v2"}:opt_sr=32000
889
+ elif model_version=="v3":opt_sr=24000
890
+ else:opt_sr=48000#v4
891
+ if if_sr == True and opt_sr == 24000:
892
+ print(i18n("音频超分中"))
893
+ audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
894
+ max_audio = np.abs(audio_opt).max()
895
+ if max_audio > 1:
896
+ audio_opt /= max_audio
897
+ else:
898
+ audio_opt = audio_opt.cpu().detach().numpy()
899
+ yield opt_sr, (audio_opt * 32767).astype(np.int16)
900
+
901
+
902
+ def split(todo_text):
903
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
904
+ if todo_text[-1] not in splits:
905
+ todo_text += "。"
906
+ i_split_head = i_split_tail = 0
907
+ len_text = len(todo_text)
908
+ todo_texts = []
909
+ while 1:
910
+ if i_split_head >= len_text:
911
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
912
+ if todo_text[i_split_head] in splits:
913
+ i_split_head += 1
914
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
915
+ i_split_tail = i_split_head
916
+ else:
917
+ i_split_head += 1
918
+ return todo_texts
919
+
920
+
921
+ def cut1(inp):
922
+ inp = inp.strip("\n")
923
+ inps = split(inp)
924
+ split_idx = list(range(0, len(inps), 4))
925
+ split_idx[-1] = None
926
+ if len(split_idx) > 1:
927
+ opts = []
928
+ for idx in range(len(split_idx) - 1):
929
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
930
+ else:
931
+ opts = [inp]
932
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
933
+ return "\n".join(opts)
934
+
935
+
936
+ def cut2(inp):
937
+ inp = inp.strip("\n")
938
+ inps = split(inp)
939
+ if len(inps) < 2:
940
+ return inp
941
+ opts = []
942
+ summ = 0
943
+ tmp_str = ""
944
+ for i in range(len(inps)):
945
+ summ += len(inps[i])
946
+ tmp_str += inps[i]
947
+ if summ > 50:
948
+ summ = 0
949
+ opts.append(tmp_str)
950
+ tmp_str = ""
951
+ if tmp_str != "":
952
+ opts.append(tmp_str)
953
+ # print(opts)
954
+ if len(opts) > 1 and len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
955
+ opts[-2] = opts[-2] + opts[-1]
956
+ opts = opts[:-1]
957
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
958
+ return "\n".join(opts)
959
+
960
+
961
+ def cut3(inp):
962
+ inp = inp.strip("\n")
963
+ opts = ["%s" % item for item in inp.strip("。").split("。")]
964
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
965
+ return "\n".join(opts)
966
+
967
+
968
+ def cut4(inp):
969
+ inp = inp.strip("\n")
970
+ opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
971
+ opts = [item for item in opts if not set(item).issubset(punctuation)]
972
+ return "\n".join(opts)
973
+
974
+
975
+ # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
976
+ def cut5(inp):
977
+ inp = inp.strip("\n")
978
+ punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
979
+ mergeitems = []
980
+ items = []
981
+
982
+ for i, char in enumerate(inp):
983
+ if char in punds:
984
+ if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
985
+ items.append(char)
986
+ else:
987
+ items.append(char)
988
+ mergeitems.append("".join(items))
989
+ items = []
990
+ else:
991
+ items.append(char)
992
+
993
+ if items:
994
+ mergeitems.append("".join(items))
995
+
996
+ opt = [item for item in mergeitems if not set(item).issubset(punds)]
997
+ return "\n".join(opt)
998
+
999
+
1000
+ def custom_sort_key(s):
1001
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
1002
+ parts = re.split("(\d+)", s)
1003
+ # 将数字部分转换为整数,非数字部分保持不变
1004
+ parts = [int(part) if part.isdigit() else part for part in parts]
1005
+ return parts
1006
+
1007
+
1008
+ def process_text(texts):
1009
+ _text = []
1010
+ if all(text in [None, " ", "\n", ""] for text in texts):
1011
+ raise ValueError(i18n("请输入有效文本"))
1012
+ for text in texts:
1013
+ if text in [None, " ", ""]:
1014
+ pass
1015
+ else:
1016
+ _text.append(text)
1017
+ return _text
1018
+
1019
+
1020
+ def change_choices():
1021
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
1022
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
1023
+ "choices": sorted(GPT_names, key=custom_sort_key),
1024
+ "__type__": "update",
1025
+ }
1026
+
1027
+
1028
+ SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
1029
+ GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
1030
+ for path in SoVITS_weight_root + GPT_weight_root:
1031
+ os.makedirs(path, exist_ok=True)
1032
+
1033
+
1034
+ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
1035
+ SoVITS_names = [i for i in pretrained_sovits_name]
1036
+ for path in SoVITS_weight_root:
1037
+ for name in os.listdir(path):
1038
+ if name.endswith(".pth"):
1039
+ SoVITS_names.append("%s/%s" % (path, name))
1040
+ GPT_names = [i for i in pretrained_gpt_name]
1041
+ for path in GPT_weight_root:
1042
+ for name in os.listdir(path):
1043
+ if name.endswith(".ckpt"):
1044
+ GPT_names.append("%s/%s" % (path, name))
1045
+ return SoVITS_names, GPT_names
1046
+
1047
+
1048
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
1049
+
1050
+
1051
+ def html_center(text, label="p"):
1052
+ return f"""<div style="text-align: center; margin: 100; padding: 50;">
1053
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
1054
+ </div>"""
1055
+
1056
+
1057
+ def html_left(text, label="p"):
1058
+ return f"""<div style="text-align: left; margin: 0; padding: 0;">
1059
+ <{label} style="margin: 0; padding: 0;">{text}</{label}>
1060
+ </div>"""
1061
+
1062
+
1063
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
1064
+ gr.Markdown(
1065
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
1066
+ + "<br>"
1067
+ + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
1068
+ )
1069
+ with gr.Group():
1070
+ gr.Markdown(html_center(i18n("模型切换"), "h3"))
1071
+ with gr.Row():
1072
+ GPT_dropdown = gr.Dropdown(
1073
+ label=i18n("GPT模型列表"),
1074
+ choices=sorted(GPT_names, key=custom_sort_key),
1075
+ value=gpt_path,
1076
+ interactive=True,
1077
+ scale=14,
1078
+ )
1079
+ SoVITS_dropdown = gr.Dropdown(
1080
+ label=i18n("SoVITS模型列表"),
1081
+ choices=sorted(SoVITS_names, key=custom_sort_key),
1082
+ value=sovits_path,
1083
+ interactive=True,
1084
+ scale=14,
1085
+ )
1086
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
1087
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
1088
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
1089
+ with gr.Row():
1090
+ inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath", scale=13)
1091
+ with gr.Column(scale=13):
1092
+ ref_text_free = gr.Checkbox(
1093
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
1094
+ + i18n("v3暂不支持该模式,使用了会报错。"),
1095
+ value=False,
1096
+ interactive=True if model_version not in v3v4set else False,
1097
+ show_label=True,
1098
+ scale=1,
1099
+ )
1100
+ gr.Markdown(
1101
+ html_left(
1102
+ i18n("使用无参考文本模式时建议使用微调的GPT")
1103
+ + "<br>"
1104
+ + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
1105
+ )
1106
+ )
1107
+ prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=5, max_lines=5, scale=1)
1108
+ with gr.Column(scale=14):
1109
+ prompt_language = gr.Dropdown(
1110
+ label=i18n("参考音频的语种"),
1111
+ choices=list(dict_language.keys()),
1112
+ value=i18n("中文"),
1113
+ )
1114
+ inp_refs = (
1115
+ gr.File(
1116
+ label=i18n(
1117
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
1118
+ ),
1119
+ file_count="multiple",
1120
+ )
1121
+ if model_version not in v3v4set
1122
+ else gr.File(
1123
+ label=i18n(
1124
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
1125
+ ),
1126
+ file_count="multiple",
1127
+ visible=False,
1128
+ )
1129
+ )
1130
+ sample_steps = (
1131
+ gr.Radio(
1132
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
1133
+ value=32 if model_version=="v3"else 8,
1134
+ choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32,64,128],
1135
+ visible=True,
1136
+ )
1137
+ if model_version in v3v4set
1138
+ else gr.Radio(
1139
+ label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
1140
+ choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32,64,128],
1141
+ visible=False,
1142
+ value=32 if model_version=="v3"else 8,
1143
+ )
1144
+ )
1145
+ if_sr_Checkbox = gr.Checkbox(
1146
+ label=i18n("v3输出如果觉得闷可以试试开超分"),
1147
+ value=False,
1148
+ interactive=True,
1149
+ show_label=True,
1150
+ visible=False if model_version !="v3" else True,
1151
+ )
1152
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
1153
+ with gr.Row():
1154
+ with gr.Column(scale=13):
1155
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
1156
+ with gr.Column(scale=7):
1157
+ text_language = gr.Dropdown(
1158
+ label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
1159
+ choices=list(dict_language.keys()),
1160
+ value=i18n("中文"),
1161
+ scale=1,
1162
+ )
1163
+ how_to_cut = gr.Dropdown(
1164
+ label=i18n("怎么切"),
1165
+ choices=[
1166
+ i18n("不切"),
1167
+ i18n("凑四句一切"),
1168
+ i18n("凑50字一切"),
1169
+ i18n("按中文句号。切"),
1170
+ i18n("按英文句号.切"),
1171
+ i18n("按标点符号切"),
1172
+ ],
1173
+ value=i18n("凑四句一切"),
1174
+ interactive=True,
1175
+ scale=1,
1176
+ )
1177
+ gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
1178
+ if_freeze = gr.Checkbox(
1179
+ label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
1180
+ value=False,
1181
+ interactive=True,
1182
+ show_label=True,
1183
+ scale=1,
1184
+ )
1185
+ with gr.Row():
1186
+ speed = gr.Slider(
1187
+ minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1
1188
+ )
1189
+ pause_second_slider = gr.Slider(
1190
+ minimum=0.1,
1191
+ maximum=0.5,
1192
+ step=0.01,
1193
+ label=i18n("句间停顿秒数"),
1194
+ value=0.3,
1195
+ interactive=True,
1196
+ scale=1,
1197
+ )
1198
+ gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
1199
+ top_k = gr.Slider(
1200
+ minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1
1201
+ )
1202
+ top_p = gr.Slider(
1203
+ minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1
1204
+ )
1205
+ temperature = gr.Slider(
1206
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1
1207
+ )
1208
+ # with gr.Column():
1209
+ # gr.Markdown(value=i18n("手工调整音素。当音素框不为空时使用手工音素输入推理,无视目标文本框。"))
1210
+ # phoneme=gr.Textbox(label=i18n("音素框"), value="")
1211
+ # get_phoneme_button = gr.Button(i18n("目标文本转音素"), variant="primary")
1212
+ with gr.Row():
1213
+ inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg", scale=25)
1214
+ output = gr.Audio(label=i18n("输出的语音"), scale=14)
1215
+
1216
+ inference_button.click(
1217
+ get_tts_wav,
1218
+ [
1219
+ inp_ref,
1220
+ prompt_text,
1221
+ prompt_language,
1222
+ text,
1223
+ text_language,
1224
+ how_to_cut,
1225
+ top_k,
1226
+ top_p,
1227
+ temperature,
1228
+ ref_text_free,
1229
+ speed,
1230
+ if_freeze,
1231
+ inp_refs,
1232
+ sample_steps,
1233
+ if_sr_Checkbox,
1234
+ pause_second_slider,
1235
+ ],
1236
+ [output],
1237
+ )
1238
+ SoVITS_dropdown.change(
1239
+ change_sovits_weights,
1240
+ [SoVITS_dropdown, prompt_language, text_language],
1241
+ [
1242
+ prompt_language,
1243
+ text_language,
1244
+ prompt_text,
1245
+ prompt_language,
1246
+ text,
1247
+ text_language,
1248
+ sample_steps,
1249
+ inp_refs,
1250
+ ref_text_free,
1251
+ if_sr_Checkbox,
1252
+ inference_button,
1253
+ ],
1254
+ )
1255
+ GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
1256
+
1257
+ # gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
1258
+ # with gr.Row():
1259
+ # text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="")
1260
+ # button1 = gr.Button(i18n("凑四句一切"), variant="primary")
1261
+ # button2 = gr.Button(i18n("凑50字一切"), variant="primary")
1262
+ # button3 = gr.Button(i18n("按中文句号。切"), variant="primary")
1263
+ # button4 = gr.Button(i18n("按英文句号.切"), variant="primary")
1264
+ # button5 = gr.Button(i18n("按标点符号切"), variant="primary")
1265
+ # text_opt = gr.Textbox(label=i18n("切分后文本"), value="")
1266
+ # button1.click(cut1, [text_inp], [text_opt])
1267
+ # button2.click(cut2, [text_inp], [text_opt])
1268
+ # button3.click(cut3, [text_inp], [text_opt])
1269
+ # button4.click(cut4, [text_inp], [text_opt])
1270
+ # button5.click(cut5, [text_inp], [text_opt])
1271
+ # gr.Markdown(html_center(i18n("后续将支持转音素、手工修改音素、语音合成分步执行。")))
1272
+
1273
+ if __name__ == "__main__":
1274
+ app.queue().launch( # concurrency_count=511, max_size=1022
1275
+ server_name="0.0.0.0",
1276
+ inbrowser=True,
1277
+ share=True,
1278
+ server_port=infer_ttswebui,
1279
+ quiet=True,
1280
+ )
GPT_SoVITS/inference_webui_fast.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 按中英混合识别
3
+ 按日英混合识别
4
+ 多语种启动切分识别语种
5
+ 全部按中文识别
6
+ 全部按英文识别
7
+ 全部按日文识别
8
+ """
9
+
10
+ import json
11
+ import logging
12
+ import os
13
+ import random
14
+ import re
15
+ import sys
16
+
17
+ now_dir = os.getcwd()
18
+ sys.path.append(now_dir)
19
+ sys.path.append("%s/GPT_SoVITS" % (now_dir))
20
+
21
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
22
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
23
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
24
+ logging.getLogger("httpx").setLevel(logging.ERROR)
25
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
26
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
27
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
28
+ import torch
29
+
30
+ try:
31
+ import gradio.analytics as analytics
32
+
33
+ analytics.version_check = lambda: None
34
+ except:
35
+ ...
36
+
37
+
38
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
39
+ infer_ttswebui = int(infer_ttswebui)
40
+ is_share = os.environ.get("is_share", "False")
41
+ is_share = eval(is_share)
42
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
43
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
44
+
45
+ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
46
+ gpt_path = os.environ.get("gpt_path", None)
47
+ sovits_path = os.environ.get("sovits_path", None)
48
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
49
+ bert_path = os.environ.get("bert_path", None)
50
+ version = model_version = os.environ.get("version", "v2")
51
+
52
+ import gradio as gr
53
+ from TTS_infer_pack.text_segmentation_method import get_method
54
+ from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
55
+
56
+ from tools.i18n.i18n import I18nAuto, scan_language_list
57
+
58
+ language = os.environ.get("language", "Auto")
59
+ language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
60
+ i18n = I18nAuto(language=language)
61
+
62
+
63
+ # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
64
+
65
+ if torch.cuda.is_available():
66
+ device = "cuda"
67
+ # elif torch.backends.mps.is_available():
68
+ # device = "mps"
69
+ else:
70
+ device = "cpu"
71
+
72
+ # is_half = False
73
+ # device = "cpu"
74
+
75
+ dict_language_v1 = {
76
+ i18n("中文"): "all_zh", # 全部按中文识别
77
+ i18n("英文"): "en", # 全部按英文识别#######不变
78
+ i18n("日文"): "all_ja", # 全部按日文识别
79
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
80
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
81
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
82
+ }
83
+ dict_language_v2 = {
84
+ i18n("中文"): "all_zh", # 全部按中文识别
85
+ i18n("英文"): "en", # 全部按英文识别#######不变
86
+ i18n("日文"): "all_ja", # 全部按日文识别
87
+ i18n("粤语"): "all_yue", # 全部按中文识别
88
+ i18n("韩文"): "all_ko", # 全部按韩文识别
89
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
90
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
91
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
92
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
93
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
94
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
95
+ }
96
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
97
+
98
+ cut_method = {
99
+ i18n("不切"): "cut0",
100
+ i18n("凑四句一切"): "cut1",
101
+ i18n("凑50字一切"): "cut2",
102
+ i18n("按中文句号。切"): "cut3",
103
+ i18n("按英文句号.切"): "cut4",
104
+ i18n("按标点符号切"): "cut5",
105
+ }
106
+
107
+ tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
108
+ tts_config.device = device
109
+ tts_config.is_half = is_half
110
+ tts_config.version = version
111
+ if gpt_path is not None:
112
+ tts_config.t2s_weights_path = gpt_path
113
+ if sovits_path is not None:
114
+ tts_config.vits_weights_path = sovits_path
115
+ if cnhubert_base_path is not None:
116
+ tts_config.cnhuhbert_base_path = cnhubert_base_path
117
+ if bert_path is not None:
118
+ tts_config.bert_base_path = bert_path
119
+
120
+ print(tts_config)
121
+ tts_pipeline = TTS(tts_config)
122
+ gpt_path = tts_config.t2s_weights_path
123
+ sovits_path = tts_config.vits_weights_path
124
+ version = tts_config.version
125
+
126
+
127
+ def inference(
128
+ text,
129
+ text_lang,
130
+ ref_audio_path,
131
+ aux_ref_audio_paths,
132
+ prompt_text,
133
+ prompt_lang,
134
+ top_k,
135
+ top_p,
136
+ temperature,
137
+ text_split_method,
138
+ batch_size,
139
+ speed_factor,
140
+ ref_text_free,
141
+ split_bucket,
142
+ fragment_interval,
143
+ seed,
144
+ keep_random,
145
+ parallel_infer,
146
+ repetition_penalty,
147
+ sample_steps,
148
+ super_sampling,
149
+ ):
150
+ seed = -1 if keep_random else seed
151
+ actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
152
+ inputs = {
153
+ "text": text,
154
+ "text_lang": dict_language[text_lang],
155
+ "ref_audio_path": ref_audio_path,
156
+ "aux_ref_audio_paths": [item.name for item in aux_ref_audio_paths] if aux_ref_audio_paths is not None else [],
157
+ "prompt_text": prompt_text if not ref_text_free else "",
158
+ "prompt_lang": dict_language[prompt_lang],
159
+ "top_k": top_k,
160
+ "top_p": top_p,
161
+ "temperature": temperature,
162
+ "text_split_method": cut_method[text_split_method],
163
+ "batch_size": int(batch_size),
164
+ "speed_factor": float(speed_factor),
165
+ "split_bucket": split_bucket,
166
+ "return_fragment": False,
167
+ "fragment_interval": fragment_interval,
168
+ "seed": actual_seed,
169
+ "parallel_infer": parallel_infer,
170
+ "repetition_penalty": repetition_penalty,
171
+ "sample_steps": int(sample_steps),
172
+ "super_sampling": super_sampling,
173
+ }
174
+ try:
175
+ for item in tts_pipeline.run(inputs):
176
+ yield item, actual_seed
177
+ except NO_PROMPT_ERROR:
178
+ gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
179
+
180
+
181
+ def custom_sort_key(s):
182
+ # 使用正则表达式提取字符串中的数字部分和非数字部分
183
+ parts = re.split("(\d+)", s)
184
+ # 将数字部分转换为整数,非数字部分保持不变
185
+ parts = [int(part) if part.isdigit() else part for part in parts]
186
+ return parts
187
+
188
+
189
+ def change_choices():
190
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
191
+ return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
192
+ "choices": sorted(GPT_names, key=custom_sort_key),
193
+ "__type__": "update",
194
+ }
195
+
196
+
197
+ path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
198
+ pretrained_sovits_name = [
199
+ "GPT_SoVITS/pretrained_models/s2G488k.pth",
200
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
201
+ path_sovits_v3,
202
+ ]
203
+ pretrained_gpt_name = [
204
+ "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
205
+ "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
206
+ "GPT_SoVITS/pretrained_models/s1v3.ckpt",
207
+ ]
208
+
209
+ _ = [[], []]
210
+ for i in range(3):
211
+ if os.path.exists(pretrained_gpt_name[i]):
212
+ _[0].append(pretrained_gpt_name[i])
213
+ if os.path.exists(pretrained_sovits_name[i]):
214
+ _[-1].append(pretrained_sovits_name[i])
215
+ pretrained_gpt_name, pretrained_sovits_name = _
216
+
217
+
218
+ if os.path.exists("./weight.json"):
219
+ pass
220
+ else:
221
+ with open("./weight.json", "w", encoding="utf-8") as file:
222
+ json.dump({"GPT": {}, "SoVITS": {}}, file)
223
+
224
+ with open("./weight.json", "r", encoding="utf-8") as file:
225
+ weight_data = file.read()
226
+ weight_data = json.loads(weight_data)
227
+ gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
228
+ sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
229
+ if isinstance(gpt_path, list):
230
+ gpt_path = gpt_path[0]
231
+ if isinstance(sovits_path, list):
232
+ sovits_path = sovits_path[0]
233
+
234
+
235
+ SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
236
+ GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
237
+ for path in SoVITS_weight_root + GPT_weight_root:
238
+ os.makedirs(path, exist_ok=True)
239
+
240
+
241
+ def get_weights_names(GPT_weight_root, SoVITS_weight_root):
242
+ SoVITS_names = [i for i in pretrained_sovits_name]
243
+ for path in SoVITS_weight_root:
244
+ for name in os.listdir(path):
245
+ if name.endswith(".pth"):
246
+ SoVITS_names.append("%s/%s" % (path, name))
247
+ GPT_names = [i for i in pretrained_gpt_name]
248
+ for path in GPT_weight_root:
249
+ for name in os.listdir(path):
250
+ if name.endswith(".ckpt"):
251
+ GPT_names.append("%s/%s" % (path, name))
252
+ return SoVITS_names, GPT_names
253
+
254
+
255
+ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
256
+
257
+
258
+ from process_ckpt import get_sovits_version_from_path_fast
259
+
260
+
261
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
262
+ global version, model_version, dict_language, if_lora_v3
263
+ version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
264
+ # print(sovits_path,version, model_version, if_lora_v3)
265
+ if if_lora_v3 and not os.path.exists(path_sovits_v3):
266
+ info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
267
+ gr.Warning(info)
268
+ raise FileExistsError(info)
269
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
270
+ if prompt_language is not None and text_language is not None:
271
+ if prompt_language in list(dict_language.keys()):
272
+ prompt_text_update, prompt_language_update = (
273
+ {"__type__": "update"},
274
+ {"__type__": "update", "value": prompt_language},
275
+ )
276
+ else:
277
+ prompt_text_update = {"__type__": "update", "value": ""}
278
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
279
+ if text_language in list(dict_language.keys()):
280
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
281
+ else:
282
+ text_update = {"__type__": "update", "value": ""}
283
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
284
+ if model_version == "v3":
285
+ visible_sample_steps = True
286
+ visible_inp_refs = False
287
+ else:
288
+ visible_sample_steps = False
289
+ visible_inp_refs = True
290
+ # prompt_language,text_language,prompt_text,prompt_language,text,text_language,inp_refs,ref_text_free,
291
+ yield (
292
+ {"__type__": "update", "choices": list(dict_language.keys())},
293
+ {"__type__": "update", "choices": list(dict_language.keys())},
294
+ prompt_text_update,
295
+ prompt_language_update,
296
+ text_update,
297
+ text_language_update,
298
+ {"__type__": "update", "interactive": visible_sample_steps, "value": 32},
299
+ {"__type__": "update", "visible": visible_inp_refs},
300
+ {"__type__": "update", "interactive": True if model_version != "v3" else False},
301
+ {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
302
+ )
303
+
304
+ tts_pipeline.init_vits_weights(sovits_path)
305
+ yield (
306
+ {"__type__": "update", "choices": list(dict_language.keys())},
307
+ {"__type__": "update", "choices": list(dict_language.keys())},
308
+ prompt_text_update,
309
+ prompt_language_update,
310
+ text_update,
311
+ text_language_update,
312
+ {"__type__": "update", "interactive": visible_sample_steps, "value": 32},
313
+ {"__type__": "update", "visible": visible_inp_refs},
314
+ {"__type__": "update", "interactive": True if model_version != "v3" else False},
315
+ {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
316
+ )
317
+ with open("./weight.json") as f:
318
+ data = f.read()
319
+ data = json.loads(data)
320
+ data["SoVITS"][version] = sovits_path
321
+ with open("./weight.json", "w") as f:
322
+ f.write(json.dumps(data))
323
+
324
+
325
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
326
+ gr.Markdown(
327
+ value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
328
+ + "<br>"
329
+ + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
330
+ )
331
+
332
+ with gr.Column():
333
+ # with gr.Group():
334
+ gr.Markdown(value=i18n("模型切换"))
335
+ with gr.Row():
336
+ GPT_dropdown = gr.Dropdown(
337
+ label=i18n("GPT模型列表"),
338
+ choices=sorted(GPT_names, key=custom_sort_key),
339
+ value=gpt_path,
340
+ interactive=True,
341
+ )
342
+ SoVITS_dropdown = gr.Dropdown(
343
+ label=i18n("SoVITS模型列表"),
344
+ choices=sorted(SoVITS_names, key=custom_sort_key),
345
+ value=sovits_path,
346
+ interactive=True,
347
+ )
348
+ refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
349
+ refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
350
+
351
+ with gr.Row():
352
+ with gr.Column():
353
+ gr.Markdown(value=i18n("*请上传并填写参考信息"))
354
+ with gr.Row():
355
+ inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
356
+ inp_refs = gr.File(
357
+ label=i18n("辅参考音频(可选多个,或不选)"),
358
+ file_count="multiple",
359
+ visible=True if model_version != "v3" else False,
360
+ )
361
+ prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
362
+ with gr.Row():
363
+ prompt_language = gr.Dropdown(
364
+ label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
365
+ )
366
+ with gr.Column():
367
+ ref_text_free = gr.Checkbox(
368
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
369
+ value=False,
370
+ interactive=True if model_version != "v3" else False,
371
+ show_label=True,
372
+ )
373
+ gr.Markdown(
374
+ i18n("使用无参考文本模式时建议使用微调的GPT")
375
+ + "<br>"
376
+ + i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
377
+ )
378
+
379
+ with gr.Column():
380
+ gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
381
+ text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
382
+ text_language = gr.Dropdown(
383
+ label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
384
+ )
385
+
386
+ with gr.Group():
387
+ gr.Markdown(value=i18n("推理设置"))
388
+ with gr.Row():
389
+ with gr.Column():
390
+ with gr.Row():
391
+ batch_size = gr.Slider(
392
+ minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
393
+ )
394
+ sample_steps = gr.Radio(
395
+ label=i18n("采样步数(仅对V3生效)"), value=32, choices=[4, 8, 16, 32], visible=True
396
+ )
397
+ with gr.Row():
398
+ fragment_interval = gr.Slider(
399
+ minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
400
+ )
401
+ speed_factor = gr.Slider(
402
+ minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
403
+ )
404
+ with gr.Row():
405
+ top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
406
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
407
+ with gr.Row():
408
+ temperature = gr.Slider(
409
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
410
+ )
411
+ repetition_penalty = gr.Slider(
412
+ minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
413
+ )
414
+
415
+ with gr.Column():
416
+ with gr.Row():
417
+ how_to_cut = gr.Dropdown(
418
+ label=i18n("怎么切"),
419
+ choices=[
420
+ i18n("不切"),
421
+ i18n("凑四句一切"),
422
+ i18n("凑50字一切"),
423
+ i18n("按中文句号。切"),
424
+ i18n("按英文句号.切"),
425
+ i18n("按标点符号切"),
426
+ ],
427
+ value=i18n("凑四句一切"),
428
+ interactive=True,
429
+ scale=1,
430
+ )
431
+ super_sampling = gr.Checkbox(
432
+ label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
433
+ )
434
+
435
+ with gr.Row():
436
+ parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
437
+ split_bucket = gr.Checkbox(
438
+ label=i18n("数据分桶(并行推理时会降低一点计算量)"),
439
+ value=True,
440
+ interactive=True,
441
+ show_label=True,
442
+ )
443
+
444
+ with gr.Row():
445
+ seed = gr.Number(label=i18n("随机种子"), value=-1)
446
+ keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
447
+
448
+ output = gr.Audio(label=i18n("输出的语音"))
449
+ with gr.Row():
450
+ inference_button = gr.Button(i18n("合成语音"), variant="primary")
451
+ stop_infer = gr.Button(i18n("终止合成"), variant="primary")
452
+
453
+ inference_button.click(
454
+ inference,
455
+ [
456
+ text,
457
+ text_language,
458
+ inp_ref,
459
+ inp_refs,
460
+ prompt_text,
461
+ prompt_language,
462
+ top_k,
463
+ top_p,
464
+ temperature,
465
+ how_to_cut,
466
+ batch_size,
467
+ speed_factor,
468
+ ref_text_free,
469
+ split_bucket,
470
+ fragment_interval,
471
+ seed,
472
+ keep_random,
473
+ parallel_infer,
474
+ repetition_penalty,
475
+ sample_steps,
476
+ super_sampling,
477
+ ],
478
+ [output, seed],
479
+ )
480
+ stop_infer.click(tts_pipeline.stop, [], [])
481
+ SoVITS_dropdown.change(
482
+ change_sovits_weights,
483
+ [SoVITS_dropdown, prompt_language, text_language],
484
+ [
485
+ prompt_language,
486
+ text_language,
487
+ prompt_text,
488
+ prompt_language,
489
+ text,
490
+ text_language,
491
+ sample_steps,
492
+ inp_refs,
493
+ ref_text_free,
494
+ inference_button,
495
+ ],
496
+ ) #
497
+ GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
498
+
499
+ with gr.Group():
500
+ gr.Markdown(
501
+ value=i18n(
502
+ "文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
503
+ )
504
+ )
505
+ with gr.Row():
506
+ text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
507
+ with gr.Column():
508
+ _how_to_cut = gr.Radio(
509
+ label=i18n("怎么切"),
510
+ choices=[
511
+ i18n("不切"),
512
+ i18n("凑四句一切"),
513
+ i18n("凑50字一切"),
514
+ i18n("按中文句号。切"),
515
+ i18n("按英文句号.切"),
516
+ i18n("按标点符号切"),
517
+ ],
518
+ value=i18n("凑四句一切"),
519
+ interactive=True,
520
+ )
521
+ cut_text = gr.Button(i18n("切分"), variant="primary")
522
+
523
+ def to_cut(text_inp, how_to_cut):
524
+ if len(text_inp.strip()) == 0 or text_inp == []:
525
+ return ""
526
+ method = get_method(cut_method[how_to_cut])
527
+ return method(text_inp)
528
+
529
+ text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
530
+ cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
531
+ gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
532
+
533
+ if __name__ == "__main__":
534
+ app.queue().launch( # concurrency_count=511, max_size=1022
535
+ server_name="0.0.0.0",
536
+ inbrowser=True,
537
+ share=is_share,
538
+ server_port=infer_ttswebui,
539
+ quiet=True,
540
+ )
GPT_SoVITS/onnx_export.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
4
+ from feature_extractor import cnhubert
5
+ from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
6
+ from torch import nn
7
+
8
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
9
+ cnhubert.cnhubert_base_path = cnhubert_base_path
10
+ ssl_model = cnhubert.get_model()
11
+ import json
12
+ import os
13
+
14
+ import soundfile
15
+ from text import cleaned_text_to_sequence
16
+
17
+
18
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
19
+ hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
20
+ y = torch.nn.functional.pad(
21
+ y.unsqueeze(1),
22
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
23
+ mode="reflect",
24
+ )
25
+ y = y.squeeze(1)
26
+ spec = torch.stft(
27
+ y,
28
+ n_fft,
29
+ hop_length=hop_size,
30
+ win_length=win_size,
31
+ window=hann_window,
32
+ center=center,
33
+ pad_mode="reflect",
34
+ normalized=False,
35
+ onesided=True,
36
+ return_complex=False,
37
+ )
38
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
39
+ return spec
40
+
41
+
42
+ class DictToAttrRecursive(dict):
43
+ def __init__(self, input_dict):
44
+ super().__init__(input_dict)
45
+ for key, value in input_dict.items():
46
+ if isinstance(value, dict):
47
+ value = DictToAttrRecursive(value)
48
+ self[key] = value
49
+ setattr(self, key, value)
50
+
51
+ def __getattr__(self, item):
52
+ try:
53
+ return self[item]
54
+ except KeyError:
55
+ raise AttributeError(f"Attribute {item} not found")
56
+
57
+ def __setattr__(self, key, value):
58
+ if isinstance(value, dict):
59
+ value = DictToAttrRecursive(value)
60
+ super(DictToAttrRecursive, self).__setitem__(key, value)
61
+ super().__setattr__(key, value)
62
+
63
+ def __delattr__(self, item):
64
+ try:
65
+ del self[item]
66
+ except KeyError:
67
+ raise AttributeError(f"Attribute {item} not found")
68
+
69
+
70
+ class T2SEncoder(nn.Module):
71
+ def __init__(self, t2s, vits):
72
+ super().__init__()
73
+ self.encoder = t2s.onnx_encoder
74
+ self.vits = vits
75
+
76
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
77
+ codes = self.vits.extract_latent(ssl_content)
78
+ prompt_semantic = codes[0, 0]
79
+ bert = torch.cat([ref_bert.transpose(0, 1), text_bert.transpose(0, 1)], 1)
80
+ all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
81
+ bert = bert.unsqueeze(0)
82
+ prompt = prompt_semantic.unsqueeze(0)
83
+ return self.encoder(all_phoneme_ids, bert), prompt
84
+
85
+
86
+ class T2SModel(nn.Module):
87
+ def __init__(self, t2s_path, vits_model):
88
+ super().__init__()
89
+ dict_s1 = torch.load(t2s_path, map_location="cpu")
90
+ self.config = dict_s1["config"]
91
+ self.t2s_model = Text2SemanticLightningModule(self.config, "ojbk", is_train=False)
92
+ self.t2s_model.load_state_dict(dict_s1["weight"])
93
+ self.t2s_model.eval()
94
+ self.vits_model = vits_model.vq_model
95
+ self.hz = 50
96
+ self.max_sec = self.config["data"]["max_sec"]
97
+ self.t2s_model.model.top_k = torch.LongTensor([self.config["inference"]["top_k"]])
98
+ self.t2s_model.model.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
99
+ self.t2s_model = self.t2s_model.model
100
+ self.t2s_model.init_onnx()
101
+ self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
102
+ self.first_stage_decoder = self.t2s_model.first_stage_decoder
103
+ self.stage_decoder = self.t2s_model.stage_decoder
104
+ # self.t2s_model = torch.jit.script(self.t2s_model)
105
+
106
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
107
+ early_stop_num = self.t2s_model.early_stop_num
108
+
109
+ # [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
110
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
111
+
112
+ prefix_len = prompts.shape[1]
113
+
114
+ # [1,N,512] [1,N]
115
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
116
+
117
+ stop = False
118
+ for idx in range(1, 1500):
119
+ # [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
120
+ enco = self.stage_decoder(y, k, v, y_emb, x_example)
121
+ y, k, v, y_emb, logits, samples = enco
122
+ if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
123
+ stop = True
124
+ if torch.argmax(logits, dim=-1)[0] == self.t2s_model.EOS or samples[0, 0] == self.t2s_model.EOS:
125
+ stop = True
126
+ if stop:
127
+ break
128
+ y[0, -1] = 0
129
+
130
+ return y[:, -idx:].unsqueeze(0)
131
+
132
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
133
+ # self.onnx_encoder = torch.jit.script(self.onnx_encoder)
134
+ if dynamo:
135
+ export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
136
+ onnx_encoder_export_output = torch.onnx.dynamo_export(
137
+ self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
138
+ )
139
+ onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
140
+ return
141
+
142
+ torch.onnx.export(
143
+ self.onnx_encoder,
144
+ (ref_seq, text_seq, ref_bert, text_bert, ssl_content),
145
+ f"onnx/{project_name}/{project_name}_t2s_encoder.onnx",
146
+ input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
147
+ output_names=["x", "prompts"],
148
+ dynamic_axes={
149
+ "ref_seq": {1: "ref_length"},
150
+ "text_seq": {1: "text_length"},
151
+ "ref_bert": {0: "ref_length"},
152
+ "text_bert": {0: "text_length"},
153
+ "ssl_content": {2: "ssl_length"},
154
+ },
155
+ opset_version=16,
156
+ )
157
+ x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
158
+
159
+ torch.onnx.export(
160
+ self.first_stage_decoder,
161
+ (x, prompts),
162
+ f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx",
163
+ input_names=["x", "prompts"],
164
+ output_names=["y", "k", "v", "y_emb", "x_example"],
165
+ dynamic_axes={
166
+ "x": {1: "x_length"},
167
+ "prompts": {1: "prompts_length"},
168
+ },
169
+ verbose=False,
170
+ opset_version=16,
171
+ )
172
+ y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
173
+
174
+ torch.onnx.export(
175
+ self.stage_decoder,
176
+ (y, k, v, y_emb, x_example),
177
+ f"onnx/{project_name}/{project_name}_t2s_sdec.onnx",
178
+ input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
179
+ output_names=["y", "k", "v", "y_emb", "logits", "samples"],
180
+ dynamic_axes={
181
+ "iy": {1: "iy_length"},
182
+ "ik": {1: "ik_length"},
183
+ "iv": {1: "iv_length"},
184
+ "iy_emb": {1: "iy_emb_length"},
185
+ "ix_example": {1: "ix_example_length"},
186
+ },
187
+ verbose=False,
188
+ opset_version=16,
189
+ )
190
+
191
+
192
+ class VitsModel(nn.Module):
193
+ def __init__(self, vits_path):
194
+ super().__init__()
195
+ dict_s2 = torch.load(vits_path, map_location="cpu")
196
+ self.hps = dict_s2["config"]
197
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
198
+ self.hps["model"]["version"] = "v1"
199
+ else:
200
+ self.hps["model"]["version"] = "v2"
201
+
202
+ self.hps = DictToAttrRecursive(self.hps)
203
+ self.hps.model.semantic_frame_rate = "25hz"
204
+ self.vq_model = SynthesizerTrn(
205
+ self.hps.data.filter_length // 2 + 1,
206
+ self.hps.train.segment_size // self.hps.data.hop_length,
207
+ n_speakers=self.hps.data.n_speakers,
208
+ **self.hps.model,
209
+ )
210
+ self.vq_model.eval()
211
+ self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
212
+
213
+ def forward(self, text_seq, pred_semantic, ref_audio):
214
+ refer = spectrogram_torch(
215
+ ref_audio,
216
+ self.hps.data.filter_length,
217
+ self.hps.data.sampling_rate,
218
+ self.hps.data.hop_length,
219
+ self.hps.data.win_length,
220
+ center=False,
221
+ )
222
+ return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
223
+
224
+
225
+ class GptSoVits(nn.Module):
226
+ def __init__(self, vits, t2s):
227
+ super().__init__()
228
+ self.vits = vits
229
+ self.t2s = t2s
230
+
231
+ def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
232
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
233
+ audio = self.vits(text_seq, pred_semantic, ref_audio)
234
+ if debug:
235
+ import onnxruntime
236
+
237
+ sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
238
+ audio1 = sess.run(
239
+ None,
240
+ {
241
+ "text_seq": text_seq.detach().cpu().numpy(),
242
+ "pred_semantic": pred_semantic.detach().cpu().numpy(),
243
+ "ref_audio": ref_audio.detach().cpu().numpy(),
244
+ },
245
+ )
246
+ return audio, audio1
247
+ return audio
248
+
249
+ def export(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, project_name):
250
+ self.t2s.export(ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name)
251
+ pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
252
+ torch.onnx.export(
253
+ self.vits,
254
+ (text_seq, pred_semantic, ref_audio),
255
+ f"onnx/{project_name}/{project_name}_vits.onnx",
256
+ input_names=["text_seq", "pred_semantic", "ref_audio"],
257
+ output_names=["audio"],
258
+ dynamic_axes={
259
+ "text_seq": {1: "text_length"},
260
+ "pred_semantic": {2: "pred_length"},
261
+ "ref_audio": {1: "audio_length"},
262
+ },
263
+ opset_version=17,
264
+ verbose=False,
265
+ )
266
+
267
+
268
+ class SSLModel(nn.Module):
269
+ def __init__(self):
270
+ super().__init__()
271
+ self.ssl = ssl_model
272
+
273
+ def forward(self, ref_audio_16k):
274
+ return self.ssl.model(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
275
+
276
+
277
+ def export(vits_path, gpt_path, project_name, vits_model="v2"):
278
+ vits = VitsModel(vits_path)
279
+ gpt = T2SModel(gpt_path, vits)
280
+ gpt_sovits = GptSoVits(vits, gpt)
281
+ ssl = SSLModel()
282
+ ref_seq = torch.LongTensor(
283
+ [
284
+ cleaned_text_to_sequence(
285
+ [
286
+ "n",
287
+ "i2",
288
+ "h",
289
+ "ao3",
290
+ ",",
291
+ "w",
292
+ "o3",
293
+ "sh",
294
+ "i4",
295
+ "b",
296
+ "ai2",
297
+ "y",
298
+ "e4",
299
+ ],
300
+ version=vits_model,
301
+ )
302
+ ]
303
+ )
304
+ text_seq = torch.LongTensor(
305
+ [
306
+ cleaned_text_to_sequence(
307
+ [
308
+ "w",
309
+ "o3",
310
+ "sh",
311
+ "i4",
312
+ "b",
313
+ "ai2",
314
+ "y",
315
+ "e4",
316
+ "w",
317
+ "o3",
318
+ "sh",
319
+ "i4",
320
+ "b",
321
+ "ai2",
322
+ "y",
323
+ "e4",
324
+ "w",
325
+ "o3",
326
+ "sh",
327
+ "i4",
328
+ "b",
329
+ "ai2",
330
+ "y",
331
+ "e4",
332
+ ],
333
+ version=vits_model,
334
+ )
335
+ ]
336
+ )
337
+ ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
338
+ text_bert = torch.randn((text_seq.shape[1], 1024)).float()
339
+ ref_audio = torch.randn((1, 48000 * 5)).float()
340
+ # ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
341
+ ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
342
+ ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
343
+
344
+ try:
345
+ os.mkdir(f"onnx/{project_name}")
346
+ except:
347
+ pass
348
+
349
+ ssl_content = ssl(ref_audio_16k).float()
350
+
351
+ # debug = False
352
+ debug = True
353
+
354
+ # gpt_sovits.export(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, project_name)
355
+
356
+ if debug:
357
+ a, b = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content, debug=debug)
358
+ soundfile.write("out1.wav", a.cpu().detach().numpy(), vits.hps.data.sampling_rate)
359
+ soundfile.write("out2.wav", b[0], vits.hps.data.sampling_rate)
360
+ else:
361
+ a = gpt_sovits(ref_seq, text_seq, ref_bert, text_bert, ref_audio_sr, ssl_content).detach().cpu().numpy()
362
+ soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
363
+
364
+ if vits_model == "v1":
365
+ symbols = symbols_v1
366
+ else:
367
+ symbols = symbols_v2
368
+
369
+ MoeVSConf = {
370
+ "Folder": f"{project_name}",
371
+ "Name": f"{project_name}",
372
+ "Type": "GPT-SoVits",
373
+ "Rate": vits.hps.data.sampling_rate,
374
+ "NumLayers": gpt.t2s_model.num_layers,
375
+ "EmbeddingDim": gpt.t2s_model.embedding_dim,
376
+ "Dict": "BasicDict",
377
+ "BertPath": "chinese-roberta-wwm-ext-large",
378
+ # "Symbol": symbols,
379
+ "AddBlank": False,
380
+ }
381
+
382
+ MoeVSConfJson = json.dumps(MoeVSConf)
383
+ with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
384
+ json.dump(MoeVSConf, MoeVsConfFile, indent=4)
385
+
386
+
387
+ if __name__ == "__main__":
388
+ try:
389
+ os.mkdir("onnx")
390
+ except:
391
+ pass
392
+
393
+ gpt_path = "GPT_weights/nahida-e25.ckpt"
394
+ vits_path = "SoVITS_weights/nahida_e30_s3930.pth"
395
+ exp_path = "nahida"
396
+ export(vits_path, gpt_path, exp_path)
397
+
398
+ # soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
GPT_SoVITS/process_ckpt.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from collections import OrderedDict
3
+ from time import time as ttime
4
+ import shutil
5
+ import os
6
+ import torch
7
+ from tools.i18n.i18n import I18nAuto
8
+
9
+ i18n = I18nAuto()
10
+
11
+
12
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
13
+ dir = os.path.dirname(path)
14
+ name = os.path.basename(path)
15
+ tmp_path = "%s.pth" % (ttime())
16
+ torch.save(fea, tmp_path)
17
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
18
+
19
+
20
+ """
21
+ 00:v1
22
+ 01:v2
23
+ 02:v3
24
+ 03:v3lora
25
+ 04:v4lora
26
+
27
+ """
28
+ from io import BytesIO
29
+
30
+
31
+ def my_save2(fea, path,cfm_version):
32
+ bio = BytesIO()
33
+ torch.save(fea, bio)
34
+ bio.seek(0)
35
+ data = bio.getvalue()
36
+ byte=b"03" if cfm_version=="v3"else b"04"
37
+ data = byte + data[2:]
38
+ with open(path, "wb") as f:
39
+ f.write(data)
40
+
41
+
42
+ def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
43
+ try:
44
+ opt = OrderedDict()
45
+ opt["weight"] = {}
46
+ for key in ckpt.keys():
47
+ if "enc_q" in key:
48
+ continue
49
+ opt["weight"][key] = ckpt[key].half()
50
+ opt["config"] = hps
51
+ opt["info"] = "%sepoch_%siteration" % (epoch, steps)
52
+ if lora_rank:
53
+ opt["lora_rank"] = lora_rank
54
+ my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
55
+ else:
56
+ my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
57
+ return "Success."
58
+ except:
59
+ return traceback.format_exc()
60
+
61
+
62
+ head2version = {
63
+ b"00": ["v1", "v1", False],
64
+ b"01": ["v2", "v2", False],
65
+ b"02": ["v2", "v3", False],
66
+ b"03": ["v2", "v3", True],
67
+ b"04": ["v2", "v4", True],
68
+ }
69
+ hash_pretrained_dict = {
70
+ "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
71
+ "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
72
+ "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
73
+ "4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
74
+ }
75
+ import hashlib
76
+
77
+
78
+ def get_hash_from_file(sovits_path):
79
+ with open(sovits_path, "rb") as f:
80
+ data = f.read(8192)
81
+ hash_md5 = hashlib.md5()
82
+ hash_md5.update(data)
83
+ return hash_md5.hexdigest()
84
+
85
+
86
+ def get_sovits_version_from_path_fast(sovits_path):
87
+ ###1-if it is pretrained sovits models, by hash
88
+ hash = get_hash_from_file(sovits_path)
89
+ if hash in hash_pretrained_dict:
90
+ return hash_pretrained_dict[hash]
91
+ ###2-new weights, by head
92
+ with open(sovits_path, "rb") as f:
93
+ version = f.read(2)
94
+ if version != b"PK":
95
+ return head2version[version]
96
+ ###3-old weights, by file size
97
+ if_lora_v3 = False
98
+ size = os.path.getsize(sovits_path)
99
+ """
100
+ v1weights:about 82942KB
101
+ half thr:82978KB
102
+ v2weights:about 83014KB
103
+ v3weights:about 750MB
104
+ """
105
+ if size < 82978 * 1024:
106
+ model_version = version = "v1"
107
+ elif size < 700 * 1024 * 1024:
108
+ model_version = version = "v2"
109
+ else:
110
+ version = "v2"
111
+ model_version = "v3"
112
+ return version, model_version, if_lora_v3
113
+
114
+
115
+ def load_sovits_new(sovits_path):
116
+ f = open(sovits_path, "rb")
117
+ meta = f.read(2)
118
+ if meta != "PK":
119
+ data = b"PK" + f.read()
120
+ bio = BytesIO()
121
+ bio.write(data)
122
+ bio.seek(0)
123
+ return torch.load(bio, map_location="cpu", weights_only=False)
124
+ return torch.load(sovits_path, map_location="cpu", weights_only=False)
GPT_SoVITS/s1_train.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
2
+ import os
3
+
4
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
6
+ import argparse
7
+ import logging
8
+ import platform
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from AR.data.data_module import Text2SemanticDataModule
13
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
14
+ from AR.utils.io import load_yaml_config
15
+ from pytorch_lightning import Trainer, seed_everything
16
+ from pytorch_lightning.callbacks import ModelCheckpoint
17
+ from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
18
+ from pytorch_lightning.strategies import DDPStrategy
19
+
20
+ logging.getLogger("numba").setLevel(logging.WARNING)
21
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
22
+ torch.set_float32_matmul_precision("high")
23
+ from collections import OrderedDict
24
+
25
+ from AR.utils import get_newest_ckpt
26
+ from process_ckpt import my_save
27
+
28
+
29
+ class my_model_ckpt(ModelCheckpoint):
30
+ def __init__(
31
+ self,
32
+ config,
33
+ if_save_latest,
34
+ if_save_every_weights,
35
+ half_weights_save_dir,
36
+ exp_name,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.if_save_latest = if_save_latest
41
+ self.if_save_every_weights = if_save_every_weights
42
+ self.half_weights_save_dir = half_weights_save_dir
43
+ self.exp_name = exp_name
44
+ self.config = config
45
+
46
+ def on_train_epoch_end(self, trainer, pl_module):
47
+ # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
48
+ if self._should_save_on_train_epoch_end(trainer):
49
+ monitor_candidates = self._monitor_candidates(trainer)
50
+ if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
51
+ if (
52
+ self.if_save_latest == True
53
+ ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
54
+ to_clean = list(os.listdir(self.dirpath))
55
+ self._save_topk_checkpoint(trainer, monitor_candidates)
56
+ if self.if_save_latest == True:
57
+ for name in to_clean:
58
+ try:
59
+ os.remove("%s/%s" % (self.dirpath, name))
60
+ except:
61
+ pass
62
+ if self.if_save_every_weights == True:
63
+ to_save_od = OrderedDict()
64
+ to_save_od["weight"] = OrderedDict()
65
+ dictt = trainer.strategy._lightning_module.state_dict()
66
+ for key in dictt:
67
+ to_save_od["weight"][key] = dictt[key].half()
68
+ to_save_od["config"] = self.config
69
+ to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
70
+ # torch.save(
71
+ # print(os.environ)
72
+ if os.environ.get("LOCAL_RANK", "0") == "0":
73
+ my_save(
74
+ to_save_od,
75
+ "%s/%s-e%s.ckpt"
76
+ % (
77
+ self.half_weights_save_dir,
78
+ self.exp_name,
79
+ trainer.current_epoch + 1,
80
+ ),
81
+ )
82
+ self._save_last_checkpoint(trainer, monitor_candidates)
83
+
84
+
85
+ def main(args):
86
+ config = load_yaml_config(args.config_file)
87
+
88
+ output_dir = Path(config["output_dir"])
89
+ output_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ ckpt_dir = output_dir / "ckpt"
92
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
93
+
94
+ seed_everything(config["train"]["seed"], workers=True)
95
+ ckpt_callback: ModelCheckpoint = my_model_ckpt(
96
+ config=config,
97
+ if_save_latest=config["train"]["if_save_latest"],
98
+ if_save_every_weights=config["train"]["if_save_every_weights"],
99
+ half_weights_save_dir=config["train"]["half_weights_save_dir"],
100
+ exp_name=config["train"]["exp_name"],
101
+ save_top_k=-1,
102
+ monitor="top_3_acc",
103
+ mode="max",
104
+ save_on_train_epoch_end=True,
105
+ every_n_epochs=config["train"]["save_every_n_epoch"],
106
+ dirpath=ckpt_dir,
107
+ )
108
+ logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
109
+ os.environ["MASTER_ADDR"] = "localhost"
110
+ os.environ["USE_LIBUV"] = "0"
111
+ trainer: Trainer = Trainer(
112
+ max_epochs=config["train"]["epochs"],
113
+ accelerator="gpu" if torch.cuda.is_available() else "cpu",
114
+ # val_check_interval=9999999999999999999999,###不要验证
115
+ # check_val_every_n_epoch=None,
116
+ limit_val_batches=0,
117
+ devices=-1 if torch.cuda.is_available() else 1,
118
+ benchmark=False,
119
+ fast_dev_run=False,
120
+ strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
121
+ if torch.cuda.is_available()
122
+ else "auto",
123
+ precision=config["train"]["precision"],
124
+ logger=logger,
125
+ num_sanity_val_steps=0,
126
+ callbacks=[ckpt_callback],
127
+ use_distributed_sampler=False, # 非常简单的修改,但解决了采用自定义的 bucket_sampler 下训练步数不一致的问题!
128
+ )
129
+
130
+ model: Text2SemanticLightningModule = Text2SemanticLightningModule(config, output_dir)
131
+
132
+ data_module: Text2SemanticDataModule = Text2SemanticDataModule(
133
+ config,
134
+ train_semantic_path=config["train_semantic_path"],
135
+ train_phoneme_path=config["train_phoneme_path"],
136
+ # dev_semantic_path=args.dev_semantic_path,
137
+ # dev_phoneme_path=args.dev_phoneme_path
138
+ )
139
+
140
+ try:
141
+ # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
142
+ newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
143
+ ckpt_path = ckpt_dir / newest_ckpt_name
144
+ except Exception:
145
+ ckpt_path = None
146
+ print("ckpt_path:", ckpt_path)
147
+ trainer.fit(model, data_module, ckpt_path=ckpt_path)
148
+
149
+
150
+ # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument(
154
+ "-c",
155
+ "--config_file",
156
+ type=str,
157
+ default="configs/s1longer.yaml",
158
+ help="path of config file",
159
+ )
160
+ # args for dataset
161
+ # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
162
+ # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
163
+
164
+ # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
165
+ # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
166
+ # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
167
+ # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
168
+
169
+ args = parser.parse_args()
170
+ logging.info(str(args))
171
+ main(args)
GPT_SoVITS/s2_train.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+ import os
5
+
6
+ import utils
7
+
8
+ hps = utils.get_hparams(stage=2)
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
10
+ import logging
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from torch.nn import functional as F
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from torch.utils.data import DataLoader
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tqdm import tqdm
21
+
22
+ logging.getLogger("matplotlib").setLevel(logging.INFO)
23
+ logging.getLogger("h5py").setLevel(logging.INFO)
24
+ logging.getLogger("numba").setLevel(logging.INFO)
25
+ from random import randint
26
+
27
+ from module import commons
28
+ from module.data_utils import (
29
+ DistributedBucketSampler,
30
+ TextAudioSpeakerCollate,
31
+ TextAudioSpeakerLoader,
32
+ )
33
+ from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
34
+ from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
35
+ from module.models import (
36
+ MultiPeriodDiscriminator,
37
+ SynthesizerTrn,
38
+ )
39
+ from process_ckpt import savee
40
+
41
+ torch.backends.cudnn.benchmark = False
42
+ torch.backends.cudnn.deterministic = False
43
+ ###反正A100fp32更快,那试试tf32吧
44
+ torch.backends.cuda.matmul.allow_tf32 = True
45
+ torch.backends.cudnn.allow_tf32 = True
46
+ torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
47
+ # from config import pretrained_s2G,pretrained_s2D
48
+ global_step = 0
49
+
50
+ device = "cpu" # cuda以外的设备,等mps优化后加入
51
+
52
+
53
+ def main():
54
+ if torch.cuda.is_available():
55
+ n_gpus = torch.cuda.device_count()
56
+ else:
57
+ n_gpus = 1
58
+ os.environ["MASTER_ADDR"] = "localhost"
59
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
60
+
61
+ mp.spawn(
62
+ run,
63
+ nprocs=n_gpus,
64
+ args=(
65
+ n_gpus,
66
+ hps,
67
+ ),
68
+ )
69
+
70
+
71
+ def run(rank, n_gpus, hps):
72
+ global global_step
73
+ if rank == 0:
74
+ logger = utils.get_logger(hps.data.exp_dir)
75
+ logger.info(hps)
76
+ # utils.check_git_hash(hps.s2_ckpt_dir)
77
+ writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
78
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
79
+
80
+ dist.init_process_group(
81
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
82
+ init_method="env://?use_libuv=False",
83
+ world_size=n_gpus,
84
+ rank=rank,
85
+ )
86
+ torch.manual_seed(hps.train.seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.set_device(rank)
89
+
90
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
91
+ train_sampler = DistributedBucketSampler(
92
+ train_dataset,
93
+ hps.train.batch_size,
94
+ [
95
+ 32,
96
+ 300,
97
+ 400,
98
+ 500,
99
+ 600,
100
+ 700,
101
+ 800,
102
+ 900,
103
+ 1000,
104
+ 1100,
105
+ 1200,
106
+ 1300,
107
+ 1400,
108
+ 1500,
109
+ 1600,
110
+ 1700,
111
+ 1800,
112
+ 1900,
113
+ ],
114
+ num_replicas=n_gpus,
115
+ rank=rank,
116
+ shuffle=True,
117
+ )
118
+ collate_fn = TextAudioSpeakerCollate()
119
+ train_loader = DataLoader(
120
+ train_dataset,
121
+ num_workers=6,
122
+ shuffle=False,
123
+ pin_memory=True,
124
+ collate_fn=collate_fn,
125
+ batch_sampler=train_sampler,
126
+ persistent_workers=True,
127
+ prefetch_factor=4,
128
+ )
129
+ # if rank == 0:
130
+ # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
131
+ # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
132
+ # batch_size=1, pin_memory=True,
133
+ # drop_last=False, collate_fn=collate_fn)
134
+
135
+ net_g = (
136
+ SynthesizerTrn(
137
+ hps.data.filter_length // 2 + 1,
138
+ hps.train.segment_size // hps.data.hop_length,
139
+ n_speakers=hps.data.n_speakers,
140
+ **hps.model,
141
+ ).cuda(rank)
142
+ if torch.cuda.is_available()
143
+ else SynthesizerTrn(
144
+ hps.data.filter_length // 2 + 1,
145
+ hps.train.segment_size // hps.data.hop_length,
146
+ n_speakers=hps.data.n_speakers,
147
+ **hps.model,
148
+ ).to(device)
149
+ )
150
+
151
+ net_d = (
152
+ MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
153
+ if torch.cuda.is_available()
154
+ else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
155
+ )
156
+ for name, param in net_g.named_parameters():
157
+ if not param.requires_grad:
158
+ print(name, "not requires_grad")
159
+
160
+ te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
161
+ et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
162
+ mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
163
+ base_params = filter(
164
+ lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
165
+ net_g.parameters(),
166
+ )
167
+
168
+ # te_p=net_g.enc_p.text_embedding.parameters()
169
+ # et_p=net_g.enc_p.encoder_text.parameters()
170
+ # mrte_p=net_g.enc_p.mrte.parameters()
171
+
172
+ optim_g = torch.optim.AdamW(
173
+ # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
174
+ [
175
+ {"params": base_params, "lr": hps.train.learning_rate},
176
+ {
177
+ "params": net_g.enc_p.text_embedding.parameters(),
178
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
179
+ },
180
+ {
181
+ "params": net_g.enc_p.encoder_text.parameters(),
182
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
183
+ },
184
+ {
185
+ "params": net_g.enc_p.mrte.parameters(),
186
+ "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
187
+ },
188
+ ],
189
+ hps.train.learning_rate,
190
+ betas=hps.train.betas,
191
+ eps=hps.train.eps,
192
+ )
193
+ optim_d = torch.optim.AdamW(
194
+ net_d.parameters(),
195
+ hps.train.learning_rate,
196
+ betas=hps.train.betas,
197
+ eps=hps.train.eps,
198
+ )
199
+ if torch.cuda.is_available():
200
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
201
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
202
+ else:
203
+ net_g = net_g.to(device)
204
+ net_d = net_d.to(device)
205
+
206
+ try: # 如果能加载自动resume
207
+ _, _, _, epoch_str = utils.load_checkpoint(
208
+ utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
209
+ net_d,
210
+ optim_d,
211
+ ) # D多半加载没事
212
+ if rank == 0:
213
+ logger.info("loaded D")
214
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
215
+ _, _, _, epoch_str = utils.load_checkpoint(
216
+ utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
217
+ net_g,
218
+ optim_g,
219
+ )
220
+ epoch_str += 1
221
+ global_step = (epoch_str - 1) * len(train_loader)
222
+ # epoch_str = 1
223
+ # global_step = 0
224
+ except: # 如果首次不能加载,加载pretrain
225
+ # traceback.print_exc()
226
+ epoch_str = 1
227
+ global_step = 0
228
+ if (
229
+ hps.train.pretrained_s2G != ""
230
+ and hps.train.pretrained_s2G != None
231
+ and os.path.exists(hps.train.pretrained_s2G)
232
+ ):
233
+ if rank == 0:
234
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
235
+ print(
236
+ "loaded pretrained %s" % hps.train.pretrained_s2G,
237
+ net_g.module.load_state_dict(
238
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
239
+ strict=False,
240
+ )
241
+ if torch.cuda.is_available()
242
+ else net_g.load_state_dict(
243
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
244
+ strict=False,
245
+ ),
246
+ ) ##测试不加载优化器
247
+ if (
248
+ hps.train.pretrained_s2D != ""
249
+ and hps.train.pretrained_s2D != None
250
+ and os.path.exists(hps.train.pretrained_s2D)
251
+ ):
252
+ if rank == 0:
253
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
254
+ print(
255
+ "loaded pretrained %s" % hps.train.pretrained_s2D,
256
+ net_d.module.load_state_dict(
257
+ torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
258
+ )
259
+ if torch.cuda.is_available()
260
+ else net_d.load_state_dict(
261
+ torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
262
+ ),
263
+ )
264
+
265
+ # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
266
+ # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
267
+
268
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
269
+ optim_g,
270
+ gamma=hps.train.lr_decay,
271
+ last_epoch=-1,
272
+ )
273
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
274
+ optim_d,
275
+ gamma=hps.train.lr_decay,
276
+ last_epoch=-1,
277
+ )
278
+ for _ in range(epoch_str):
279
+ scheduler_g.step()
280
+ scheduler_d.step()
281
+
282
+ scaler = GradScaler(enabled=hps.train.fp16_run)
283
+
284
+ print("start training from epoch %s" % epoch_str)
285
+ for epoch in range(epoch_str, hps.train.epochs + 1):
286
+ if rank == 0:
287
+ train_and_evaluate(
288
+ rank,
289
+ epoch,
290
+ hps,
291
+ [net_g, net_d],
292
+ [optim_g, optim_d],
293
+ [scheduler_g, scheduler_d],
294
+ scaler,
295
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
296
+ [train_loader, None],
297
+ logger,
298
+ [writer, writer_eval],
299
+ )
300
+ else:
301
+ train_and_evaluate(
302
+ rank,
303
+ epoch,
304
+ hps,
305
+ [net_g, net_d],
306
+ [optim_g, optim_d],
307
+ [scheduler_g, scheduler_d],
308
+ scaler,
309
+ [train_loader, None],
310
+ None,
311
+ None,
312
+ )
313
+ scheduler_g.step()
314
+ scheduler_d.step()
315
+ print("training done")
316
+
317
+
318
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
319
+ net_g, net_d = nets
320
+ optim_g, optim_d = optims
321
+ # scheduler_g, scheduler_d = schedulers
322
+ train_loader, eval_loader = loaders
323
+ if writers is not None:
324
+ writer, writer_eval = writers
325
+
326
+ train_loader.batch_sampler.set_epoch(epoch)
327
+ global global_step
328
+
329
+ net_g.train()
330
+ net_d.train()
331
+ for batch_idx, (
332
+ ssl,
333
+ ssl_lengths,
334
+ spec,
335
+ spec_lengths,
336
+ y,
337
+ y_lengths,
338
+ text,
339
+ text_lengths,
340
+ ) in enumerate(tqdm(train_loader)):
341
+ if torch.cuda.is_available():
342
+ spec, spec_lengths = (
343
+ spec.cuda(
344
+ rank,
345
+ non_blocking=True,
346
+ ),
347
+ spec_lengths.cuda(
348
+ rank,
349
+ non_blocking=True,
350
+ ),
351
+ )
352
+ y, y_lengths = (
353
+ y.cuda(
354
+ rank,
355
+ non_blocking=True,
356
+ ),
357
+ y_lengths.cuda(
358
+ rank,
359
+ non_blocking=True,
360
+ ),
361
+ )
362
+ ssl = ssl.cuda(rank, non_blocking=True)
363
+ ssl.requires_grad = False
364
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
365
+ text, text_lengths = (
366
+ text.cuda(
367
+ rank,
368
+ non_blocking=True,
369
+ ),
370
+ text_lengths.cuda(
371
+ rank,
372
+ non_blocking=True,
373
+ ),
374
+ )
375
+ else:
376
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
377
+ y, y_lengths = y.to(device), y_lengths.to(device)
378
+ ssl = ssl.to(device)
379
+ ssl.requires_grad = False
380
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
381
+ text, text_lengths = text.to(device), text_lengths.to(device)
382
+
383
+ with autocast(enabled=hps.train.fp16_run):
384
+ (
385
+ y_hat,
386
+ kl_ssl,
387
+ ids_slice,
388
+ x_mask,
389
+ z_mask,
390
+ (z, z_p, m_p, logs_p, m_q, logs_q),
391
+ stats_ssl,
392
+ ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
393
+
394
+ mel = spec_to_mel_torch(
395
+ spec,
396
+ hps.data.filter_length,
397
+ hps.data.n_mel_channels,
398
+ hps.data.sampling_rate,
399
+ hps.data.mel_fmin,
400
+ hps.data.mel_fmax,
401
+ )
402
+ y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length)
403
+ y_hat_mel = mel_spectrogram_torch(
404
+ y_hat.squeeze(1),
405
+ hps.data.filter_length,
406
+ hps.data.n_mel_channels,
407
+ hps.data.sampling_rate,
408
+ hps.data.hop_length,
409
+ hps.data.win_length,
410
+ hps.data.mel_fmin,
411
+ hps.data.mel_fmax,
412
+ )
413
+
414
+ y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice
415
+
416
+ # Discriminator
417
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
418
+ with autocast(enabled=False):
419
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
420
+ y_d_hat_r,
421
+ y_d_hat_g,
422
+ )
423
+ loss_disc_all = loss_disc
424
+ optim_d.zero_grad()
425
+ scaler.scale(loss_disc_all).backward()
426
+ scaler.unscale_(optim_d)
427
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
428
+ scaler.step(optim_d)
429
+
430
+ with autocast(enabled=hps.train.fp16_run):
431
+ # Generator
432
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
433
+ with autocast(enabled=False):
434
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
435
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
436
+
437
+ loss_fm = feature_loss(fmap_r, fmap_g)
438
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
439
+ loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
440
+
441
+ optim_g.zero_grad()
442
+ scaler.scale(loss_gen_all).backward()
443
+ scaler.unscale_(optim_g)
444
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
445
+ scaler.step(optim_g)
446
+ scaler.update()
447
+
448
+ if rank == 0:
449
+ if global_step % hps.train.log_interval == 0:
450
+ lr = optim_g.param_groups[0]["lr"]
451
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
452
+ logger.info(
453
+ "Train Epoch: {} [{:.0f}%]".format(
454
+ epoch,
455
+ 100.0 * batch_idx / len(train_loader),
456
+ )
457
+ )
458
+ logger.info([x.item() for x in losses] + [global_step, lr])
459
+
460
+ scalar_dict = {
461
+ "loss/g/total": loss_gen_all,
462
+ "loss/d/total": loss_disc_all,
463
+ "learning_rate": lr,
464
+ "grad_norm_d": grad_norm_d,
465
+ "grad_norm_g": grad_norm_g,
466
+ }
467
+ scalar_dict.update(
468
+ {
469
+ "loss/g/fm": loss_fm,
470
+ "loss/g/mel": loss_mel,
471
+ "loss/g/kl_ssl": kl_ssl,
472
+ "loss/g/kl": loss_kl,
473
+ }
474
+ )
475
+
476
+ # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
477
+ # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
478
+ # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
479
+ image_dict = None
480
+ try: ###Some people installed the wrong version of matplotlib.
481
+ image_dict = {
482
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
483
+ y_mel[0].data.cpu().numpy(),
484
+ ),
485
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
486
+ y_hat_mel[0].data.cpu().numpy(),
487
+ ),
488
+ "all/mel": utils.plot_spectrogram_to_numpy(
489
+ mel[0].data.cpu().numpy(),
490
+ ),
491
+ "all/stats_ssl": utils.plot_spectrogram_to_numpy(
492
+ stats_ssl[0].data.cpu().numpy(),
493
+ ),
494
+ }
495
+ except:
496
+ pass
497
+ if image_dict:
498
+ utils.summarize(
499
+ writer=writer,
500
+ global_step=global_step,
501
+ images=image_dict,
502
+ scalars=scalar_dict,
503
+ )
504
+ else:
505
+ utils.summarize(
506
+ writer=writer,
507
+ global_step=global_step,
508
+ scalars=scalar_dict,
509
+ )
510
+ global_step += 1
511
+ if epoch % hps.train.save_every_epoch == 0 and rank == 0:
512
+ if hps.train.if_save_latest == 0:
513
+ utils.save_checkpoint(
514
+ net_g,
515
+ optim_g,
516
+ hps.train.learning_rate,
517
+ epoch,
518
+ os.path.join(
519
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
520
+ "G_{}.pth".format(global_step),
521
+ ),
522
+ )
523
+ utils.save_checkpoint(
524
+ net_d,
525
+ optim_d,
526
+ hps.train.learning_rate,
527
+ epoch,
528
+ os.path.join(
529
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
530
+ "D_{}.pth".format(global_step),
531
+ ),
532
+ )
533
+ else:
534
+ utils.save_checkpoint(
535
+ net_g,
536
+ optim_g,
537
+ hps.train.learning_rate,
538
+ epoch,
539
+ os.path.join(
540
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
541
+ "G_{}.pth".format(233333333333),
542
+ ),
543
+ )
544
+ utils.save_checkpoint(
545
+ net_d,
546
+ optim_d,
547
+ hps.train.learning_rate,
548
+ epoch,
549
+ os.path.join(
550
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
551
+ "D_{}.pth".format(233333333333),
552
+ ),
553
+ )
554
+ if rank == 0 and hps.train.if_save_every_weights == True:
555
+ if hasattr(net_g, "module"):
556
+ ckpt = net_g.module.state_dict()
557
+ else:
558
+ ckpt = net_g.state_dict()
559
+ logger.info(
560
+ "saving ckpt %s_e%s:%s"
561
+ % (
562
+ hps.name,
563
+ epoch,
564
+ savee(
565
+ ckpt,
566
+ hps.name + "_e%s_s%s" % (epoch, global_step),
567
+ epoch,
568
+ global_step,
569
+ hps,
570
+ ),
571
+ )
572
+ )
573
+
574
+ if rank == 0:
575
+ logger.info("====> Epoch: {}".format(epoch))
576
+
577
+
578
+ def evaluate(hps, generator, eval_loader, writer_eval):
579
+ generator.eval()
580
+ image_dict = {}
581
+ audio_dict = {}
582
+ print("Evaluating ...")
583
+ with torch.no_grad():
584
+ for batch_idx, (
585
+ ssl,
586
+ ssl_lengths,
587
+ spec,
588
+ spec_lengths,
589
+ y,
590
+ y_lengths,
591
+ text,
592
+ text_lengths,
593
+ ) in enumerate(eval_loader):
594
+ print(111)
595
+ if torch.cuda.is_available():
596
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
597
+ y, y_lengths = y.cuda(), y_lengths.cuda()
598
+ ssl = ssl.cuda()
599
+ text, text_lengths = text.cuda(), text_lengths.cuda()
600
+ else:
601
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
602
+ y, y_lengths = y.to(device), y_lengths.to(device)
603
+ ssl = ssl.to(device)
604
+ text, text_lengths = text.to(device), text_lengths.to(device)
605
+ for test in [0, 1]:
606
+ y_hat, mask, *_ = (
607
+ generator.module.infer(
608
+ ssl,
609
+ spec,
610
+ spec_lengths,
611
+ text,
612
+ text_lengths,
613
+ test=test,
614
+ )
615
+ if torch.cuda.is_available()
616
+ else generator.infer(
617
+ ssl,
618
+ spec,
619
+ spec_lengths,
620
+ text,
621
+ text_lengths,
622
+ test=test,
623
+ )
624
+ )
625
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
626
+
627
+ mel = spec_to_mel_torch(
628
+ spec,
629
+ hps.data.filter_length,
630
+ hps.data.n_mel_channels,
631
+ hps.data.sampling_rate,
632
+ hps.data.mel_fmin,
633
+ hps.data.mel_fmax,
634
+ )
635
+ y_hat_mel = mel_spectrogram_torch(
636
+ y_hat.squeeze(1).float(),
637
+ hps.data.filter_length,
638
+ hps.data.n_mel_channels,
639
+ hps.data.sampling_rate,
640
+ hps.data.hop_length,
641
+ hps.data.win_length,
642
+ hps.data.mel_fmin,
643
+ hps.data.mel_fmax,
644
+ )
645
+ image_dict.update(
646
+ {
647
+ f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
648
+ y_hat_mel[0].cpu().numpy(),
649
+ ),
650
+ }
651
+ )
652
+ audio_dict.update(
653
+ {
654
+ f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]],
655
+ },
656
+ )
657
+ image_dict.update(
658
+ {
659
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()),
660
+ },
661
+ )
662
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
663
+
664
+ # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
665
+ # audio_dict.update({
666
+ # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
667
+ # })
668
+
669
+ utils.summarize(
670
+ writer=writer_eval,
671
+ global_step=global_step,
672
+ images=image_dict,
673
+ audios=audio_dict,
674
+ audio_sampling_rate=hps.data.sampling_rate,
675
+ )
676
+ generator.train()
677
+
678
+
679
+ if __name__ == "__main__":
680
+ main()
GPT_SoVITS/s2_train_v3.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+ import os
5
+
6
+ import utils
7
+
8
+ hps = utils.get_hparams(stage=2)
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
10
+ import logging
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm import tqdm
20
+
21
+ logging.getLogger("matplotlib").setLevel(logging.INFO)
22
+ logging.getLogger("h5py").setLevel(logging.INFO)
23
+ logging.getLogger("numba").setLevel(logging.INFO)
24
+ from random import randint
25
+
26
+ from module import commons
27
+ from module.data_utils import (
28
+ DistributedBucketSampler,
29
+ )
30
+ from module.data_utils import (
31
+ TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
32
+ )
33
+ from module.data_utils import (
34
+ TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
35
+ )
36
+ from module.models import (
37
+ SynthesizerTrnV3 as SynthesizerTrn,
38
+ )
39
+ from process_ckpt import savee
40
+
41
+ torch.backends.cudnn.benchmark = False
42
+ torch.backends.cudnn.deterministic = False
43
+ ###反正A100fp32更快,那试试tf32吧
44
+ torch.backends.cuda.matmul.allow_tf32 = True
45
+ torch.backends.cudnn.allow_tf32 = True
46
+ torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
47
+ # from config import pretrained_s2G,pretrained_s2D
48
+ global_step = 0
49
+
50
+ device = "cpu" # cuda以外的设备,等mps优化后加入
51
+
52
+
53
+ def main():
54
+ if torch.cuda.is_available():
55
+ n_gpus = torch.cuda.device_count()
56
+ else:
57
+ n_gpus = 1
58
+ os.environ["MASTER_ADDR"] = "localhost"
59
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
60
+
61
+ mp.spawn(
62
+ run,
63
+ nprocs=n_gpus,
64
+ args=(
65
+ n_gpus,
66
+ hps,
67
+ ),
68
+ )
69
+
70
+
71
+ def run(rank, n_gpus, hps):
72
+ global global_step
73
+ if rank == 0:
74
+ logger = utils.get_logger(hps.data.exp_dir)
75
+ logger.info(hps)
76
+ # utils.check_git_hash(hps.s2_ckpt_dir)
77
+ writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
78
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
79
+
80
+ dist.init_process_group(
81
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
82
+ init_method="env://?use_libuv=False",
83
+ world_size=n_gpus,
84
+ rank=rank,
85
+ )
86
+ torch.manual_seed(hps.train.seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.set_device(rank)
89
+
90
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
91
+ train_sampler = DistributedBucketSampler(
92
+ train_dataset,
93
+ hps.train.batch_size,
94
+ [
95
+ 32,
96
+ 300,
97
+ 400,
98
+ 500,
99
+ 600,
100
+ 700,
101
+ 800,
102
+ 900,
103
+ 1000,
104
+ # 1100,
105
+ # 1200,
106
+ # 1300,
107
+ # 1400,
108
+ # 1500,
109
+ # 1600,
110
+ # 1700,
111
+ # 1800,
112
+ # 1900,
113
+ ],
114
+ num_replicas=n_gpus,
115
+ rank=rank,
116
+ shuffle=True,
117
+ )
118
+ collate_fn = TextAudioSpeakerCollate()
119
+ train_loader = DataLoader(
120
+ train_dataset,
121
+ num_workers=6,
122
+ shuffle=False,
123
+ pin_memory=True,
124
+ collate_fn=collate_fn,
125
+ batch_sampler=train_sampler,
126
+ persistent_workers=True,
127
+ prefetch_factor=4,
128
+ )
129
+ # if rank == 0:
130
+ # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
131
+ # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
132
+ # batch_size=1, pin_memory=True,
133
+ # drop_last=False, collate_fn=collate_fn)
134
+
135
+ net_g = (
136
+ SynthesizerTrn(
137
+ hps.data.filter_length // 2 + 1,
138
+ hps.train.segment_size // hps.data.hop_length,
139
+ n_speakers=hps.data.n_speakers,
140
+ **hps.model,
141
+ ).cuda(rank)
142
+ if torch.cuda.is_available()
143
+ else SynthesizerTrn(
144
+ hps.data.filter_length // 2 + 1,
145
+ hps.train.segment_size // hps.data.hop_length,
146
+ n_speakers=hps.data.n_speakers,
147
+ **hps.model,
148
+ ).to(device)
149
+ )
150
+
151
+ # net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
152
+ # for name, param in net_g.named_parameters():
153
+ # if not param.requires_grad:
154
+ # print(name, "not requires_grad")
155
+
156
+ optim_g = torch.optim.AdamW(
157
+ filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
158
+ hps.train.learning_rate,
159
+ betas=hps.train.betas,
160
+ eps=hps.train.eps,
161
+ )
162
+ # optim_d = torch.optim.AdamW(
163
+ # net_d.parameters(),
164
+ # hps.train.learning_rate,
165
+ # betas=hps.train.betas,
166
+ # eps=hps.train.eps,
167
+ # )
168
+ if torch.cuda.is_available():
169
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
170
+ # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
171
+ else:
172
+ net_g = net_g.to(device)
173
+ # net_d = net_d.to(device)
174
+
175
+ try: # 如果能加载自动resume
176
+ # _, _, _, epoch_str = utils.load_checkpoint(
177
+ # utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
178
+ # net_d,
179
+ # optim_d,
180
+ # ) # D多半加载没事
181
+ # if rank == 0:
182
+ # logger.info("loaded D")
183
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
184
+ _, _, _, epoch_str = utils.load_checkpoint(
185
+ utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
186
+ net_g,
187
+ optim_g,
188
+ )
189
+ epoch_str += 1
190
+ global_step = (epoch_str - 1) * len(train_loader)
191
+ # epoch_str = 1
192
+ # global_step = 0
193
+ except: # 如果首次不能加载,加载pretrain
194
+ # traceback.print_exc()
195
+ epoch_str = 1
196
+ global_step = 0
197
+ if (
198
+ hps.train.pretrained_s2G != ""
199
+ and hps.train.pretrained_s2G != None
200
+ and os.path.exists(hps.train.pretrained_s2G)
201
+ ):
202
+ if rank == 0:
203
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
204
+ print(
205
+ "loaded pretrained %s" % hps.train.pretrained_s2G,
206
+ net_g.module.load_state_dict(
207
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
208
+ strict=False,
209
+ )
210
+ if torch.cuda.is_available()
211
+ else net_g.load_state_dict(
212
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
213
+ strict=False,
214
+ ),
215
+ ) ##测试不加载优化器
216
+ # if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
217
+ # if rank == 0:
218
+ # logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
219
+ # print(
220
+ # net_d.module.load_state_dict(
221
+ # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
222
+ # ) if torch.cuda.is_available() else net_d.load_state_dict(
223
+ # torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
224
+ # )
225
+ # )
226
+
227
+ # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
228
+ # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
229
+
230
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
231
+ # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
232
+ # optim_d, gamma=hps.train.lr_decay, last_epoch=-1
233
+ # )
234
+ for _ in range(epoch_str):
235
+ scheduler_g.step()
236
+ # scheduler_d.step()
237
+
238
+ scaler = GradScaler(enabled=hps.train.fp16_run)
239
+
240
+ net_d = optim_d = scheduler_d = None
241
+ print("start training from epoch %s" % epoch_str)
242
+ for epoch in range(epoch_str, hps.train.epochs + 1):
243
+ if rank == 0:
244
+ train_and_evaluate(
245
+ rank,
246
+ epoch,
247
+ hps,
248
+ [net_g, net_d],
249
+ [optim_g, optim_d],
250
+ [scheduler_g, scheduler_d],
251
+ scaler,
252
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
253
+ [train_loader, None],
254
+ logger,
255
+ [writer, writer_eval],
256
+ )
257
+ else:
258
+ train_and_evaluate(
259
+ rank,
260
+ epoch,
261
+ hps,
262
+ [net_g, net_d],
263
+ [optim_g, optim_d],
264
+ [scheduler_g, scheduler_d],
265
+ scaler,
266
+ [train_loader, None],
267
+ None,
268
+ None,
269
+ )
270
+ scheduler_g.step()
271
+ # scheduler_d.step()
272
+ print("training done")
273
+
274
+
275
+ def train_and_evaluate(
276
+ rank,
277
+ epoch,
278
+ hps,
279
+ nets,
280
+ optims,
281
+ schedulers,
282
+ scaler,
283
+ loaders,
284
+ logger,
285
+ writers,
286
+ ):
287
+ net_g, net_d = nets
288
+ optim_g, optim_d = optims
289
+ # scheduler_g, scheduler_d = schedulers
290
+ train_loader, eval_loader = loaders
291
+ if writers is not None:
292
+ writer, writer_eval = writers
293
+
294
+ train_loader.batch_sampler.set_epoch(epoch)
295
+ global global_step
296
+
297
+ net_g.train()
298
+ # net_d.train()
299
+ # for batch_idx, (
300
+ # ssl,
301
+ # ssl_lengths,
302
+ # spec,
303
+ # spec_lengths,
304
+ # y,
305
+ # y_lengths,
306
+ # text,
307
+ # text_lengths,
308
+ # ) in enumerate(tqdm(train_loader)):
309
+ for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
310
+ tqdm(train_loader)
311
+ ):
312
+ if torch.cuda.is_available():
313
+ spec, spec_lengths = (
314
+ spec.cuda(
315
+ rank,
316
+ non_blocking=True,
317
+ ),
318
+ spec_lengths.cuda(
319
+ rank,
320
+ non_blocking=True,
321
+ ),
322
+ )
323
+ mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
324
+ ssl = ssl.cuda(rank, non_blocking=True)
325
+ ssl.requires_grad = False
326
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
327
+ text, text_lengths = (
328
+ text.cuda(
329
+ rank,
330
+ non_blocking=True,
331
+ ),
332
+ text_lengths.cuda(
333
+ rank,
334
+ non_blocking=True,
335
+ ),
336
+ )
337
+ else:
338
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
339
+ mel, mel_lengths = mel.to(device), mel_lengths.to(device)
340
+ ssl = ssl.to(device)
341
+ ssl.requires_grad = False
342
+ # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
343
+ text, text_lengths = text.to(device), text_lengths.to(device)
344
+
345
+ with autocast(enabled=hps.train.fp16_run):
346
+ cfm_loss = net_g(
347
+ ssl,
348
+ spec,
349
+ mel,
350
+ ssl_lengths,
351
+ spec_lengths,
352
+ text,
353
+ text_lengths,
354
+ mel_lengths,
355
+ use_grad_ckpt=hps.train.grad_ckpt,
356
+ )
357
+ loss_gen_all = cfm_loss
358
+ optim_g.zero_grad()
359
+ scaler.scale(loss_gen_all).backward()
360
+ scaler.unscale_(optim_g)
361
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
362
+ scaler.step(optim_g)
363
+ scaler.update()
364
+
365
+ if rank == 0:
366
+ if global_step % hps.train.log_interval == 0:
367
+ lr = optim_g.param_groups[0]["lr"]
368
+ # losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
369
+ losses = [cfm_loss]
370
+ logger.info(
371
+ "Train Epoch: {} [{:.0f}%]".format(
372
+ epoch,
373
+ 100.0 * batch_idx / len(train_loader),
374
+ )
375
+ )
376
+ logger.info([x.item() for x in losses] + [global_step, lr])
377
+
378
+ scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
379
+ # image_dict = {
380
+ # "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
381
+ # "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
382
+ # "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
383
+ # "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
384
+ # }
385
+ utils.summarize(
386
+ writer=writer,
387
+ global_step=global_step,
388
+ # images=image_dict,
389
+ scalars=scalar_dict,
390
+ )
391
+
392
+ # if global_step % hps.train.eval_interval == 0:
393
+ # # evaluate(hps, net_g, eval_loader, writer_eval)
394
+ # utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler)
395
+ # # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler)
396
+ # # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3)
397
+ # # if keep_ckpts > 0:
398
+ # # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
399
+
400
+ global_step += 1
401
+ if epoch % hps.train.save_every_epoch == 0 and rank == 0:
402
+ if hps.train.if_save_latest == 0:
403
+ utils.save_checkpoint(
404
+ net_g,
405
+ optim_g,
406
+ hps.train.learning_rate,
407
+ epoch,
408
+ os.path.join(
409
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
410
+ "G_{}.pth".format(global_step),
411
+ ),
412
+ )
413
+ # utils.save_checkpoint(
414
+ # net_d,
415
+ # optim_d,
416
+ # hps.train.learning_rate,
417
+ # epoch,
418
+ # os.path.join(
419
+ # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
420
+ # ),
421
+ # )
422
+ else:
423
+ utils.save_checkpoint(
424
+ net_g,
425
+ optim_g,
426
+ hps.train.learning_rate,
427
+ epoch,
428
+ os.path.join(
429
+ "%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
430
+ "G_{}.pth".format(233333333333),
431
+ ),
432
+ )
433
+ # utils.save_checkpoint(
434
+ # net_d,
435
+ # optim_d,
436
+ # hps.train.learning_rate,
437
+ # epoch,
438
+ # os.path.join(
439
+ # "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
440
+ # ),
441
+ # )
442
+ if rank == 0 and hps.train.if_save_every_weights == True:
443
+ if hasattr(net_g, "module"):
444
+ ckpt = net_g.module.state_dict()
445
+ else:
446
+ ckpt = net_g.state_dict()
447
+ logger.info(
448
+ "saving ckpt %s_e%s:%s"
449
+ % (
450
+ hps.name,
451
+ epoch,
452
+ savee(
453
+ ckpt,
454
+ hps.name + "_e%s_s%s" % (epoch, global_step),
455
+ epoch,
456
+ global_step,
457
+ hps,
458
+ ),
459
+ )
460
+ )
461
+
462
+ if rank == 0:
463
+ logger.info("====> Epoch: {}".format(epoch))
464
+
465
+
466
+ if __name__ == "__main__":
467
+ main()
GPT_SoVITS/s2_train_v3_lora.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+ import os
5
+
6
+ import utils
7
+
8
+ hps = utils.get_hparams(stage=2)
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
10
+ import logging
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.multiprocessing as mp
15
+ from torch.cuda.amp import GradScaler, autocast
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm import tqdm
20
+
21
+ logging.getLogger("matplotlib").setLevel(logging.INFO)
22
+ logging.getLogger("h5py").setLevel(logging.INFO)
23
+ logging.getLogger("numba").setLevel(logging.INFO)
24
+ from collections import OrderedDict as od
25
+ from random import randint
26
+
27
+ from module import commons
28
+ from module.data_utils import (
29
+ DistributedBucketSampler,
30
+ TextAudioSpeakerCollateV3,
31
+ TextAudioSpeakerLoaderV3,
32
+ TextAudioSpeakerCollateV4,
33
+ TextAudioSpeakerLoaderV4,
34
+
35
+ )
36
+ from module.models import (
37
+ SynthesizerTrnV3 as SynthesizerTrn,
38
+ )
39
+ from peft import LoraConfig, get_peft_model
40
+ from process_ckpt import savee
41
+
42
+ torch.backends.cudnn.benchmark = False
43
+ torch.backends.cudnn.deterministic = False
44
+ ###反正A100fp32更快,那试试tf32吧
45
+ torch.backends.cuda.matmul.allow_tf32 = True
46
+ torch.backends.cudnn.allow_tf32 = True
47
+ torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
48
+ # from config import pretrained_s2G,pretrained_s2D
49
+ global_step = 0
50
+
51
+ device = "cpu" # cuda以外的设备,等mps优化后加入
52
+
53
+
54
+ def main():
55
+ if torch.cuda.is_available():
56
+ n_gpus = torch.cuda.device_count()
57
+ else:
58
+ n_gpus = 1
59
+ os.environ["MASTER_ADDR"] = "localhost"
60
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
61
+
62
+ mp.spawn(
63
+ run,
64
+ nprocs=n_gpus,
65
+ args=(
66
+ n_gpus,
67
+ hps,
68
+ ),
69
+ )
70
+
71
+
72
+ def run(rank, n_gpus, hps):
73
+ global global_step, no_grad_names, save_root, lora_rank
74
+ if rank == 0:
75
+ logger = utils.get_logger(hps.data.exp_dir)
76
+ logger.info(hps)
77
+ # utils.check_git_hash(hps.s2_ckpt_dir)
78
+ writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
79
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
80
+
81
+ dist.init_process_group(
82
+ backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
83
+ init_method="env://?use_libuv=False",
84
+ world_size=n_gpus,
85
+ rank=rank,
86
+ )
87
+ torch.manual_seed(hps.train.seed)
88
+ if torch.cuda.is_available():
89
+ torch.cuda.set_device(rank)
90
+
91
+ TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4
92
+ TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4
93
+ train_dataset = TextAudioSpeakerLoader(hps.data) ########
94
+ train_sampler = DistributedBucketSampler(
95
+ train_dataset,
96
+ hps.train.batch_size,
97
+ [
98
+ 32,
99
+ 300,
100
+ 400,
101
+ 500,
102
+ 600,
103
+ 700,
104
+ 800,
105
+ 900,
106
+ 1000,
107
+ # 1100,
108
+ # 1200,
109
+ # 1300,
110
+ # 1400,
111
+ # 1500,
112
+ # 1600,
113
+ # 1700,
114
+ # 1800,
115
+ # 1900,
116
+ ],
117
+ num_replicas=n_gpus,
118
+ rank=rank,
119
+ shuffle=True,
120
+ )
121
+ collate_fn = TextAudioSpeakerCollate()
122
+ train_loader = DataLoader(
123
+ train_dataset,
124
+ num_workers=6,
125
+ shuffle=False,
126
+ pin_memory=True,
127
+ collate_fn=collate_fn,
128
+ batch_sampler=train_sampler,
129
+ persistent_workers=True,
130
+ prefetch_factor=4,
131
+ )
132
+ save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
133
+ os.makedirs(save_root, exist_ok=True)
134
+ lora_rank = int(hps.train.lora_rank)
135
+ lora_config = LoraConfig(
136
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
137
+ r=lora_rank,
138
+ lora_alpha=lora_rank,
139
+ init_lora_weights=True,
140
+ )
141
+
142
+ def get_model(hps):
143
+ return SynthesizerTrn(
144
+ hps.data.filter_length // 2 + 1,
145
+ hps.train.segment_size // hps.data.hop_length,
146
+ n_speakers=hps.data.n_speakers,
147
+ **hps.model,
148
+ )
149
+
150
+ def get_optim(net_g):
151
+ return torch.optim.AdamW(
152
+ filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
153
+ hps.train.learning_rate,
154
+ betas=hps.train.betas,
155
+ eps=hps.train.eps,
156
+ )
157
+
158
+ def model2cuda(net_g, rank):
159
+ if torch.cuda.is_available():
160
+ net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
161
+ else:
162
+ net_g = net_g.to(device)
163
+ return net_g
164
+
165
+ try: # 如果能加载自动resume
166
+ net_g = get_model(hps)
167
+ net_g.cfm = get_peft_model(net_g.cfm, lora_config)
168
+ net_g = model2cuda(net_g, rank)
169
+ optim_g = get_optim(net_g)
170
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
171
+ _, _, _, epoch_str = utils.load_checkpoint(
172
+ utils.latest_checkpoint_path(save_root, "G_*.pth"),
173
+ net_g,
174
+ optim_g,
175
+ )
176
+ epoch_str += 1
177
+ global_step = (epoch_str - 1) * len(train_loader)
178
+ except: # 如果首次不能加载,加载pretrain
179
+ # traceback.print_exc()
180
+ epoch_str = 1
181
+ global_step = 0
182
+ net_g = get_model(hps)
183
+ if (
184
+ hps.train.pretrained_s2G != ""
185
+ and hps.train.pretrained_s2G != None
186
+ and os.path.exists(hps.train.pretrained_s2G)
187
+ ):
188
+ if rank == 0:
189
+ logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
190
+ print(
191
+ "loaded pretrained %s" % hps.train.pretrained_s2G,
192
+ net_g.load_state_dict(
193
+ torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
194
+ strict=False,
195
+ ),
196
+ )
197
+ net_g.cfm = get_peft_model(net_g.cfm, lora_config)
198
+ net_g = model2cuda(net_g, rank)
199
+ optim_g = get_optim(net_g)
200
+
201
+ no_grad_names = set()
202
+ for name, param in net_g.named_parameters():
203
+ if not param.requires_grad:
204
+ no_grad_names.add(name.replace("module.", ""))
205
+ # print(name, "not requires_grad")
206
+ # print(no_grad_names)
207
+ # os._exit(233333)
208
+
209
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
210
+ for _ in range(epoch_str):
211
+ scheduler_g.step()
212
+
213
+ scaler = GradScaler(enabled=hps.train.fp16_run)
214
+
215
+ net_d = optim_d = scheduler_d = None
216
+ print("start training from epoch %s" % epoch_str)
217
+ for epoch in range(epoch_str, hps.train.epochs + 1):
218
+ if rank == 0:
219
+ train_and_evaluate(
220
+ rank,
221
+ epoch,
222
+ hps,
223
+ [net_g, net_d],
224
+ [optim_g, optim_d],
225
+ [scheduler_g, scheduler_d],
226
+ scaler,
227
+ # [train_loader, eval_loader], logger, [writer, writer_eval])
228
+ [train_loader, None],
229
+ logger,
230
+ [writer, writer_eval],
231
+ )
232
+ else:
233
+ train_and_evaluate(
234
+ rank,
235
+ epoch,
236
+ hps,
237
+ [net_g, net_d],
238
+ [optim_g, optim_d],
239
+ [scheduler_g, scheduler_d],
240
+ scaler,
241
+ [train_loader, None],
242
+ None,
243
+ None,
244
+ )
245
+ scheduler_g.step()
246
+ print("training done")
247
+
248
+
249
+ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
250
+ net_g, net_d = nets
251
+ optim_g, optim_d = optims
252
+ # scheduler_g, scheduler_d = schedulers
253
+ train_loader, eval_loader = loaders
254
+ if writers is not None:
255
+ writer, writer_eval = writers
256
+
257
+ train_loader.batch_sampler.set_epoch(epoch)
258
+ global global_step
259
+
260
+ net_g.train()
261
+ for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
262
+ tqdm(train_loader)
263
+ ):
264
+ if torch.cuda.is_available():
265
+ spec, spec_lengths = (
266
+ spec.cuda(
267
+ rank,
268
+ non_blocking=True,
269
+ ),
270
+ spec_lengths.cuda(
271
+ rank,
272
+ non_blocking=True,
273
+ ),
274
+ )
275
+ mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
276
+ ssl = ssl.cuda(rank, non_blocking=True)
277
+ ssl.requires_grad = False
278
+ text, text_lengths = (
279
+ text.cuda(
280
+ rank,
281
+ non_blocking=True,
282
+ ),
283
+ text_lengths.cuda(
284
+ rank,
285
+ non_blocking=True,
286
+ ),
287
+ )
288
+ else:
289
+ spec, spec_lengths = spec.to(device), spec_lengths.to(device)
290
+ mel, mel_lengths = mel.to(device), mel_lengths.to(device)
291
+ ssl = ssl.to(device)
292
+ ssl.requires_grad = False
293
+ text, text_lengths = text.to(device), text_lengths.to(device)
294
+
295
+ with autocast(enabled=hps.train.fp16_run):
296
+ cfm_loss = net_g(
297
+ ssl,
298
+ spec,
299
+ mel,
300
+ ssl_lengths,
301
+ spec_lengths,
302
+ text,
303
+ text_lengths,
304
+ mel_lengths,
305
+ use_grad_ckpt=hps.train.grad_ckpt,
306
+ )
307
+ loss_gen_all = cfm_loss
308
+ optim_g.zero_grad()
309
+ scaler.scale(loss_gen_all).backward()
310
+ scaler.unscale_(optim_g)
311
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
312
+ scaler.step(optim_g)
313
+ scaler.update()
314
+
315
+ if rank == 0:
316
+ if global_step % hps.train.log_interval == 0:
317
+ lr = optim_g.param_groups[0]["lr"]
318
+ losses = [cfm_loss]
319
+ logger.info("Train Epoch: {} [{:.0f}%]".format(epoch, 100.0 * batch_idx / len(train_loader)))
320
+ logger.info([x.item() for x in losses] + [global_step, lr])
321
+
322
+ scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
323
+ utils.summarize(
324
+ writer=writer,
325
+ global_step=global_step,
326
+ scalars=scalar_dict,
327
+ )
328
+
329
+ global_step += 1
330
+ if epoch % hps.train.save_every_epoch == 0 and rank == 0:
331
+ if hps.train.if_save_latest == 0:
332
+ utils.save_checkpoint(
333
+ net_g,
334
+ optim_g,
335
+ hps.train.learning_rate,
336
+ epoch,
337
+ os.path.join(save_root, "G_{}.pth".format(global_step)),
338
+ )
339
+ else:
340
+ utils.save_checkpoint(
341
+ net_g,
342
+ optim_g,
343
+ hps.train.learning_rate,
344
+ epoch,
345
+ os.path.join(save_root, "G_{}.pth".format(233333333333)),
346
+ )
347
+ if rank == 0 and hps.train.if_save_every_weights == True:
348
+ if hasattr(net_g, "module"):
349
+ ckpt = net_g.module.state_dict()
350
+ else:
351
+ ckpt = net_g.state_dict()
352
+ sim_ckpt = od()
353
+ for key in ckpt:
354
+ # if "cfm"not in key:
355
+ # print(key)
356
+ if key not in no_grad_names:
357
+ sim_ckpt[key] = ckpt[key].half().cpu()
358
+ logger.info(
359
+ "saving ckpt %s_e%s:%s"
360
+ % (
361
+ hps.name,
362
+ epoch,
363
+ savee(
364
+ sim_ckpt,
365
+ hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
366
+ epoch,
367
+ global_step,
368
+ hps,cfm_version=hps.model.version,
369
+ lora_rank=lora_rank,
370
+ ),
371
+ )
372
+ )
373
+
374
+ if rank == 0:
375
+ logger.info("====> Epoch: {}".format(epoch))
376
+
377
+
378
+ if __name__ == "__main__":
379
+ main()
GPT_SoVITS/utils.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import traceback
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import torch
13
+
14
+ logging.getLogger("numba").setLevel(logging.ERROR)
15
+ logging.getLogger("matplotlib").setLevel(logging.ERROR)
16
+
17
+ MATPLOTLIB_FLAG = False
18
+
19
+ logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
20
+ logger = logging
21
+
22
+
23
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
24
+ assert os.path.isfile(checkpoint_path)
25
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
26
+ iteration = checkpoint_dict["iteration"]
27
+ learning_rate = checkpoint_dict["learning_rate"]
28
+ if optimizer is not None and not skip_optimizer and checkpoint_dict["optimizer"] is not None:
29
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
30
+ saved_state_dict = checkpoint_dict["model"]
31
+ if hasattr(model, "module"):
32
+ state_dict = model.module.state_dict()
33
+ else:
34
+ state_dict = model.state_dict()
35
+ new_state_dict = {}
36
+ for k, v in state_dict.items():
37
+ try:
38
+ # assert "quantizer" not in k
39
+ # print("load", k)
40
+ new_state_dict[k] = saved_state_dict[k]
41
+ assert saved_state_dict[k].shape == v.shape, (
42
+ saved_state_dict[k].shape,
43
+ v.shape,
44
+ )
45
+ except:
46
+ traceback.print_exc()
47
+ print("error, %s is not in the checkpoint" % k) # shape不对也会,比如text_embedding当cleaner修改时
48
+ new_state_dict[k] = v
49
+ if hasattr(model, "module"):
50
+ model.module.load_state_dict(new_state_dict)
51
+ else:
52
+ model.load_state_dict(new_state_dict)
53
+ print("load ")
54
+ logger.info(
55
+ "Loaded checkpoint '{}' (iteration {})".format(
56
+ checkpoint_path,
57
+ iteration,
58
+ )
59
+ )
60
+ return model, optimizer, learning_rate, iteration
61
+
62
+
63
+ import shutil
64
+ from time import time as ttime
65
+
66
+
67
+ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
68
+ dir = os.path.dirname(path)
69
+ name = os.path.basename(path)
70
+ tmp_path = "%s.pth" % (ttime())
71
+ torch.save(fea, tmp_path)
72
+ shutil.move(tmp_path, "%s/%s" % (dir, name))
73
+
74
+
75
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
76
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(iteration, checkpoint_path))
77
+ if hasattr(model, "module"):
78
+ state_dict = model.module.state_dict()
79
+ else:
80
+ state_dict = model.state_dict()
81
+ # torch.save(
82
+ my_save(
83
+ {
84
+ "model": state_dict,
85
+ "iteration": iteration,
86
+ "optimizer": optimizer.state_dict(),
87
+ "learning_rate": learning_rate,
88
+ },
89
+ checkpoint_path,
90
+ )
91
+
92
+
93
+ def summarize(
94
+ writer,
95
+ global_step,
96
+ scalars={},
97
+ histograms={},
98
+ images={},
99
+ audios={},
100
+ audio_sampling_rate=22050,
101
+ ):
102
+ for k, v in scalars.items():
103
+ writer.add_scalar(k, v, global_step)
104
+ for k, v in histograms.items():
105
+ writer.add_histogram(k, v, global_step)
106
+ for k, v in images.items():
107
+ writer.add_image(k, v, global_step, dataformats="HWC")
108
+ for k, v in audios.items():
109
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
110
+
111
+
112
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
113
+ f_list = glob.glob(os.path.join(dir_path, regex))
114
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
115
+ x = f_list[-1]
116
+ print(x)
117
+ return x
118
+
119
+
120
+ def plot_spectrogram_to_numpy(spectrogram):
121
+ global MATPLOTLIB_FLAG
122
+ if not MATPLOTLIB_FLAG:
123
+ import matplotlib
124
+
125
+ matplotlib.use("Agg")
126
+ MATPLOTLIB_FLAG = True
127
+ mpl_logger = logging.getLogger("matplotlib")
128
+ mpl_logger.setLevel(logging.WARNING)
129
+ import matplotlib.pylab as plt
130
+
131
+ fig, ax = plt.subplots(figsize=(10, 2))
132
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
133
+ plt.colorbar(im, ax=ax)
134
+ plt.xlabel("Frames")
135
+ plt.ylabel("Channels")
136
+ plt.tight_layout()
137
+
138
+ fig.canvas.draw()
139
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
140
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
141
+ plt.close()
142
+ return data
143
+
144
+
145
+ def plot_alignment_to_numpy(alignment, info=None):
146
+ global MATPLOTLIB_FLAG
147
+ if not MATPLOTLIB_FLAG:
148
+ import matplotlib
149
+
150
+ matplotlib.use("Agg")
151
+ MATPLOTLIB_FLAG = True
152
+ mpl_logger = logging.getLogger("matplotlib")
153
+ mpl_logger.setLevel(logging.WARNING)
154
+ import matplotlib.pylab as plt
155
+
156
+ fig, ax = plt.subplots(figsize=(6, 4))
157
+ im = ax.imshow(
158
+ alignment.transpose(),
159
+ aspect="auto",
160
+ origin="lower",
161
+ interpolation="none",
162
+ )
163
+ fig.colorbar(im, ax=ax)
164
+ xlabel = "Decoder timestep"
165
+ if info is not None:
166
+ xlabel += "\n\n" + info
167
+ plt.xlabel(xlabel)
168
+ plt.ylabel("Encoder timestep")
169
+ plt.tight_layout()
170
+
171
+ fig.canvas.draw()
172
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
173
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
174
+ plt.close()
175
+ return data
176
+
177
+
178
+ def load_wav_to_torch(full_path):
179
+ data, sampling_rate = librosa.load(full_path, sr=None)
180
+ return torch.FloatTensor(data), sampling_rate
181
+
182
+
183
+ def load_filepaths_and_text(filename, split="|"):
184
+ with open(filename, encoding="utf-8") as f:
185
+ filepaths_and_text = [line.strip().split(split) for line in f]
186
+ return filepaths_and_text
187
+
188
+
189
+ def get_hparams(init=True, stage=1):
190
+ parser = argparse.ArgumentParser()
191
+ parser.add_argument(
192
+ "-c",
193
+ "--config",
194
+ type=str,
195
+ default="./configs/s2.json",
196
+ help="JSON file for configuration",
197
+ )
198
+ parser.add_argument("-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir")
199
+ parser.add_argument(
200
+ "-rs",
201
+ "--resume_step",
202
+ type=int,
203
+ required=False,
204
+ default=None,
205
+ help="resume step",
206
+ )
207
+ # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
208
+ # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
209
+ # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
210
+
211
+ args = parser.parse_args()
212
+
213
+ config_path = args.config
214
+ with open(config_path, "r") as f:
215
+ data = f.read()
216
+ config = json.loads(data)
217
+
218
+ hparams = HParams(**config)
219
+ hparams.pretrain = args.pretrain
220
+ hparams.resume_step = args.resume_step
221
+ # hparams.data.exp_dir = args.exp_dir
222
+ if stage == 1:
223
+ model_dir = hparams.s1_ckpt_dir
224
+ else:
225
+ model_dir = hparams.s2_ckpt_dir
226
+ config_save_path = os.path.join(model_dir, "config.json")
227
+
228
+ if not os.path.exists(model_dir):
229
+ os.makedirs(model_dir)
230
+
231
+ with open(config_save_path, "w") as f:
232
+ f.write(data)
233
+ return hparams
234
+
235
+
236
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
237
+ """Freeing up space by deleting saved ckpts
238
+
239
+ Arguments:
240
+ path_to_models -- Path to the model directory
241
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
242
+ sort_by_time -- True -> chronologically delete ckpts
243
+ False -> lexicographically delete ckpts
244
+ """
245
+ import re
246
+
247
+ ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
248
+ name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
249
+ time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
250
+ sort_key = time_key if sort_by_time else name_key
251
+ x_sorted = lambda _x: sorted(
252
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
253
+ key=sort_key,
254
+ )
255
+ to_del = [
256
+ os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
257
+ ]
258
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
259
+ del_routine = lambda x: [os.remove(x), del_info(x)]
260
+ rs = [del_routine(fn) for fn in to_del]
261
+
262
+
263
+ def get_hparams_from_dir(model_dir):
264
+ config_save_path = os.path.join(model_dir, "config.json")
265
+ with open(config_save_path, "r") as f:
266
+ data = f.read()
267
+ config = json.loads(data)
268
+
269
+ hparams = HParams(**config)
270
+ hparams.model_dir = model_dir
271
+ return hparams
272
+
273
+
274
+ def get_hparams_from_file(config_path):
275
+ with open(config_path, "r") as f:
276
+ data = f.read()
277
+ config = json.loads(data)
278
+
279
+ hparams = HParams(**config)
280
+ return hparams
281
+
282
+
283
+ def check_git_hash(model_dir):
284
+ source_dir = os.path.dirname(os.path.realpath(__file__))
285
+ if not os.path.exists(os.path.join(source_dir, ".git")):
286
+ logger.warn(
287
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
288
+ source_dir,
289
+ )
290
+ )
291
+ return
292
+
293
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
294
+
295
+ path = os.path.join(model_dir, "githash")
296
+ if os.path.exists(path):
297
+ saved_hash = open(path).read()
298
+ if saved_hash != cur_hash:
299
+ logger.warn(
300
+ "git hash values are different. {}(saved) != {}(current)".format(
301
+ saved_hash[:8],
302
+ cur_hash[:8],
303
+ )
304
+ )
305
+ else:
306
+ open(path, "w").write(cur_hash)
307
+
308
+
309
+ def get_logger(model_dir, filename="train.log"):
310
+ global logger
311
+ logger = logging.getLogger(os.path.basename(model_dir))
312
+ logger.setLevel(logging.ERROR)
313
+
314
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
315
+ if not os.path.exists(model_dir):
316
+ os.makedirs(model_dir)
317
+ h = logging.FileHandler(os.path.join(model_dir, filename))
318
+ h.setLevel(logging.ERROR)
319
+ h.setFormatter(formatter)
320
+ logger.addHandler(h)
321
+ return logger
322
+
323
+
324
+ class HParams:
325
+ def __init__(self, **kwargs):
326
+ for k, v in kwargs.items():
327
+ if type(v) == dict:
328
+ v = HParams(**v)
329
+ self[k] = v
330
+
331
+ def keys(self):
332
+ return self.__dict__.keys()
333
+
334
+ def items(self):
335
+ return self.__dict__.items()
336
+
337
+ def values(self):
338
+ return self.__dict__.values()
339
+
340
+ def __len__(self):
341
+ return len(self.__dict__)
342
+
343
+ def __getitem__(self, key):
344
+ return getattr(self, key)
345
+
346
+ def __setitem__(self, key, value):
347
+ return setattr(self, key, value)
348
+
349
+ def __contains__(self, key):
350
+ return key in self.__dict__
351
+
352
+ def __repr__(self):
353
+ return self.__dict__.__repr__()
354
+
355
+
356
+ if __name__ == "__main__":
357
+ print(
358
+ load_wav_to_torch(
359
+ "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac",
360
+ )
361
+ )