Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
52e4f53
0
Parent(s):
-a
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +164 -0
- LICENSE +180 -0
- README.md +278 -0
- app.py +378 -0
- configs/sts_finetune_stage1.yaml +273 -0
- configs/sts_finetune_stage2.yaml +273 -0
- evaluation/compute-acc-of-contain.py +85 -0
- evaluation/compute-cer.py +559 -0
- evaluation/compute-wer.py +553 -0
- evaluation/evaluate_asr.py +379 -0
- evaluation/evaluate_libritts.py +384 -0
- evaluation/evaluate_seedtts.py +394 -0
- evaluation/evaluate_sqa.py +451 -0
- evaluation/get_chat_template.py +59 -0
- requirements.txt +1 -0
- requirements_ds_gpu.txt +44 -0
- scripts/deepspeed/ds_config_zero1.json +61 -0
- scripts/deepspeed/ds_config_zero2.json +61 -0
- scripts/deepspeed/ds_config_zero2_no_optimizer.json +52 -0
- scripts/deepspeed/ds_config_zero2_offload.json +61 -0
- scripts/deepspeed/ds_config_zero3.json +63 -0
- scripts/deepspeed/ds_config_zero3_offload.json +75 -0
- scripts/deepspeed/evaluate_sts.sh +348 -0
- scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh +137 -0
- scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh +137 -0
- scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh +137 -0
- scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh +136 -0
- scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh +138 -0
- scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage2.sh +138 -0
- scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh +138 -0
- scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh +137 -0
- scripts/set_env_ds_gpu.sh +53 -0
- setup.py +12 -0
- third_party/GLM-4-Voice/.gitignore +4 -0
- third_party/GLM-4-Voice/.gitmodules +3 -0
- third_party/GLM-4-Voice/LICENSE +201 -0
- third_party/GLM-4-Voice/README.md +159 -0
- third_party/GLM-4-Voice/README_en.md +148 -0
- third_party/GLM-4-Voice/audio_process.py +93 -0
- third_party/GLM-4-Voice/cosyvoice/__init__.py +0 -0
- third_party/GLM-4-Voice/cosyvoice/bin/inference.py +114 -0
- third_party/GLM-4-Voice/cosyvoice/bin/train.py +140 -0
- third_party/GLM-4-Voice/cosyvoice/cli/__init__.py +0 -0
- third_party/GLM-4-Voice/cosyvoice/cli/cosyvoice.py +83 -0
- third_party/GLM-4-Voice/cosyvoice/cli/frontend.py +168 -0
- third_party/GLM-4-Voice/cosyvoice/cli/model.py +95 -0
- third_party/GLM-4-Voice/cosyvoice/dataset/__init__.py +0 -0
- third_party/GLM-4-Voice/cosyvoice/dataset/dataset.py +160 -0
- third_party/GLM-4-Voice/cosyvoice/dataset/processor.py +965 -0
- third_party/GLM-4-Voice/cosyvoice/flow/decoder.py +222 -0
.gitignore
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
#
|
164 |
+
*.sw*
|
LICENSE
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
|
2 |
+
|
3 |
+
License Terms of the VITA1.5:
|
4 |
+
--------------------------------------------------------------------
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
- You agree to use the VITA1.5 only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
9 |
+
|
10 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
11 |
+
|
12 |
+
For avoidance of doubts, "Software" means the VITA1.5 model inference-enabling code, and weights made available under this license.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
15 |
+
|
16 |
+
|
17 |
+
Other dependencies and licenses:
|
18 |
+
|
19 |
+
|
20 |
+
Open Source Model Licensed under the Apache License Version 2.0:
|
21 |
+
The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"), as model weights provided for the VITA1.5 Project hereunder is fine-tuned with the assistance of below model.
|
22 |
+
|
23 |
+
All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
24 |
+
--------------------------------------------------------------------
|
25 |
+
1. Qwen2-7B-Instruct
|
26 |
+
Copyright 2024 Alibaba Cloud
|
27 |
+
|
28 |
+
Terms of the Apache License Version 2.0:
|
29 |
+
--------------------------------------------------------------------
|
30 |
+
Apache License
|
31 |
+
|
32 |
+
Version 2.0, January 2004
|
33 |
+
|
34 |
+
http://www.apache.org/licenses/
|
35 |
+
|
36 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
37 |
+
1. Definitions.
|
38 |
+
|
39 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
40 |
+
|
41 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
42 |
+
|
43 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
44 |
+
|
45 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
46 |
+
|
47 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
48 |
+
|
49 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
50 |
+
|
51 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
52 |
+
|
53 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
54 |
+
|
55 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
56 |
+
|
57 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
58 |
+
|
59 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
60 |
+
|
61 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
62 |
+
|
63 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
64 |
+
|
65 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
66 |
+
|
67 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
68 |
+
|
69 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
70 |
+
|
71 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
72 |
+
|
73 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
74 |
+
|
75 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
76 |
+
|
77 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
78 |
+
|
79 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
80 |
+
|
81 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
82 |
+
|
83 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
84 |
+
|
85 |
+
END OF TERMS AND CONDITIONS
|
86 |
+
|
87 |
+
|
88 |
+
Open Source Model/Software Licensed under the Apache License Version 2.0:
|
89 |
+
The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
90 |
+
--------------------------------------------------------------------
|
91 |
+
1. ModelLink
|
92 |
+
Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
|
93 |
+
|
94 |
+
A copy of the Apache License Version 2.0 is included in this file.
|
95 |
+
|
96 |
+
|
97 |
+
Open Source Model/Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
98 |
+
--------------------------------------------------------------------
|
99 |
+
1. opencv
|
100 |
+
Copyright (C) 2000-2022, Intel Corporation, all rights reserved.
|
101 |
+
Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved.
|
102 |
+
Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved.
|
103 |
+
Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved.
|
104 |
+
Copyright (C) 2015-2023, OpenCV Foundation, all rights reserved.
|
105 |
+
Copyright (C) 2008-2016, Itseez Inc., all rights reserved.
|
106 |
+
Copyright (C) 2019-2023, Xperience AI, all rights reserved.
|
107 |
+
Copyright (C) 2019-2022, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
|
108 |
+
Copyright (C) 2022-2023, Southern University of Science And Technology, all rights reserved.
|
109 |
+
|
110 |
+
A copy of the Apache 2.0 is included in this file.
|
111 |
+
|
112 |
+
For the license of other third party components, please refer to the following URL:
|
113 |
+
https://github.com/opencv/opencv/tree/4.10.0/3rdparty
|
114 |
+
|
115 |
+
|
116 |
+
Open Source Model/Software Licensed under the BSD 3-Clause License:
|
117 |
+
--------------------------------------------------------------------
|
118 |
+
1. flask
|
119 |
+
Copyright 2010 Pallets
|
120 |
+
|
121 |
+
2. flask-restful
|
122 |
+
Copyright (c) 2013, Twilio, Inc.
|
123 |
+
All rights reserved.
|
124 |
+
|
125 |
+
|
126 |
+
Terms of the BSD 3-Clause License:
|
127 |
+
--------------------------------------------------------------------
|
128 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
129 |
+
|
130 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
131 |
+
|
132 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
133 |
+
|
134 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
135 |
+
|
136 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
Open Source Model/Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
141 |
+
--------------------------------------------------------------------
|
142 |
+
1. Megatron-LM
|
143 |
+
Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
144 |
+
|
145 |
+
|
146 |
+
A copy of the BSD 3-Clause is included in this file.
|
147 |
+
|
148 |
+
For the license of other third party components, please refer to the following URL:
|
149 |
+
https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE
|
150 |
+
|
151 |
+
|
152 |
+
Open Source Model/Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
153 |
+
--------------------------------------------------------------------
|
154 |
+
1. MindSpeed
|
155 |
+
Copyright (c) 2024, Bytedance Inc.
|
156 |
+
Copyright (c) 2023, Huawei Technologies Co., Ltd
|
157 |
+
Copyright (c) 2022, NVIDIA CORPORATION.
|
158 |
+
All rights reserved.
|
159 |
+
|
160 |
+
|
161 |
+
A copy of the BSD 3-Clause is included in this file.
|
162 |
+
|
163 |
+
For the license of other third party components, please refer to the following URL:
|
164 |
+
https://gitee.com/ascend/MindSpeed/blob/master/LICENSE
|
165 |
+
|
166 |
+
|
167 |
+
Open Source Model/Software Licensed under the MIT License:
|
168 |
+
--------------------------------------------------------------------
|
169 |
+
1. natsort
|
170 |
+
Copyright (c) 2012-2023 Seth M. Morton
|
171 |
+
|
172 |
+
|
173 |
+
Terms of the MIT License:
|
174 |
+
--------------------------------------------------------------------
|
175 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
176 |
+
|
177 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
178 |
+
|
179 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
180 |
+
VertiTab
|
README.md
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VITA-Audio: Fast Interleaved Audio-Text Token Generation for Efficient Large Speech-Language Model
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="asset/VITA_audio_logos.png" width="50%" height="50%">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
<a href="https://arxiv.org/abs/2502.05177" target="_blank"><img src="https://img.shields.io/badge/VITA%20Audio-Report-b5212f.svg?logo=arxiv" /></a>
|
9 |
+
<a href="https://huggingface.co/collections/VITA-MLLM/vita-audio-680f036c174441e7cdf02575" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-ffc107?color=ffc107&logoColor=white" /></a>
|
10 |
+
</p>
|
11 |
+
|
12 |
+
|
13 |
+
## :fire: News
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
* **`2025.05.06`** 🌟 We are proud to launch VITA-Audio, an end-to-end large speech model with fast audio-text token generation.
|
18 |
+
|
19 |
+
|
20 |
+
## 📄 Contents <!-- omit in toc -->
|
21 |
+
|
22 |
+
|
23 |
+
- [Highlights](#-highlights)
|
24 |
+
- [Exhibition](#-exhibition)
|
25 |
+
- [Models](#-models)
|
26 |
+
- [Experimental Results](#-experimental-results)
|
27 |
+
- [Training](#-training)
|
28 |
+
- [Inference](#-inference)
|
29 |
+
- [Evaluation](#-evaluation)
|
30 |
+
|
31 |
+
|
32 |
+
## ✨ Highlights
|
33 |
+
|
34 |
+
- **Low Latency**. VITA-Audio is the first end-to-end speech model capable of generating audio during the initial forward pass. By utilizing a set of 32 prefill tokens, VITA-Audio reduces the time required to generate the first audio token chunk from 217 ms to just 47 ms.
|
35 |
+
- **Fast Inference**. VITA-Audio achieves an inference speedup of 3-5x at the 7B parameter scale.
|
36 |
+
- **Open Source**. VITA-Audio is trained on **open-source data** only, consisting of 200k hours of publicly available audio.
|
37 |
+
- **Strong Performance**. VITA-Audio achieves competitive results on ASR,TTS and SQA benchmarks among cutting-edge models under 7B parameters.
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
## 📌 Exhibition
|
42 |
+
|
43 |
+
### Inference Acceleration
|
44 |
+
Model inference speed under different inference modes.
|
45 |
+
|
46 |
+
<p align="center">
|
47 |
+
<img src="./asset/qa_speed.gif" alt="demogif" width="48%" style="display: inline-block; margin-right: 2%;">
|
48 |
+
<img src="./asset/tts_speed.gif" alt="second_gif" width="48%" style="display: inline-block;">
|
49 |
+
</p>
|
50 |
+
|
51 |
+
### Time to Generate the First Audio Segment In Streaming Inference
|
52 |
+
<div align="center">
|
53 |
+
<img width="400" alt="first audio generate time" src="https://github.com/user-attachments/assets/165f943e-ac53-443f-abba-e5eb1e0c0f40" />
|
54 |
+
</div>
|
55 |
+
|
56 |
+
|
57 |
+
### Generated Audio Case
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
> 打南边来了个哑巴,腰里别了个喇叭;打北边来了个喇嘛,手里提了个獭犸。
|
62 |
+
> 提着獭犸的喇嘛要拿獭犸换别着喇叭的哑巴的喇叭;别着喇叭的哑巴不愿拿喇叭换提着獭玛的喇嘛的獭犸。
|
63 |
+
> 不知是别着喇叭的哑巴打了提着獭玛的喇嘛一喇叭;还是提着獭玛的喇嘛打了别着喇叭的哑巴一獭玛。
|
64 |
+
> 喇嘛回家炖獭犸;哑巴嘀嘀哒哒吹喇叭。
|
65 |
+
|
66 |
+
https://github.com/user-attachments/assets/38da791f-5d72-4d9c-a9b2-cec97c2f2b2b
|
67 |
+
|
68 |
+
|
69 |
+
---
|
70 |
+
|
71 |
+
> To be or not to be--to live intensely and richly,
|
72 |
+
> merely to exist, that depends on ourselves. Let widen and intensify our relations.
|
73 |
+
> While we live, let live!
|
74 |
+
|
75 |
+
https://github.com/user-attachments/assets/fd478065-4041-4eb8-b331-0c03b304d853
|
76 |
+
|
77 |
+
|
78 |
+
---
|
79 |
+
|
80 |
+
> The hair has been so little, don't think about it, go to bed early, for your hair. Good night!
|
81 |
+
|
82 |
+
https://github.com/user-attachments/assets/4cfe4742-e237-42bd-9f17-7935b2285799
|
83 |
+
|
84 |
+
|
85 |
+
---
|
86 |
+
> 两个黄鹂鸣翠柳,
|
87 |
+
> 一行白鹭上青天。
|
88 |
+
> 窗含西岭千秋雪,
|
89 |
+
> 门泊东吴万里船。
|
90 |
+
|
91 |
+
https://github.com/user-attachments/assets/382620ee-bb2a-488e-9e00-71afd2342b56
|
92 |
+
|
93 |
+
|
94 |
+
---
|
95 |
+
## 🔔 Models
|
96 |
+
|
97 |
+
| Model | LLM Size | Huggingface Weights |
|
98 |
+
|-------------------------|----------|---------------------------------------------------------------|
|
99 |
+
| VITA-Audio-Boost | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Boost |
|
100 |
+
| VITA-Audio-Balance | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Balance |
|
101 |
+
| VITA-Audio-Plus-Vanilla | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Plus-Vanilla |
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
## 📈 Experimental Results
|
106 |
+
- **Comparison of Spoken Question Answering**.
|
107 |
+
|
108 |
+

|
109 |
+
|
110 |
+
|
111 |
+
- **Comparison of Text to Speech**.
|
112 |
+
|
113 |
+

|
114 |
+
|
115 |
+
|
116 |
+
- **Comparison of Automatic Speech Recognition**.
|
117 |
+
|
118 |
+

|
119 |
+
|
120 |
+

|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
- **Effectiveness of Inference Acceleration**.
|
125 |
+
|
126 |
+
|
127 |
+

|
128 |
+
|
129 |
+

|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
## 📔 Requirements and Installation
|
134 |
+
|
135 |
+
### Prepare Environment
|
136 |
+
```
|
137 |
+
docker pull shenyunhang/pytorch:24.11-py3_2024-1224
|
138 |
+
```
|
139 |
+
|
140 |
+
### Get the Code
|
141 |
+
```
|
142 |
+
git clone https://github.com/VITA-MLLM/VITA-Audio.git
|
143 |
+
cd VITA-Audio
|
144 |
+
pip install -r requirements_ds_gpu.txt
|
145 |
+
pip install -e .
|
146 |
+
```
|
147 |
+
|
148 |
+
### Prepare Pre-trained Weight
|
149 |
+
|
150 |
+
#### LLM
|
151 |
+
|
152 |
+
- Download the LLM from https://huggingface.co/Qwen/Qwen2.5-7B-Instruct.
|
153 |
+
- Put it into '../models/Qwen/Qwen2.5-7B-Instruct/'
|
154 |
+
|
155 |
+
#### Audio Encoder and Audio Decoder
|
156 |
+
|
157 |
+
- Download the Audio Encoder from https://huggingface.co/THUDM/glm-4-voice-tokenizer.
|
158 |
+
- Put it into '../models/THUDM/glm-4-voice-tokenizer'
|
159 |
+
|
160 |
+
- Download the Audio Decoder from https://huggingface.co/THUDM/glm-4-voice-decoder.
|
161 |
+
- Put it into '../models/THUDM/glm-4-voice-decoder'
|
162 |
+
|
163 |
+
|
164 |
+
### Data Format
|
165 |
+
#### **Speech QA Interleaved Data Format**
|
166 |
+
|
167 |
+
> This format shows how text and audio sequences are interleaved in a structured JSON conversation between a user and an assistant.
|
168 |
+
|
169 |
+
```jsonc
|
170 |
+
{
|
171 |
+
"messages": [
|
172 |
+
{
|
173 |
+
"role": "user",
|
174 |
+
"content": "<|begin_of_audio|> audio_sequence <|end_of_audio|>"
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"role": "assistant",
|
178 |
+
"content": "text_sequence_1 <|begin_of_audio|> audio_sequence_1 <|end_of_audio|> text_sequence_2 <|begin_of_audio|> audio_sequence_2 <|end_of_audio|>"
|
179 |
+
}
|
180 |
+
]
|
181 |
+
}
|
182 |
+
```
|
183 |
+
|
184 |
+
## 🎲 Training
|
185 |
+
|
186 |
+
|
187 |
+
The following tutorial will take `VITA-Audio-Boost` as an example.
|
188 |
+
|
189 |
+
- To train `VITA-Audio-Balance` and other variants, you should modify the `text-audio-interval-ratio`.
|
190 |
+
|
191 |
+
VITA-Audio-Boost:
|
192 |
+
```
|
193 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
194 |
+
```
|
195 |
+
|
196 |
+
VITA-Audio-Balance:
|
197 |
+
```
|
198 |
+
--text-audio-interval-ratio 1 4 3 8 4 10 \
|
199 |
+
```
|
200 |
+
|
201 |
+
- To train `VITA-Audio-Plus-*`, you should use the script like `scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice...`
|
202 |
+
|
203 |
+
### Stage-1 (Audio-Text Alignment)
|
204 |
+
|
205 |
+
```
|
206 |
+
bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
|
207 |
+
```
|
208 |
+
|
209 |
+
The above script may need some adjustments.
|
210 |
+
|
211 |
+
- Set `ROOT_PATH` to your code root folder.
|
212 |
+
- Set `LOCAL_ROOT_PATH` to a temporary code root folder.
|
213 |
+
- Modify other variables as needed for your environment.
|
214 |
+
|
215 |
+
### Stage-2 (Single MCTP Module Training)
|
216 |
+
|
217 |
+
```
|
218 |
+
bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
|
219 |
+
```
|
220 |
+
|
221 |
+
The above script may need some adjustments.
|
222 |
+
|
223 |
+
- Set `ROOT_PATH` to your code root folder.
|
224 |
+
- Set `LOCAL_ROOT_PATH` to a temporary code root folder.
|
225 |
+
- Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 1.
|
226 |
+
- Modify other variables as needed for your environment.
|
227 |
+
|
228 |
+
### Stage-3 (Multiple MCTP Modules Training)
|
229 |
+
|
230 |
+
```
|
231 |
+
bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
|
232 |
+
```
|
233 |
+
|
234 |
+
The above script may need some adjustments.
|
235 |
+
|
236 |
+
- Set `ROOT_PATH` to your code root folder.
|
237 |
+
- Set `LOCAL_ROOT_PATH` to a temporary code root folder.
|
238 |
+
- Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 2.
|
239 |
+
- Modify other variables as needed for your environment.
|
240 |
+
|
241 |
+
### Stage-4 (Supervised Fine-tuning)
|
242 |
+
|
243 |
+
```
|
244 |
+
bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh 2048 `date +'%Y%m%d_%H%M%S'`
|
245 |
+
```
|
246 |
+
|
247 |
+
The above script may need some adjustments.
|
248 |
+
|
249 |
+
- Set `ROOT_PATH` to your code root folder.
|
250 |
+
- Set `LOCAL_ROOT_PATH` to a temporary code root folder.
|
251 |
+
- Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 3.
|
252 |
+
- Modify other variables as needed for your environment.
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
## 📐 Inference
|
257 |
+
|
258 |
+
Here we implement a simple script for inference.
|
259 |
+
|
260 |
+
It includes examples of speech-to-speech, ASR, and TTS tasks, as well as inference speed testing.
|
261 |
+
|
262 |
+
```
|
263 |
+
python tools/inference_sts.py
|
264 |
+
```
|
265 |
+
|
266 |
+
- Set `model_name_or_path` to VITA-Audio weights.
|
267 |
+
- Set `audio_tokenizer_path` to the path of the audio encoder.
|
268 |
+
- Set `flow_path` to the path of the audio decoder.
|
269 |
+
|
270 |
+
|
271 |
+
## 🔎 Evaluation
|
272 |
+
|
273 |
+
Evaluate SQA, ASR, and TTS benchmarks
|
274 |
+
```
|
275 |
+
bash scripts/deepspeed/evaluate_sts.sh
|
276 |
+
```
|
277 |
+
|
278 |
+
|
app.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
import gradio as gr
|
6 |
+
import sys
|
7 |
+
from vita_audio.tokenizer import get_audio_tokenizer
|
8 |
+
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
|
9 |
+
|
10 |
+
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig
|
12 |
+
from transformers.generation import GenerationConfig
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
17 |
+
|
18 |
+
|
19 |
+
import math
|
20 |
+
from numba import jit
|
21 |
+
|
22 |
+
@jit
|
23 |
+
def float_to_int16(audio: np.ndarray) -> np.ndarray:
|
24 |
+
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
|
25 |
+
am = 32767 * 32768 // am
|
26 |
+
return np.multiply(audio, am).astype(np.int16)
|
27 |
+
|
28 |
+
|
29 |
+
def is_wav(file_path):
|
30 |
+
wav_extensions = {'.wav'}
|
31 |
+
_, ext = os.path.splitext(file_path)
|
32 |
+
return ext.lower() in wav_extensions
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def _parse_text(text):
|
37 |
+
lines = text.split("\n")
|
38 |
+
lines = [line for line in lines if line != ""]
|
39 |
+
count = 0
|
40 |
+
|
41 |
+
for i, line in enumerate(lines):
|
42 |
+
if "```" in line:
|
43 |
+
count += 1
|
44 |
+
items = line.split("`")
|
45 |
+
if count % 2 == 1:
|
46 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
47 |
+
else:
|
48 |
+
lines[i] = "<br></code></pre>"
|
49 |
+
else:
|
50 |
+
if i > 0 and count % 2 == 1:
|
51 |
+
line = line.replace("`", r"\`")
|
52 |
+
line = line.replace("<", "<")
|
53 |
+
line = line.replace(">", ">")
|
54 |
+
line = line.replace(" ", " ")
|
55 |
+
line = line.replace("*", "*")
|
56 |
+
line = line.replace("_", "_")
|
57 |
+
line = line.replace("-", "-")
|
58 |
+
line = line.replace(".", ".")
|
59 |
+
line = line.replace("!", "!")
|
60 |
+
line = line.replace("(", "(")
|
61 |
+
line = line.replace(")", ")")
|
62 |
+
line = line.replace("$", "$")
|
63 |
+
lines[i] = "<br>" + line
|
64 |
+
|
65 |
+
return "".join(lines)
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def _launch_demo(model, tokenizer, audio_tokenizer):
|
70 |
+
def predict(_chatbot, task_history,task):
|
71 |
+
chat_query = task_history[-1][0]
|
72 |
+
print(task_history)
|
73 |
+
|
74 |
+
messages = []
|
75 |
+
|
76 |
+
audio_path_list =[]
|
77 |
+
if task == 'Spoken QA':
|
78 |
+
messages = [
|
79 |
+
{
|
80 |
+
"role": "system",
|
81 |
+
#"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.",
|
82 |
+
# "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.",
|
83 |
+
"content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.",
|
84 |
+
},
|
85 |
+
]
|
86 |
+
for i, (q, a) in enumerate(task_history):
|
87 |
+
|
88 |
+
if isinstance(q, (tuple, list)) and is_wav(q[0]):
|
89 |
+
audio_path_list.append(q[0])
|
90 |
+
messages = messages + [
|
91 |
+
{
|
92 |
+
"role": "user",
|
93 |
+
"content": f"\n<|audio|>",
|
94 |
+
},
|
95 |
+
]
|
96 |
+
else:
|
97 |
+
messages = messages + [
|
98 |
+
{
|
99 |
+
"role": "user",
|
100 |
+
"content": q ,
|
101 |
+
},
|
102 |
+
]
|
103 |
+
if a != None:
|
104 |
+
messages = messages + [
|
105 |
+
{
|
106 |
+
"role": "assistant",
|
107 |
+
"content": a ,
|
108 |
+
},
|
109 |
+
]
|
110 |
+
model.generation_config.do_sample = False
|
111 |
+
|
112 |
+
elif task == 'TTS':
|
113 |
+
for i, (q, a) in enumerate(task_history):
|
114 |
+
|
115 |
+
if isinstance(q, (tuple, list)) and is_wav(q[0]):
|
116 |
+
audio_path_list.append(q[0])
|
117 |
+
messages = messages + [
|
118 |
+
{
|
119 |
+
"role": "user",
|
120 |
+
"content": f"\n<|audio|>",
|
121 |
+
},
|
122 |
+
]
|
123 |
+
else:
|
124 |
+
messages = messages + [
|
125 |
+
{
|
126 |
+
"role": "user",
|
127 |
+
"content": f'Convert the text to speech.\n{q}' ,
|
128 |
+
},
|
129 |
+
]
|
130 |
+
if a != None:
|
131 |
+
messages = messages + [
|
132 |
+
{
|
133 |
+
"role": "assistant",
|
134 |
+
"content": a ,
|
135 |
+
},
|
136 |
+
]
|
137 |
+
model.generation_config.do_sample = True
|
138 |
+
elif task == 'ASR':
|
139 |
+
for i, (q, a) in enumerate(task_history):
|
140 |
+
|
141 |
+
if isinstance(q, (tuple, list)) and is_wav(q[0]):
|
142 |
+
audio_path_list.append(q[0])
|
143 |
+
messages = messages + [
|
144 |
+
{
|
145 |
+
"role": "user",
|
146 |
+
"content": f"Convert the speech to text.\n<|audio|>",
|
147 |
+
},
|
148 |
+
]
|
149 |
+
else:
|
150 |
+
messages = messages + [
|
151 |
+
{
|
152 |
+
"role": "user",
|
153 |
+
"content": f"{q}" ,
|
154 |
+
},
|
155 |
+
]
|
156 |
+
if a != None:
|
157 |
+
messages = messages + [
|
158 |
+
{
|
159 |
+
"role": "assistant",
|
160 |
+
"content": a ,
|
161 |
+
},
|
162 |
+
]
|
163 |
+
model.generation_config.do_sample = False
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
add_generation_prompt =True
|
168 |
+
input_ids = tokenizer.apply_chat_template(
|
169 |
+
messages,
|
170 |
+
tokenize=True,
|
171 |
+
add_generation_prompt=add_generation_prompt,
|
172 |
+
# return_tensors="pt",
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
input_ids, audios, audio_indices = add_audio_input_contiguous(
|
177 |
+
input_ids, audio_path_list, tokenizer, audio_tokenizer
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda")
|
182 |
+
|
183 |
+
# print("input", tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True)
|
184 |
+
|
185 |
+
|
186 |
+
if audio_path_list == []:
|
187 |
+
audios = None
|
188 |
+
audio_indices = None
|
189 |
+
|
190 |
+
outputs = model.generate(
|
191 |
+
input_ids,
|
192 |
+
audios=audios,
|
193 |
+
audio_indices=audio_indices,
|
194 |
+
)
|
195 |
+
|
196 |
+
output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
197 |
+
# print(f"{output=}", flush=True)
|
198 |
+
|
199 |
+
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
|
200 |
+
begin_of_audio = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>")
|
201 |
+
end_of_audio = tokenizer.convert_tokens_to_ids("<|end_of_audio|>")
|
202 |
+
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
203 |
+
response = outputs[0][len(input_ids[0]):]
|
204 |
+
|
205 |
+
audio_tokens = []
|
206 |
+
text_tokens = []
|
207 |
+
for token_id in response:
|
208 |
+
if token_id >= audio_offset:
|
209 |
+
audio_tokens.append(token_id - audio_offset)
|
210 |
+
elif (token_id.item() != begin_of_audio) and (token_id.item() != end_of_audio) and (token_id.item() != im_end):
|
211 |
+
text_tokens.append(token_id)
|
212 |
+
|
213 |
+
if len(audio_tokens) > 0:
|
214 |
+
tts_speech = audio_tokenizer.decode(audio_tokens)
|
215 |
+
audio_np = float_to_int16(tts_speech.cpu().numpy())
|
216 |
+
tts_speech = (22050,audio_np)
|
217 |
+
else:
|
218 |
+
tts_speech = None
|
219 |
+
|
220 |
+
# import pdb;pdb.set_trace()
|
221 |
+
history_response = tokenizer.decode(text_tokens)
|
222 |
+
task_history[-1] = (chat_query, history_response)
|
223 |
+
|
224 |
+
_chatbot[-1] = (chat_query, history_response)
|
225 |
+
# print("query",chat_query)
|
226 |
+
# print("task_history",task_history)
|
227 |
+
# print(_chatbot)
|
228 |
+
# print("answer: ",outputs)
|
229 |
+
return _chatbot, tts_speech
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
def add_text(history, task_history, text):
|
234 |
+
task_text = text
|
235 |
+
# import pdb;pdb.set_trace()
|
236 |
+
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
|
237 |
+
task_text = text[:-1]
|
238 |
+
history = history + [(_parse_text(text), None)]
|
239 |
+
task_history = task_history + [(task_text, None)]
|
240 |
+
return history, task_history, ""
|
241 |
+
|
242 |
+
|
243 |
+
def add_audio(history, task_history, file):
|
244 |
+
print(file)
|
245 |
+
if file is None:
|
246 |
+
return history, task_history
|
247 |
+
history = history + [((file,), None)]
|
248 |
+
task_history = task_history + [((file,), None)]
|
249 |
+
return history, task_history
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
def reset_user_input():
|
255 |
+
# import pdb;pdb.set_trace()
|
256 |
+
return gr.update(value="")
|
257 |
+
|
258 |
+
def reset_state(task_history):
|
259 |
+
task_history.clear()
|
260 |
+
return []
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
with gr.Blocks(title="VITA-Audio-Plus-Vanilla") as demo:
|
265 |
+
gr.Markdown("""<center><font size=8>VITA-Audio-Plus-Vanilla</center>""")
|
266 |
+
gr.Markdown("""<center><font size=4>The deployment of the VITA-Audio-Plus-Vanilla model employs a non-streaming deployment approach. The currently deployed model is VITA-Audio-Plus-Vanilla. For the ASR and TTS tasks, only single-turn dialogues are supported. In the Spoken QA task, generated text is used as dialogue history to reduce the context length.</center>""")
|
267 |
+
chatbot = gr.Chatbot(label='VITA-Audio-Plus-Vanilla', elem_classes="control-height", height=500)
|
268 |
+
query = gr.Textbox(lines=2, label='Text Input')
|
269 |
+
task_history = gr.State([])
|
270 |
+
with gr.Row():
|
271 |
+
add_text_button = gr.Button("Submit Text (提交文本)")
|
272 |
+
add_audio_button = gr.Button("Submit Audio (提交音频)")
|
273 |
+
empty_bin = gr.Button("🧹 Clear History (清除历史)")
|
274 |
+
task = gr.Radio(
|
275 |
+
choices = ["ASR", "TTS", "Spoken QA"], label="TASK",value = 'Spoken QA'
|
276 |
+
)
|
277 |
+
|
278 |
+
with gr.Row(scale=1):
|
279 |
+
|
280 |
+
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
|
281 |
+
audio_output = gr.Audio(label="Play", streaming=True,
|
282 |
+
autoplay=True, show_download_button=True)
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
|
287 |
+
reset_user_input, [], [query]
|
288 |
+
).then(
|
289 |
+
predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
294 |
+
|
295 |
+
|
296 |
+
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
|
297 |
+
predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True
|
298 |
+
)
|
299 |
+
|
300 |
+
|
301 |
+
server_port = 18806
|
302 |
+
demo.launch(
|
303 |
+
share=False,
|
304 |
+
debug=True,
|
305 |
+
server_name="0.0.0.0",
|
306 |
+
server_port=server_port,
|
307 |
+
show_api=False,
|
308 |
+
show_error=False,
|
309 |
+
|
310 |
+
)
|
311 |
+
|
312 |
+
def main():
|
313 |
+
|
314 |
+
model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla"
|
315 |
+
|
316 |
+
device_map = "cuda:0"
|
317 |
+
|
318 |
+
sys.path.append("third_party/GLM-4-Voice/")
|
319 |
+
sys.path.append("third_party/GLM-4-Voice/cosyvoice/")
|
320 |
+
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/")
|
321 |
+
|
322 |
+
from huggingface_hub import snapshot_download
|
323 |
+
audio_tokenizer_path = snapshot_download(repo_id="THUDM/glm-4-voice-tokenizer")
|
324 |
+
flow_path = snapshot_download(repo_id="THUDM/glm-4-voice-decoder")
|
325 |
+
|
326 |
+
audio_tokenizer_rank = 0
|
327 |
+
audio_tokenizer_type = "sensevoice_glm4voice"
|
328 |
+
|
329 |
+
torch_dtype = torch.bfloat16
|
330 |
+
audio_tokenizer = get_audio_tokenizer(
|
331 |
+
audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path, rank=audio_tokenizer_rank
|
332 |
+
)
|
333 |
+
from evaluation.get_chat_template import qwen2_chat_template as chat_template
|
334 |
+
|
335 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
336 |
+
model_name_or_path,
|
337 |
+
trust_remote_code=True,
|
338 |
+
chat_template=chat_template,
|
339 |
+
)
|
340 |
+
# print(f"{tokenizer=}")
|
341 |
+
# print(f"{tokenizer.get_chat_template()=}")
|
342 |
+
|
343 |
+
|
344 |
+
model = AutoModelForCausalLM.from_pretrained(
|
345 |
+
model_name_or_path,
|
346 |
+
trust_remote_code=True,
|
347 |
+
device_map=device_map,
|
348 |
+
torch_dtype=torch_dtype,
|
349 |
+
attn_implementation="flash_attention_2",
|
350 |
+
).eval()
|
351 |
+
|
352 |
+
# print(f"{model.config.model_type=}")
|
353 |
+
|
354 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
355 |
+
model_name_or_path, trust_remote_code=True
|
356 |
+
)
|
357 |
+
|
358 |
+
model.generation_config.max_new_tokens = 4096
|
359 |
+
model.generation_config.chat_format = "chatml"
|
360 |
+
model.generation_config.max_window_size = 8192
|
361 |
+
model.generation_config.use_cache = True
|
362 |
+
model.generation_config.do_sample = True
|
363 |
+
model.generation_config.temperature = 1.0
|
364 |
+
model.generation_config.top_k = 50
|
365 |
+
model.generation_config.top_p = 1.0
|
366 |
+
model.generation_config.num_beams = 1
|
367 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
368 |
+
model.generation_config.mtp_inference_mode = [8192,10]
|
369 |
+
|
370 |
+
|
371 |
+
_launch_demo(model, tokenizer, audio_tokenizer)
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == '__main__':
|
377 |
+
|
378 |
+
main()
|
configs/sts_finetune_stage1.yaml
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
xlsx_sample_num: 5
|
3 |
+
|
4 |
+
dataset:
|
5 |
+
|
6 |
+
wenet-e2e/wenetspeech:
|
7 |
+
ratio: 1.0
|
8 |
+
data_paths:
|
9 |
+
- datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
|
10 |
+
- datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
|
11 |
+
|
12 |
+
Wenetspeech4TTS/Wenetspeech4TTS:
|
13 |
+
ratio: 1.0
|
14 |
+
data_paths:
|
15 |
+
- datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
|
16 |
+
|
17 |
+
fixie-ai/librispeech_asr:
|
18 |
+
ratio: 1.0
|
19 |
+
data_paths:
|
20 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
|
21 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
|
22 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
|
23 |
+
|
24 |
+
mythicinfinity/libritts:
|
25 |
+
ratio: 1.0
|
26 |
+
data_paths:
|
27 |
+
- datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
|
28 |
+
- datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
|
29 |
+
- datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
|
30 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
|
31 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
|
32 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
|
33 |
+
|
34 |
+
parler-tts/mls_eng:
|
35 |
+
ratio: 1.0
|
36 |
+
data_paths:
|
37 |
+
#- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
|
38 |
+
- datasets/jsonl/parler-tts/mls_eng/train.jsonl
|
39 |
+
|
40 |
+
mozilla-foundation/common_voice_17_0:
|
41 |
+
ratio: 1.0
|
42 |
+
data_paths:
|
43 |
+
- datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
|
44 |
+
- datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
|
45 |
+
|
46 |
+
MushanW/GLOBE_V2:
|
47 |
+
ratio: 1.0
|
48 |
+
data_paths:
|
49 |
+
- datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
|
50 |
+
|
51 |
+
amphion/Emilia-Dataset:
|
52 |
+
ratio: 0.5
|
53 |
+
data_paths:
|
54 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
|
55 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
|
56 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
|
57 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
|
58 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
|
59 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
|
60 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
|
61 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
|
62 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
|
63 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
|
64 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
|
65 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
|
66 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
|
67 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
|
68 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
|
69 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
|
70 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
|
71 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
|
72 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
|
73 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
|
74 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
|
75 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
|
76 |
+
|
77 |
+
amphion/Emilia-Dataset/speaker_prompt:
|
78 |
+
ratio: 0.5
|
79 |
+
data_paths:
|
80 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
|
81 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
|
82 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
|
83 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
|
84 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
|
85 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
|
86 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
|
87 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
|
88 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
|
89 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
|
90 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
|
91 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
|
92 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
|
93 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
|
94 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
|
95 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
|
96 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
|
97 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
|
98 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
|
99 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
|
100 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
|
101 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
|
102 |
+
|
103 |
+
openslr:
|
104 |
+
ratio: 1.0
|
105 |
+
data_paths:
|
106 |
+
- datasets/jsonl/openslr/SLR68/train.jsonl
|
107 |
+
- datasets/jsonl/openslr/SLR68/dev.jsonl
|
108 |
+
|
109 |
+
speechcolab/gigaspeech:
|
110 |
+
ratio: 1.0
|
111 |
+
data_paths:
|
112 |
+
- datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
|
113 |
+
- datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
|
114 |
+
|
115 |
+
MLCommons/peoples_speech:
|
116 |
+
ratio: 1.0
|
117 |
+
data_paths:
|
118 |
+
- datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
|
119 |
+
- datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
|
120 |
+
- datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
|
121 |
+
- datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
|
122 |
+
- datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
|
123 |
+
|
124 |
+
facebook/voxpopuli:
|
125 |
+
ratio: 1.0
|
126 |
+
data_paths:
|
127 |
+
- datasets/jsonl/facebook/voxpopuli/en_train.jsonl
|
128 |
+
- datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
|
129 |
+
|
130 |
+
shenyunhang:
|
131 |
+
ratio: 1.0
|
132 |
+
data_paths:
|
133 |
+
- datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
|
134 |
+
- datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
|
135 |
+
- datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
|
136 |
+
- datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
|
137 |
+
- datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
|
138 |
+
|
139 |
+
gpt-omni/VoiceAssistant-400K:
|
140 |
+
ratio: 0.0
|
141 |
+
data_paths:
|
142 |
+
- datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
|
143 |
+
|
144 |
+
VITA-MLLM/AudioQA-1M:
|
145 |
+
ratio: 0.0
|
146 |
+
data_paths:
|
147 |
+
- datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
|
148 |
+
|
149 |
+
BAAI/Infinity-Instruct:
|
150 |
+
ratio: 1.0
|
151 |
+
data_paths:
|
152 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
|
153 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
|
154 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
|
155 |
+
- datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
|
156 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
|
157 |
+
|
158 |
+
OpenHermes:
|
159 |
+
ratio: 1.0
|
160 |
+
data_paths:
|
161 |
+
- datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
|
162 |
+
|
163 |
+
lima:
|
164 |
+
ratio: 1.0
|
165 |
+
data_paths:
|
166 |
+
- datasets/jsonl/GAIR/lima/train.jsonl
|
167 |
+
|
168 |
+
databricks-dolly-15k:
|
169 |
+
ratio: 1.0
|
170 |
+
data_paths:
|
171 |
+
- datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
|
172 |
+
|
173 |
+
MetaMathQA:
|
174 |
+
ratio: 1.0
|
175 |
+
data_paths:
|
176 |
+
- datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
|
177 |
+
|
178 |
+
MathInstruct:
|
179 |
+
ratio: 1.0
|
180 |
+
data_paths:
|
181 |
+
- datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
|
182 |
+
|
183 |
+
orca-math-word-problems-200k:
|
184 |
+
ratio: 1.0
|
185 |
+
data_paths:
|
186 |
+
- datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
|
187 |
+
|
188 |
+
atlas-math-sets:
|
189 |
+
ratio: 1.0
|
190 |
+
num: 100000
|
191 |
+
data_paths:
|
192 |
+
- datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
|
193 |
+
|
194 |
+
goat:
|
195 |
+
ratio: 1.0
|
196 |
+
num: 30000
|
197 |
+
data_paths:
|
198 |
+
- datasets/jsonl/tiedong/goat/dataset.jsonl
|
199 |
+
|
200 |
+
camel-ai:
|
201 |
+
ratio: 1.0
|
202 |
+
data_paths:
|
203 |
+
- datasets/jsonl/camel-ai/math/math.jsonl
|
204 |
+
|
205 |
+
Long-Instruction-with-Paraphrasing:
|
206 |
+
ratio: 1.0
|
207 |
+
data_paths:
|
208 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
|
209 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
|
210 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
|
211 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
|
212 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
|
213 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
|
214 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
|
215 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
|
216 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
|
217 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
|
218 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
|
219 |
+
|
220 |
+
Long:
|
221 |
+
ratio: 1.0
|
222 |
+
data_paths:
|
223 |
+
- datasets/jsonl/akoksal/LongForm/data.jsonl
|
224 |
+
- datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
|
225 |
+
- datasets/jsonl/THUDM/LongCite-45k/long.jsonl
|
226 |
+
- datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
|
227 |
+
- datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
|
228 |
+
- datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
|
229 |
+
- datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
|
230 |
+
- datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
|
231 |
+
- datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
|
232 |
+
|
233 |
+
open-thoughts/OpenThoughts2-1M:
|
234 |
+
ratio: 0.0
|
235 |
+
num: 200000
|
236 |
+
data_paths:
|
237 |
+
- datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
|
238 |
+
|
239 |
+
nvidia/Llama-Nemotron-Post-Training-Dataset:
|
240 |
+
ratio: 0.0
|
241 |
+
num: 200000
|
242 |
+
data_paths:
|
243 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
|
244 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
|
245 |
+
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
|
246 |
+
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
|
247 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
|
248 |
+
|
249 |
+
glaiveai/reasoning-v1-20m:
|
250 |
+
ratio: 0.0
|
251 |
+
num: 200000
|
252 |
+
data_paths:
|
253 |
+
- datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
|
254 |
+
|
255 |
+
nvidia/OpenCodeReasoning:
|
256 |
+
ratio: 0.0
|
257 |
+
num: 200000
|
258 |
+
data_paths:
|
259 |
+
- datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
|
260 |
+
|
261 |
+
Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
|
262 |
+
ratio: 0.0
|
263 |
+
num: 200000
|
264 |
+
data_paths:
|
265 |
+
- datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
|
266 |
+
|
267 |
+
open-r1/OpenR1-Math-220k:
|
268 |
+
ratio: 0.0
|
269 |
+
num: 200000
|
270 |
+
data_paths:
|
271 |
+
#- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
|
272 |
+
- datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
|
273 |
+
#- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
|
configs/sts_finetune_stage2.yaml
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
xlsx_sample_num: 5
|
3 |
+
|
4 |
+
dataset:
|
5 |
+
|
6 |
+
wenet-e2e/wenetspeech:
|
7 |
+
ratio: 0.05
|
8 |
+
data_paths:
|
9 |
+
- datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
|
10 |
+
- datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
|
11 |
+
|
12 |
+
Wenetspeech4TTS/Wenetspeech4TTS:
|
13 |
+
ratio: 0.05
|
14 |
+
data_paths:
|
15 |
+
- datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
|
16 |
+
|
17 |
+
fixie-ai/librispeech_asr:
|
18 |
+
ratio: 0.05
|
19 |
+
data_paths:
|
20 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
|
21 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
|
22 |
+
- datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
|
23 |
+
|
24 |
+
mythicinfinity/libritts:
|
25 |
+
ratio: 0.05
|
26 |
+
data_paths:
|
27 |
+
- datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
|
28 |
+
- datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
|
29 |
+
- datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
|
30 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
|
31 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
|
32 |
+
- datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
|
33 |
+
|
34 |
+
parler-tts/mls_eng:
|
35 |
+
ratio: 0.05
|
36 |
+
data_paths:
|
37 |
+
#- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
|
38 |
+
- datasets/jsonl/parler-tts/mls_eng/train.jsonl
|
39 |
+
|
40 |
+
mozilla-foundation/common_voice_17_0:
|
41 |
+
ratio: 0.05
|
42 |
+
data_paths:
|
43 |
+
- datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
|
44 |
+
- datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
|
45 |
+
|
46 |
+
MushanW/GLOBE_V2:
|
47 |
+
ratio: 0.05
|
48 |
+
data_paths:
|
49 |
+
- datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
|
50 |
+
|
51 |
+
amphion/Emilia-Dataset:
|
52 |
+
ratio: 0.025
|
53 |
+
data_paths:
|
54 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
|
55 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
|
56 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
|
57 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
|
58 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
|
59 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
|
60 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
|
61 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
|
62 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
|
63 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
|
64 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
|
65 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
|
66 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
|
67 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
|
68 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
|
69 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
|
70 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
|
71 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
|
72 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
|
73 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
|
74 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
|
75 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
|
76 |
+
|
77 |
+
amphion/Emilia-Dataset/speaker_prompt:
|
78 |
+
ratio: 0.025
|
79 |
+
data_paths:
|
80 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
|
81 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
|
82 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
|
83 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
|
84 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
|
85 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
|
86 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
|
87 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
|
88 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
|
89 |
+
- datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
|
90 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
|
91 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
|
92 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
|
93 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
|
94 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
|
95 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
|
96 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
|
97 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
|
98 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
|
99 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
|
100 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
|
101 |
+
- datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
|
102 |
+
|
103 |
+
openslr:
|
104 |
+
ratio: 0.05
|
105 |
+
data_paths:
|
106 |
+
- datasets/jsonl/openslr/SLR68/train.jsonl
|
107 |
+
- datasets/jsonl/openslr/SLR68/dev.jsonl
|
108 |
+
|
109 |
+
speechcolab/gigaspeech:
|
110 |
+
ratio: 0.05
|
111 |
+
data_paths:
|
112 |
+
- datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
|
113 |
+
- datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
|
114 |
+
|
115 |
+
MLCommons/peoples_speech:
|
116 |
+
ratio: 0.05
|
117 |
+
data_paths:
|
118 |
+
- datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
|
119 |
+
- datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
|
120 |
+
- datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
|
121 |
+
- datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
|
122 |
+
- datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
|
123 |
+
|
124 |
+
facebook/voxpopuli:
|
125 |
+
ratio: 0.05
|
126 |
+
data_paths:
|
127 |
+
- datasets/jsonl/facebook/voxpopuli/en_train.jsonl
|
128 |
+
- datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
|
129 |
+
|
130 |
+
shenyunhang:
|
131 |
+
ratio: 0.05
|
132 |
+
data_paths:
|
133 |
+
- datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
|
134 |
+
- datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
|
135 |
+
- datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
|
136 |
+
- datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
|
137 |
+
- datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
|
138 |
+
|
139 |
+
gpt-omni/VoiceAssistant-400K:
|
140 |
+
ratio: 2.0
|
141 |
+
data_paths:
|
142 |
+
- datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
|
143 |
+
|
144 |
+
VITA-MLLM/AudioQA-1M:
|
145 |
+
ratio: 2.0
|
146 |
+
data_paths:
|
147 |
+
- datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
|
148 |
+
|
149 |
+
BAAI/Infinity-Instruct:
|
150 |
+
ratio: 0.05
|
151 |
+
data_paths:
|
152 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
|
153 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
|
154 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
|
155 |
+
- datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
|
156 |
+
#- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
|
157 |
+
|
158 |
+
OpenHermes:
|
159 |
+
ratio: 0.05
|
160 |
+
data_paths:
|
161 |
+
- datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
|
162 |
+
|
163 |
+
lima:
|
164 |
+
ratio: 0.05
|
165 |
+
data_paths:
|
166 |
+
- datasets/jsonl/GAIR/lima/train.jsonl
|
167 |
+
|
168 |
+
databricks-dolly-15k:
|
169 |
+
ratio: 0.05
|
170 |
+
data_paths:
|
171 |
+
- datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
|
172 |
+
|
173 |
+
MetaMathQA:
|
174 |
+
ratio: 0.05
|
175 |
+
data_paths:
|
176 |
+
- datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
|
177 |
+
|
178 |
+
MathInstruct:
|
179 |
+
ratio: 0.05
|
180 |
+
data_paths:
|
181 |
+
- datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
|
182 |
+
|
183 |
+
orca-math-word-problems-200k:
|
184 |
+
ratio: 0.05
|
185 |
+
data_paths:
|
186 |
+
- datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
|
187 |
+
|
188 |
+
atlas-math-sets:
|
189 |
+
ratio: 0.05
|
190 |
+
num: 100000
|
191 |
+
data_paths:
|
192 |
+
- datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
|
193 |
+
|
194 |
+
goat:
|
195 |
+
ratio: 0.05
|
196 |
+
num: 30000
|
197 |
+
data_paths:
|
198 |
+
- datasets/jsonl/tiedong/goat/dataset.jsonl
|
199 |
+
|
200 |
+
camel-ai:
|
201 |
+
ratio: 0.05
|
202 |
+
data_paths:
|
203 |
+
- datasets/jsonl/camel-ai/math/math.jsonl
|
204 |
+
|
205 |
+
Long-Instruction-with-Paraphrasing:
|
206 |
+
ratio: 0.05
|
207 |
+
data_paths:
|
208 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
|
209 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
|
210 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
|
211 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
|
212 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
|
213 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
|
214 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
|
215 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
|
216 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
|
217 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
|
218 |
+
- datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
|
219 |
+
|
220 |
+
Long:
|
221 |
+
ratio: 0.05
|
222 |
+
data_paths:
|
223 |
+
- datasets/jsonl/akoksal/LongForm/data.jsonl
|
224 |
+
- datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
|
225 |
+
- datasets/jsonl/THUDM/LongCite-45k/long.jsonl
|
226 |
+
- datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
|
227 |
+
- datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
|
228 |
+
- datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
|
229 |
+
- datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
|
230 |
+
- datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
|
231 |
+
- datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
|
232 |
+
|
233 |
+
open-thoughts/OpenThoughts2-1M:
|
234 |
+
ratio: 0.0
|
235 |
+
num: 10000
|
236 |
+
data_paths:
|
237 |
+
- datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
|
238 |
+
|
239 |
+
nvidia/Llama-Nemotron-Post-Training-Dataset:
|
240 |
+
ratio: 0.0
|
241 |
+
num: 10000
|
242 |
+
data_paths:
|
243 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
|
244 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
|
245 |
+
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
|
246 |
+
#- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
|
247 |
+
- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
|
248 |
+
|
249 |
+
glaiveai/reasoning-v1-20m:
|
250 |
+
ratio: 0.0
|
251 |
+
num: 10000
|
252 |
+
data_paths:
|
253 |
+
- datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
|
254 |
+
|
255 |
+
nvidia/OpenCodeReasoning:
|
256 |
+
ratio: 0.0
|
257 |
+
num: 10000
|
258 |
+
data_paths:
|
259 |
+
- datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
|
260 |
+
|
261 |
+
Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
|
262 |
+
ratio: 0.0
|
263 |
+
num: 10000
|
264 |
+
data_paths:
|
265 |
+
- datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
|
266 |
+
|
267 |
+
open-r1/OpenR1-Math-220k:
|
268 |
+
ratio: 0.0
|
269 |
+
num: 10000
|
270 |
+
data_paths:
|
271 |
+
#- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
|
272 |
+
- datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
|
273 |
+
#- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
|
evaluation/compute-acc-of-contain.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import json
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
import sys
|
8 |
+
import unicodedata
|
9 |
+
|
10 |
+
from word2number import w2n
|
11 |
+
|
12 |
+
# def is_list_in_string(text, candidate):
|
13 |
+
# return any([all(xx in text for xx in x.split(" ")) if isinstance(x, str) else all([xx in text for xx in x]) for x in candidate])
|
14 |
+
|
15 |
+
|
16 |
+
def is_string_in_string(text, candidate):
|
17 |
+
return all(x in text for x in candidate.split(" "))
|
18 |
+
|
19 |
+
|
20 |
+
def is_list_in_string(text, candidate):
|
21 |
+
return any(
|
22 |
+
[
|
23 |
+
is_string_in_string(text, x) if isinstance(x, str) else is_list_in_string(text, x)
|
24 |
+
for x in candidate
|
25 |
+
]
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def clean_punctuation(value):
|
30 |
+
punctuation = string.punctuation
|
31 |
+
punctuation = punctuation.replace("'", "")
|
32 |
+
value = re.sub(f"[{punctuation}]", " ", value)
|
33 |
+
return value
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
|
38 |
+
pred_gt_json_file = sys.argv[1]
|
39 |
+
|
40 |
+
with open(pred_gt_json_file, "r") as f:
|
41 |
+
pred_gt = json.load(f)
|
42 |
+
|
43 |
+
acc = 0
|
44 |
+
for line in pred_gt:
|
45 |
+
|
46 |
+
pred = line[0]
|
47 |
+
gt = line[1]
|
48 |
+
|
49 |
+
# pred = clean_punctuation(pred)
|
50 |
+
pred = pred.lower()
|
51 |
+
|
52 |
+
if isinstance(gt, list):
|
53 |
+
pass
|
54 |
+
else:
|
55 |
+
gt = [
|
56 |
+
gt,
|
57 |
+
]
|
58 |
+
gt = [clean_punctuation(x) for x in gt]
|
59 |
+
gt = [x.lower().strip() for x in gt]
|
60 |
+
|
61 |
+
try:
|
62 |
+
gt_number = [str(w2n.word_to_num(x.lower())) for x in gt]
|
63 |
+
except:
|
64 |
+
gt_number = gt
|
65 |
+
pass
|
66 |
+
|
67 |
+
if is_list_in_string(pred, gt):
|
68 |
+
acc += 1
|
69 |
+
elif is_list_in_string(pred, gt_number):
|
70 |
+
acc += 1
|
71 |
+
else:
|
72 |
+
print("======================================================")
|
73 |
+
print(f"{line[0]=}")
|
74 |
+
print(f"{line[1]=}")
|
75 |
+
|
76 |
+
print("======================================================")
|
77 |
+
print(f"{acc=}")
|
78 |
+
print(f"{len(pred_gt)=}")
|
79 |
+
print("======================================================")
|
80 |
+
|
81 |
+
acc = acc / len(pred_gt) * 100
|
82 |
+
|
83 |
+
print("======================================================")
|
84 |
+
print(f"{acc=}")
|
85 |
+
print("======================================================")
|
evaluation/compute-cer.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import unicodedata
|
6 |
+
import codecs
|
7 |
+
|
8 |
+
remove_tag = True
|
9 |
+
spacelist = [' ', '\t', '\r', '\n']
|
10 |
+
puncts = [
|
11 |
+
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
|
12 |
+
'《', '》'
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
def characterize(string):
|
17 |
+
res = []
|
18 |
+
i = 0
|
19 |
+
while i < len(string):
|
20 |
+
char = string[i]
|
21 |
+
if char in puncts:
|
22 |
+
i += 1
|
23 |
+
continue
|
24 |
+
cat1 = unicodedata.category(char)
|
25 |
+
# https://unicodebook.readthedocs.io/unicode.html#unicode-categories
|
26 |
+
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
|
27 |
+
i += 1
|
28 |
+
continue
|
29 |
+
if cat1 == 'Lo': # letter-other
|
30 |
+
res.append(char)
|
31 |
+
i += 1
|
32 |
+
else:
|
33 |
+
# some input looks like: <unk><noise>, we want to separate it to two words.
|
34 |
+
sep = ' '
|
35 |
+
if char == '<':
|
36 |
+
sep = '>'
|
37 |
+
j = i + 1
|
38 |
+
while j < len(string):
|
39 |
+
c = string[j]
|
40 |
+
if ord(c) >= 128 or (c in spacelist) or (c == sep):
|
41 |
+
break
|
42 |
+
j += 1
|
43 |
+
if j < len(string) and string[j] == '>':
|
44 |
+
j += 1
|
45 |
+
res.append(string[i:j])
|
46 |
+
i = j
|
47 |
+
return res
|
48 |
+
|
49 |
+
|
50 |
+
def stripoff_tags(x):
|
51 |
+
if not x:
|
52 |
+
return ''
|
53 |
+
chars = []
|
54 |
+
i = 0
|
55 |
+
T = len(x)
|
56 |
+
while i < T:
|
57 |
+
if x[i] == '<':
|
58 |
+
while i < T and x[i] != '>':
|
59 |
+
i += 1
|
60 |
+
i += 1
|
61 |
+
else:
|
62 |
+
chars.append(x[i])
|
63 |
+
i += 1
|
64 |
+
return ''.join(chars)
|
65 |
+
|
66 |
+
|
67 |
+
def normalize(sentence, ignore_words, cs, split=None):
|
68 |
+
""" sentence, ignore_words are both in unicode
|
69 |
+
"""
|
70 |
+
new_sentence = []
|
71 |
+
for token in sentence:
|
72 |
+
x = token
|
73 |
+
if not cs:
|
74 |
+
x = x.upper()
|
75 |
+
if x in ignore_words:
|
76 |
+
continue
|
77 |
+
if remove_tag:
|
78 |
+
x = stripoff_tags(x)
|
79 |
+
if not x:
|
80 |
+
continue
|
81 |
+
if split and x in split:
|
82 |
+
new_sentence += split[x]
|
83 |
+
if x.isalnum():
|
84 |
+
for k in x:
|
85 |
+
new_sentence.append(k)
|
86 |
+
else:
|
87 |
+
new_sentence.append(x)
|
88 |
+
return new_sentence
|
89 |
+
|
90 |
+
|
91 |
+
class Calculator:
|
92 |
+
|
93 |
+
def __init__(self):
|
94 |
+
self.data = {}
|
95 |
+
self.space = []
|
96 |
+
self.cost = {}
|
97 |
+
self.cost['cor'] = 0
|
98 |
+
self.cost['sub'] = 1
|
99 |
+
self.cost['del'] = 1
|
100 |
+
self.cost['ins'] = 1
|
101 |
+
|
102 |
+
def calculate(self, lab, rec):
|
103 |
+
# Initialization
|
104 |
+
lab.insert(0, '')
|
105 |
+
rec.insert(0, '')
|
106 |
+
while len(self.space) < len(lab):
|
107 |
+
self.space.append([])
|
108 |
+
for row in self.space:
|
109 |
+
for element in row:
|
110 |
+
element['dist'] = 0
|
111 |
+
element['error'] = 'non'
|
112 |
+
while len(row) < len(rec):
|
113 |
+
row.append({'dist': 0, 'error': 'non'})
|
114 |
+
for i in range(len(lab)):
|
115 |
+
self.space[i][0]['dist'] = i
|
116 |
+
self.space[i][0]['error'] = 'del'
|
117 |
+
for j in range(len(rec)):
|
118 |
+
self.space[0][j]['dist'] = j
|
119 |
+
self.space[0][j]['error'] = 'ins'
|
120 |
+
self.space[0][0]['error'] = 'non'
|
121 |
+
for token in lab:
|
122 |
+
if token not in self.data and len(token) > 0:
|
123 |
+
self.data[token] = {
|
124 |
+
'all': 0,
|
125 |
+
'cor': 0,
|
126 |
+
'sub': 0,
|
127 |
+
'ins': 0,
|
128 |
+
'del': 0
|
129 |
+
}
|
130 |
+
for token in rec:
|
131 |
+
if token not in self.data and len(token) > 0:
|
132 |
+
self.data[token] = {
|
133 |
+
'all': 0,
|
134 |
+
'cor': 0,
|
135 |
+
'sub': 0,
|
136 |
+
'ins': 0,
|
137 |
+
'del': 0
|
138 |
+
}
|
139 |
+
# Computing edit distance
|
140 |
+
for i, lab_token in enumerate(lab):
|
141 |
+
for j, rec_token in enumerate(rec):
|
142 |
+
if i == 0 or j == 0:
|
143 |
+
continue
|
144 |
+
min_dist = sys.maxsize
|
145 |
+
min_error = 'none'
|
146 |
+
dist = self.space[i - 1][j]['dist'] + self.cost['del']
|
147 |
+
error = 'del'
|
148 |
+
if dist < min_dist:
|
149 |
+
min_dist = dist
|
150 |
+
min_error = error
|
151 |
+
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
|
152 |
+
error = 'ins'
|
153 |
+
if dist < min_dist:
|
154 |
+
min_dist = dist
|
155 |
+
min_error = error
|
156 |
+
if lab_token == rec_token:
|
157 |
+
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
|
158 |
+
error = 'cor'
|
159 |
+
else:
|
160 |
+
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
|
161 |
+
error = 'sub'
|
162 |
+
if dist < min_dist:
|
163 |
+
min_dist = dist
|
164 |
+
min_error = error
|
165 |
+
self.space[i][j]['dist'] = min_dist
|
166 |
+
self.space[i][j]['error'] = min_error
|
167 |
+
# Tracing back
|
168 |
+
result = {
|
169 |
+
'lab': [],
|
170 |
+
'rec': [],
|
171 |
+
'all': 0,
|
172 |
+
'cor': 0,
|
173 |
+
'sub': 0,
|
174 |
+
'ins': 0,
|
175 |
+
'del': 0
|
176 |
+
}
|
177 |
+
i = len(lab) - 1
|
178 |
+
j = len(rec) - 1
|
179 |
+
while True:
|
180 |
+
if self.space[i][j]['error'] == 'cor': # correct
|
181 |
+
if len(lab[i]) > 0:
|
182 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
183 |
+
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
|
184 |
+
result['all'] = result['all'] + 1
|
185 |
+
result['cor'] = result['cor'] + 1
|
186 |
+
result['lab'].insert(0, lab[i])
|
187 |
+
result['rec'].insert(0, rec[j])
|
188 |
+
i = i - 1
|
189 |
+
j = j - 1
|
190 |
+
elif self.space[i][j]['error'] == 'sub': # substitution
|
191 |
+
if len(lab[i]) > 0:
|
192 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
193 |
+
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
|
194 |
+
result['all'] = result['all'] + 1
|
195 |
+
result['sub'] = result['sub'] + 1
|
196 |
+
result['lab'].insert(0, lab[i])
|
197 |
+
result['rec'].insert(0, rec[j])
|
198 |
+
i = i - 1
|
199 |
+
j = j - 1
|
200 |
+
elif self.space[i][j]['error'] == 'del': # deletion
|
201 |
+
if len(lab[i]) > 0:
|
202 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
203 |
+
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
|
204 |
+
result['all'] = result['all'] + 1
|
205 |
+
result['del'] = result['del'] + 1
|
206 |
+
result['lab'].insert(0, lab[i])
|
207 |
+
result['rec'].insert(0, "")
|
208 |
+
i = i - 1
|
209 |
+
elif self.space[i][j]['error'] == 'ins': # insertion
|
210 |
+
if len(rec[j]) > 0:
|
211 |
+
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
|
212 |
+
result['ins'] = result['ins'] + 1
|
213 |
+
result['lab'].insert(0, "")
|
214 |
+
result['rec'].insert(0, rec[j])
|
215 |
+
j = j - 1
|
216 |
+
elif self.space[i][j]['error'] == 'non': # starting point
|
217 |
+
break
|
218 |
+
else: # shouldn't reach here
|
219 |
+
print('this should not happen , i={i} , j={j} , \
|
220 |
+
error={error}'.format(i=i,
|
221 |
+
j=j,
|
222 |
+
error=self.space[i][j]['error']))
|
223 |
+
return result
|
224 |
+
|
225 |
+
def overall(self):
|
226 |
+
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
227 |
+
for token in self.data:
|
228 |
+
result['all'] = result['all'] + self.data[token]['all']
|
229 |
+
result['cor'] = result['cor'] + self.data[token]['cor']
|
230 |
+
result['sub'] = result['sub'] + self.data[token]['sub']
|
231 |
+
result['ins'] = result['ins'] + self.data[token]['ins']
|
232 |
+
result['del'] = result['del'] + self.data[token]['del']
|
233 |
+
return result
|
234 |
+
|
235 |
+
def cluster(self, data):
|
236 |
+
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
237 |
+
for token in data:
|
238 |
+
if token in self.data:
|
239 |
+
result['all'] = result['all'] + self.data[token]['all']
|
240 |
+
result['cor'] = result['cor'] + self.data[token]['cor']
|
241 |
+
result['sub'] = result['sub'] + self.data[token]['sub']
|
242 |
+
result['ins'] = result['ins'] + self.data[token]['ins']
|
243 |
+
result['del'] = result['del'] + self.data[token]['del']
|
244 |
+
return result
|
245 |
+
|
246 |
+
def keys(self):
|
247 |
+
return list(self.data.keys())
|
248 |
+
|
249 |
+
|
250 |
+
def width(string):
|
251 |
+
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
|
252 |
+
|
253 |
+
|
254 |
+
def default_cluster(word):
|
255 |
+
unicode_names = [unicodedata.name(char) for char in word]
|
256 |
+
for i in reversed(range(len(unicode_names))):
|
257 |
+
if unicode_names[i].startswith('DIGIT'): # 1
|
258 |
+
unicode_names[i] = 'Number' # 'DIGIT'
|
259 |
+
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
|
260 |
+
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
|
261 |
+
# 明 / 郎
|
262 |
+
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
|
263 |
+
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
|
264 |
+
or unicode_names[i].startswith('LATIN SMALL LETTER')):
|
265 |
+
# A / a
|
266 |
+
unicode_names[i] = 'English' # 'LATIN LETTER'
|
267 |
+
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
|
268 |
+
unicode_names[i] = 'Japanese' # 'GANA LETTER'
|
269 |
+
elif (unicode_names[i].startswith('AMPERSAND')
|
270 |
+
or unicode_names[i].startswith('APOSTROPHE')
|
271 |
+
or unicode_names[i].startswith('COMMERCIAL AT')
|
272 |
+
or unicode_names[i].startswith('DEGREE CELSIUS')
|
273 |
+
or unicode_names[i].startswith('EQUALS SIGN')
|
274 |
+
or unicode_names[i].startswith('FULL STOP')
|
275 |
+
or unicode_names[i].startswith('HYPHEN-MINUS')
|
276 |
+
or unicode_names[i].startswith('LOW LINE')
|
277 |
+
or unicode_names[i].startswith('NUMBER SIGN')
|
278 |
+
or unicode_names[i].startswith('PLUS SIGN')
|
279 |
+
or unicode_names[i].startswith('SEMICOLON')):
|
280 |
+
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
|
281 |
+
del unicode_names[i]
|
282 |
+
else:
|
283 |
+
return 'Other'
|
284 |
+
if len(unicode_names) == 0:
|
285 |
+
return 'Other'
|
286 |
+
if len(unicode_names) == 1:
|
287 |
+
return unicode_names[0]
|
288 |
+
for i in range(len(unicode_names) - 1):
|
289 |
+
if unicode_names[i] != unicode_names[i + 1]:
|
290 |
+
return 'Other'
|
291 |
+
return unicode_names[0]
|
292 |
+
|
293 |
+
|
294 |
+
def usage():
|
295 |
+
print("compute-wer.py : compute word error rate (WER) \
|
296 |
+
and align recognition results and references.")
|
297 |
+
print(" usage : python compute-wer.py [--cs={0,1}] \
|
298 |
+
[--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \
|
299 |
+
[--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
|
300 |
+
|
301 |
+
|
302 |
+
if __name__ == '__main__':
|
303 |
+
if len(sys.argv) == 1:
|
304 |
+
usage()
|
305 |
+
sys.exit(0)
|
306 |
+
calculator = Calculator()
|
307 |
+
cluster_file = ''
|
308 |
+
ignore_words = set()
|
309 |
+
tochar = False
|
310 |
+
verbose = 1
|
311 |
+
padding_symbol = ' '
|
312 |
+
case_sensitive = False
|
313 |
+
max_words_per_line = sys.maxsize
|
314 |
+
split = None
|
315 |
+
while len(sys.argv) > 3:
|
316 |
+
a = '--maxw='
|
317 |
+
if sys.argv[1].startswith(a):
|
318 |
+
b = sys.argv[1][len(a):]
|
319 |
+
del sys.argv[1]
|
320 |
+
max_words_per_line = int(b)
|
321 |
+
continue
|
322 |
+
a = '--rt='
|
323 |
+
if sys.argv[1].startswith(a):
|
324 |
+
b = sys.argv[1][len(a):].lower()
|
325 |
+
del sys.argv[1]
|
326 |
+
remove_tag = (b == 'true') or (b != '0')
|
327 |
+
continue
|
328 |
+
a = '--cs='
|
329 |
+
if sys.argv[1].startswith(a):
|
330 |
+
b = sys.argv[1][len(a):].lower()
|
331 |
+
del sys.argv[1]
|
332 |
+
case_sensitive = (b == 'true') or (b != '0')
|
333 |
+
continue
|
334 |
+
a = '--cluster='
|
335 |
+
if sys.argv[1].startswith(a):
|
336 |
+
cluster_file = sys.argv[1][len(a):]
|
337 |
+
del sys.argv[1]
|
338 |
+
continue
|
339 |
+
a = '--splitfile='
|
340 |
+
if sys.argv[1].startswith(a):
|
341 |
+
split_file = sys.argv[1][len(a):]
|
342 |
+
del sys.argv[1]
|
343 |
+
split = dict()
|
344 |
+
with codecs.open(split_file, 'r', 'utf-8') as fh:
|
345 |
+
for line in fh: # line in unicode
|
346 |
+
words = line.strip().split()
|
347 |
+
if len(words) >= 2:
|
348 |
+
split[words[0]] = words[1:]
|
349 |
+
continue
|
350 |
+
a = '--ig='
|
351 |
+
if sys.argv[1].startswith(a):
|
352 |
+
ignore_file = sys.argv[1][len(a):]
|
353 |
+
del sys.argv[1]
|
354 |
+
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
|
355 |
+
for line in fh: # line in unicode
|
356 |
+
line = line.strip()
|
357 |
+
if len(line) > 0:
|
358 |
+
ignore_words.add(line)
|
359 |
+
continue
|
360 |
+
a = '--char='
|
361 |
+
if sys.argv[1].startswith(a):
|
362 |
+
b = sys.argv[1][len(a):].lower()
|
363 |
+
del sys.argv[1]
|
364 |
+
tochar = (b == 'true') or (b != '0')
|
365 |
+
continue
|
366 |
+
a = '--v='
|
367 |
+
if sys.argv[1].startswith(a):
|
368 |
+
b = sys.argv[1][len(a):].lower()
|
369 |
+
del sys.argv[1]
|
370 |
+
verbose = 0
|
371 |
+
try:
|
372 |
+
verbose = int(b)
|
373 |
+
except Exception:
|
374 |
+
if b == 'true' or b != '0':
|
375 |
+
verbose = 1
|
376 |
+
continue
|
377 |
+
a = '--padding-symbol='
|
378 |
+
if sys.argv[1].startswith(a):
|
379 |
+
b = sys.argv[1][len(a):].lower()
|
380 |
+
del sys.argv[1]
|
381 |
+
if b == 'space':
|
382 |
+
padding_symbol = ' '
|
383 |
+
elif b == 'underline':
|
384 |
+
padding_symbol = '_'
|
385 |
+
continue
|
386 |
+
if True or sys.argv[1].startswith('-'):
|
387 |
+
# ignore invalid switch
|
388 |
+
del sys.argv[1]
|
389 |
+
continue
|
390 |
+
|
391 |
+
if not case_sensitive:
|
392 |
+
ig = set([w.upper() for w in ignore_words])
|
393 |
+
ignore_words = ig
|
394 |
+
|
395 |
+
default_clusters = {}
|
396 |
+
default_words = {}
|
397 |
+
|
398 |
+
ref_file = sys.argv[1]
|
399 |
+
hyp_file = sys.argv[2]
|
400 |
+
rec_set = {}
|
401 |
+
if split and not case_sensitive:
|
402 |
+
newsplit = dict()
|
403 |
+
for w in split:
|
404 |
+
words = split[w]
|
405 |
+
for i in range(len(words)):
|
406 |
+
words[i] = words[i].upper()
|
407 |
+
newsplit[w.upper()] = words
|
408 |
+
split = newsplit
|
409 |
+
|
410 |
+
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
|
411 |
+
for line in fh:
|
412 |
+
if tochar:
|
413 |
+
array = characterize(line)
|
414 |
+
else:
|
415 |
+
array = line.strip().split()
|
416 |
+
if len(array) == 0:
|
417 |
+
continue
|
418 |
+
fid = array[0]
|
419 |
+
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
|
420 |
+
split)
|
421 |
+
|
422 |
+
# compute error rate on the interaction of reference file and hyp file
|
423 |
+
for line in open(ref_file, 'r', encoding='utf-8'):
|
424 |
+
if tochar:
|
425 |
+
array = characterize(line)
|
426 |
+
else:
|
427 |
+
array = line.rstrip('\n').split()
|
428 |
+
if len(array) == 0:
|
429 |
+
continue
|
430 |
+
fid = array[0]
|
431 |
+
if fid not in rec_set:
|
432 |
+
continue
|
433 |
+
lab = normalize(array[1:], ignore_words, case_sensitive, split)
|
434 |
+
rec = rec_set[fid]
|
435 |
+
if verbose:
|
436 |
+
print('\nutt: %s' % fid)
|
437 |
+
|
438 |
+
for word in rec + lab:
|
439 |
+
if word not in default_words:
|
440 |
+
default_cluster_name = default_cluster(word)
|
441 |
+
if default_cluster_name not in default_clusters:
|
442 |
+
default_clusters[default_cluster_name] = {}
|
443 |
+
if word not in default_clusters[default_cluster_name]:
|
444 |
+
default_clusters[default_cluster_name][word] = 1
|
445 |
+
default_words[word] = default_cluster_name
|
446 |
+
|
447 |
+
result = calculator.calculate(lab, rec)
|
448 |
+
if verbose:
|
449 |
+
if result['all'] != 0:
|
450 |
+
wer = float(result['ins'] + result['sub'] +
|
451 |
+
result['del']) * 100.0 / result['all']
|
452 |
+
else:
|
453 |
+
wer = 0.0
|
454 |
+
print('WER: %4.2f %%' % wer, end=' ')
|
455 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
456 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
457 |
+
result['ins']))
|
458 |
+
space = {}
|
459 |
+
space['lab'] = []
|
460 |
+
space['rec'] = []
|
461 |
+
for idx in range(len(result['lab'])):
|
462 |
+
len_lab = width(result['lab'][idx])
|
463 |
+
len_rec = width(result['rec'][idx])
|
464 |
+
length = max(len_lab, len_rec)
|
465 |
+
space['lab'].append(length - len_lab)
|
466 |
+
space['rec'].append(length - len_rec)
|
467 |
+
upper_lab = len(result['lab'])
|
468 |
+
upper_rec = len(result['rec'])
|
469 |
+
lab1, rec1 = 0, 0
|
470 |
+
while lab1 < upper_lab or rec1 < upper_rec:
|
471 |
+
if verbose > 1:
|
472 |
+
print('lab(%s):' % fid.encode('utf-8'), end=' ')
|
473 |
+
else:
|
474 |
+
print('lab:', end=' ')
|
475 |
+
lab2 = min(upper_lab, lab1 + max_words_per_line)
|
476 |
+
for idx in range(lab1, lab2):
|
477 |
+
token = result['lab'][idx]
|
478 |
+
print('{token}'.format(token=token), end='')
|
479 |
+
for n in range(space['lab'][idx]):
|
480 |
+
print(padding_symbol, end='')
|
481 |
+
print(' ', end='')
|
482 |
+
print()
|
483 |
+
if verbose > 1:
|
484 |
+
print('rec(%s):' % fid.encode('utf-8'), end=' ')
|
485 |
+
else:
|
486 |
+
print('rec:', end=' ')
|
487 |
+
rec2 = min(upper_rec, rec1 + max_words_per_line)
|
488 |
+
for idx in range(rec1, rec2):
|
489 |
+
token = result['rec'][idx]
|
490 |
+
print('{token}'.format(token=token), end='')
|
491 |
+
for n in range(space['rec'][idx]):
|
492 |
+
print(padding_symbol, end='')
|
493 |
+
print(' ', end='')
|
494 |
+
print('\n', end='\n')
|
495 |
+
lab1 = lab2
|
496 |
+
rec1 = rec2
|
497 |
+
|
498 |
+
if verbose:
|
499 |
+
print('==================================================='
|
500 |
+
'========================')
|
501 |
+
print()
|
502 |
+
|
503 |
+
result = calculator.overall()
|
504 |
+
if result['all'] != 0:
|
505 |
+
wer = float(result['ins'] + result['sub'] +
|
506 |
+
result['del']) * 100.0 / result['all']
|
507 |
+
else:
|
508 |
+
wer = 0.0
|
509 |
+
print('Overall -> %4.2f %%' % wer, end=' ')
|
510 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
511 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
512 |
+
result['ins']))
|
513 |
+
if not verbose:
|
514 |
+
print()
|
515 |
+
|
516 |
+
if verbose:
|
517 |
+
for cluster_id in default_clusters:
|
518 |
+
result = calculator.cluster(k
|
519 |
+
for k in default_clusters[cluster_id])
|
520 |
+
if result['all'] != 0:
|
521 |
+
wer = float(result['ins'] + result['sub'] +
|
522 |
+
result['del']) * 100.0 / result['all']
|
523 |
+
else:
|
524 |
+
wer = 0.0
|
525 |
+
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
|
526 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
527 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
528 |
+
result['ins']))
|
529 |
+
if len(cluster_file) > 0: # compute separated WERs for word clusters
|
530 |
+
cluster_id = ''
|
531 |
+
cluster = []
|
532 |
+
for line in open(cluster_file, 'r', encoding='utf-8'):
|
533 |
+
for token in line.decode('utf-8').rstrip('\n').split():
|
534 |
+
# end of cluster reached, like </Keyword>
|
535 |
+
if token[0:2] == '</' and token[len(token) - 1] == '>' and \
|
536 |
+
token.lstrip('</').rstrip('>') == cluster_id :
|
537 |
+
result = calculator.cluster(cluster)
|
538 |
+
if result['all'] != 0:
|
539 |
+
wer = float(result['ins'] + result['sub'] +
|
540 |
+
result['del']) * 100.0 / result['all']
|
541 |
+
else:
|
542 |
+
wer = 0.0
|
543 |
+
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
|
544 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
545 |
+
(result['all'], result['cor'], result['sub'],
|
546 |
+
result['del'], result['ins']))
|
547 |
+
cluster_id = ''
|
548 |
+
cluster = []
|
549 |
+
# begin of cluster reached, like <Keyword>
|
550 |
+
elif (token[0] == '<' and token[len(token) - 1] == '>'
|
551 |
+
and cluster_id == ''):
|
552 |
+
cluster_id = token.lstrip('<').rstrip('>')
|
553 |
+
cluster = []
|
554 |
+
# general terms, like WEATHER / CAR / ...
|
555 |
+
else:
|
556 |
+
cluster.append(token)
|
557 |
+
print()
|
558 |
+
print('======================================='
|
559 |
+
'====================================')
|
evaluation/compute-wer.py
ADDED
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import re, sys, unicodedata
|
5 |
+
import codecs
|
6 |
+
|
7 |
+
remove_tag = True
|
8 |
+
spacelist = [' ', '\t', '\r', '\n']
|
9 |
+
puncts = [
|
10 |
+
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
|
11 |
+
'《', '》'
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
def characterize(string):
|
16 |
+
res = []
|
17 |
+
i = 0
|
18 |
+
while i < len(string):
|
19 |
+
char = string[i]
|
20 |
+
if char in puncts:
|
21 |
+
i += 1
|
22 |
+
continue
|
23 |
+
cat1 = unicodedata.category(char)
|
24 |
+
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
|
25 |
+
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
|
26 |
+
i += 1
|
27 |
+
continue
|
28 |
+
if cat1 == 'Lo': # letter-other
|
29 |
+
res.append(char)
|
30 |
+
i += 1
|
31 |
+
else:
|
32 |
+
# some input looks like: <unk><noise>, we want to separate it to two words.
|
33 |
+
sep = ' '
|
34 |
+
if char == '<': sep = '>'
|
35 |
+
j = i + 1
|
36 |
+
while j < len(string):
|
37 |
+
c = string[j]
|
38 |
+
if ord(c) >= 128 or (c in spacelist) or (c == sep):
|
39 |
+
break
|
40 |
+
j += 1
|
41 |
+
if j < len(string) and string[j] == '>':
|
42 |
+
j += 1
|
43 |
+
res.append(string[i:j])
|
44 |
+
i = j
|
45 |
+
return res
|
46 |
+
|
47 |
+
|
48 |
+
def stripoff_tags(x):
|
49 |
+
if not x: return ''
|
50 |
+
chars = []
|
51 |
+
i = 0
|
52 |
+
T = len(x)
|
53 |
+
while i < T:
|
54 |
+
if x[i] == '<':
|
55 |
+
while i < T and x[i] != '>':
|
56 |
+
i += 1
|
57 |
+
i += 1
|
58 |
+
else:
|
59 |
+
chars.append(x[i])
|
60 |
+
i += 1
|
61 |
+
return ''.join(chars)
|
62 |
+
|
63 |
+
|
64 |
+
def normalize(sentence, ignore_words, cs, split=None):
|
65 |
+
""" sentence, ignore_words are both in unicode
|
66 |
+
"""
|
67 |
+
new_sentence = []
|
68 |
+
for token in sentence:
|
69 |
+
x = token
|
70 |
+
if not cs:
|
71 |
+
x = x.upper()
|
72 |
+
if x in ignore_words:
|
73 |
+
continue
|
74 |
+
if remove_tag:
|
75 |
+
x = stripoff_tags(x)
|
76 |
+
if not x:
|
77 |
+
continue
|
78 |
+
if split and x in split:
|
79 |
+
new_sentence += split[x]
|
80 |
+
else:
|
81 |
+
new_sentence.append(x)
|
82 |
+
return new_sentence
|
83 |
+
|
84 |
+
|
85 |
+
class Calculator:
|
86 |
+
|
87 |
+
def __init__(self):
|
88 |
+
self.data = {}
|
89 |
+
self.space = []
|
90 |
+
self.cost = {}
|
91 |
+
self.cost['cor'] = 0
|
92 |
+
self.cost['sub'] = 1
|
93 |
+
self.cost['del'] = 1
|
94 |
+
self.cost['ins'] = 1
|
95 |
+
|
96 |
+
def calculate(self, lab, rec):
|
97 |
+
# Initialization
|
98 |
+
lab.insert(0, '')
|
99 |
+
rec.insert(0, '')
|
100 |
+
while len(self.space) < len(lab):
|
101 |
+
self.space.append([])
|
102 |
+
for row in self.space:
|
103 |
+
for element in row:
|
104 |
+
element['dist'] = 0
|
105 |
+
element['error'] = 'non'
|
106 |
+
while len(row) < len(rec):
|
107 |
+
row.append({'dist': 0, 'error': 'non'})
|
108 |
+
for i in range(len(lab)):
|
109 |
+
self.space[i][0]['dist'] = i
|
110 |
+
self.space[i][0]['error'] = 'del'
|
111 |
+
for j in range(len(rec)):
|
112 |
+
self.space[0][j]['dist'] = j
|
113 |
+
self.space[0][j]['error'] = 'ins'
|
114 |
+
self.space[0][0]['error'] = 'non'
|
115 |
+
for token in lab:
|
116 |
+
if token not in self.data and len(token) > 0:
|
117 |
+
self.data[token] = {
|
118 |
+
'all': 0,
|
119 |
+
'cor': 0,
|
120 |
+
'sub': 0,
|
121 |
+
'ins': 0,
|
122 |
+
'del': 0
|
123 |
+
}
|
124 |
+
for token in rec:
|
125 |
+
if token not in self.data and len(token) > 0:
|
126 |
+
self.data[token] = {
|
127 |
+
'all': 0,
|
128 |
+
'cor': 0,
|
129 |
+
'sub': 0,
|
130 |
+
'ins': 0,
|
131 |
+
'del': 0
|
132 |
+
}
|
133 |
+
# Computing edit distance
|
134 |
+
for i, lab_token in enumerate(lab):
|
135 |
+
for j, rec_token in enumerate(rec):
|
136 |
+
if i == 0 or j == 0:
|
137 |
+
continue
|
138 |
+
min_dist = sys.maxsize
|
139 |
+
min_error = 'none'
|
140 |
+
dist = self.space[i - 1][j]['dist'] + self.cost['del']
|
141 |
+
error = 'del'
|
142 |
+
if dist < min_dist:
|
143 |
+
min_dist = dist
|
144 |
+
min_error = error
|
145 |
+
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
|
146 |
+
error = 'ins'
|
147 |
+
if dist < min_dist:
|
148 |
+
min_dist = dist
|
149 |
+
min_error = error
|
150 |
+
if lab_token == rec_token:
|
151 |
+
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
|
152 |
+
error = 'cor'
|
153 |
+
else:
|
154 |
+
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
|
155 |
+
error = 'sub'
|
156 |
+
if dist < min_dist:
|
157 |
+
min_dist = dist
|
158 |
+
min_error = error
|
159 |
+
self.space[i][j]['dist'] = min_dist
|
160 |
+
self.space[i][j]['error'] = min_error
|
161 |
+
# Tracing back
|
162 |
+
result = {
|
163 |
+
'lab': [],
|
164 |
+
'rec': [],
|
165 |
+
'all': 0,
|
166 |
+
'cor': 0,
|
167 |
+
'sub': 0,
|
168 |
+
'ins': 0,
|
169 |
+
'del': 0
|
170 |
+
}
|
171 |
+
i = len(lab) - 1
|
172 |
+
j = len(rec) - 1
|
173 |
+
while True:
|
174 |
+
if self.space[i][j]['error'] == 'cor': # correct
|
175 |
+
if len(lab[i]) > 0:
|
176 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
177 |
+
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
|
178 |
+
result['all'] = result['all'] + 1
|
179 |
+
result['cor'] = result['cor'] + 1
|
180 |
+
result['lab'].insert(0, lab[i])
|
181 |
+
result['rec'].insert(0, rec[j])
|
182 |
+
i = i - 1
|
183 |
+
j = j - 1
|
184 |
+
elif self.space[i][j]['error'] == 'sub': # substitution
|
185 |
+
if len(lab[i]) > 0:
|
186 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
187 |
+
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
|
188 |
+
result['all'] = result['all'] + 1
|
189 |
+
result['sub'] = result['sub'] + 1
|
190 |
+
result['lab'].insert(0, lab[i])
|
191 |
+
result['rec'].insert(0, rec[j])
|
192 |
+
i = i - 1
|
193 |
+
j = j - 1
|
194 |
+
elif self.space[i][j]['error'] == 'del': # deletion
|
195 |
+
if len(lab[i]) > 0:
|
196 |
+
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
|
197 |
+
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
|
198 |
+
result['all'] = result['all'] + 1
|
199 |
+
result['del'] = result['del'] + 1
|
200 |
+
result['lab'].insert(0, lab[i])
|
201 |
+
result['rec'].insert(0, "")
|
202 |
+
i = i - 1
|
203 |
+
elif self.space[i][j]['error'] == 'ins': # insertion
|
204 |
+
if len(rec[j]) > 0:
|
205 |
+
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
|
206 |
+
result['ins'] = result['ins'] + 1
|
207 |
+
result['lab'].insert(0, "")
|
208 |
+
result['rec'].insert(0, rec[j])
|
209 |
+
j = j - 1
|
210 |
+
elif self.space[i][j]['error'] == 'non': # starting point
|
211 |
+
break
|
212 |
+
else: # shouldn't reach here
|
213 |
+
print(
|
214 |
+
'this should not happen , i = {i} , j = {j} , error = {error}'
|
215 |
+
.format(i=i, j=j, error=self.space[i][j]['error']))
|
216 |
+
return result
|
217 |
+
|
218 |
+
def overall(self):
|
219 |
+
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
220 |
+
for token in self.data:
|
221 |
+
result['all'] = result['all'] + self.data[token]['all']
|
222 |
+
result['cor'] = result['cor'] + self.data[token]['cor']
|
223 |
+
result['sub'] = result['sub'] + self.data[token]['sub']
|
224 |
+
result['ins'] = result['ins'] + self.data[token]['ins']
|
225 |
+
result['del'] = result['del'] + self.data[token]['del']
|
226 |
+
return result
|
227 |
+
|
228 |
+
def cluster(self, data):
|
229 |
+
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
|
230 |
+
for token in data:
|
231 |
+
if token in self.data:
|
232 |
+
result['all'] = result['all'] + self.data[token]['all']
|
233 |
+
result['cor'] = result['cor'] + self.data[token]['cor']
|
234 |
+
result['sub'] = result['sub'] + self.data[token]['sub']
|
235 |
+
result['ins'] = result['ins'] + self.data[token]['ins']
|
236 |
+
result['del'] = result['del'] + self.data[token]['del']
|
237 |
+
return result
|
238 |
+
|
239 |
+
def keys(self):
|
240 |
+
return list(self.data.keys())
|
241 |
+
|
242 |
+
|
243 |
+
def width(string):
|
244 |
+
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
|
245 |
+
|
246 |
+
|
247 |
+
def default_cluster(word):
|
248 |
+
unicode_names = [unicodedata.name(char) for char in word]
|
249 |
+
for i in reversed(range(len(unicode_names))):
|
250 |
+
if unicode_names[i].startswith('DIGIT'): # 1
|
251 |
+
unicode_names[i] = 'Number' # 'DIGIT'
|
252 |
+
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
|
253 |
+
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
|
254 |
+
# 明 / 郎
|
255 |
+
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
|
256 |
+
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
|
257 |
+
or unicode_names[i].startswith('LATIN SMALL LETTER')):
|
258 |
+
# A / a
|
259 |
+
unicode_names[i] = 'English' # 'LATIN LETTER'
|
260 |
+
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
|
261 |
+
unicode_names[i] = 'Japanese' # 'GANA LETTER'
|
262 |
+
elif (unicode_names[i].startswith('AMPERSAND')
|
263 |
+
or unicode_names[i].startswith('APOSTROPHE')
|
264 |
+
or unicode_names[i].startswith('COMMERCIAL AT')
|
265 |
+
or unicode_names[i].startswith('DEGREE CELSIUS')
|
266 |
+
or unicode_names[i].startswith('EQUALS SIGN')
|
267 |
+
or unicode_names[i].startswith('FULL STOP')
|
268 |
+
or unicode_names[i].startswith('HYPHEN-MINUS')
|
269 |
+
or unicode_names[i].startswith('LOW LINE')
|
270 |
+
or unicode_names[i].startswith('NUMBER SIGN')
|
271 |
+
or unicode_names[i].startswith('PLUS SIGN')
|
272 |
+
or unicode_names[i].startswith('SEMICOLON')):
|
273 |
+
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
|
274 |
+
del unicode_names[i]
|
275 |
+
else:
|
276 |
+
return 'Other'
|
277 |
+
if len(unicode_names) == 0:
|
278 |
+
return 'Other'
|
279 |
+
if len(unicode_names) == 1:
|
280 |
+
return unicode_names[0]
|
281 |
+
for i in range(len(unicode_names) - 1):
|
282 |
+
if unicode_names[i] != unicode_names[i + 1]:
|
283 |
+
return 'Other'
|
284 |
+
return unicode_names[0]
|
285 |
+
|
286 |
+
|
287 |
+
def usage():
|
288 |
+
print(
|
289 |
+
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
|
290 |
+
)
|
291 |
+
print(
|
292 |
+
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == '__main__':
|
297 |
+
if len(sys.argv) == 1:
|
298 |
+
usage()
|
299 |
+
sys.exit(0)
|
300 |
+
calculator = Calculator()
|
301 |
+
cluster_file = ''
|
302 |
+
ignore_words = set()
|
303 |
+
tochar = False
|
304 |
+
verbose = 1
|
305 |
+
padding_symbol = ' '
|
306 |
+
case_sensitive = False
|
307 |
+
max_words_per_line = sys.maxsize
|
308 |
+
split = None
|
309 |
+
while len(sys.argv) > 3:
|
310 |
+
a = '--maxw='
|
311 |
+
if sys.argv[1].startswith(a):
|
312 |
+
b = sys.argv[1][len(a):]
|
313 |
+
del sys.argv[1]
|
314 |
+
max_words_per_line = int(b)
|
315 |
+
continue
|
316 |
+
a = '--rt='
|
317 |
+
if sys.argv[1].startswith(a):
|
318 |
+
b = sys.argv[1][len(a):].lower()
|
319 |
+
del sys.argv[1]
|
320 |
+
remove_tag = (b == 'true') or (b != '0')
|
321 |
+
continue
|
322 |
+
a = '--cs='
|
323 |
+
if sys.argv[1].startswith(a):
|
324 |
+
b = sys.argv[1][len(a):].lower()
|
325 |
+
del sys.argv[1]
|
326 |
+
case_sensitive = (b == 'true') or (b != '0')
|
327 |
+
continue
|
328 |
+
a = '--cluster='
|
329 |
+
if sys.argv[1].startswith(a):
|
330 |
+
cluster_file = sys.argv[1][len(a):]
|
331 |
+
del sys.argv[1]
|
332 |
+
continue
|
333 |
+
a = '--splitfile='
|
334 |
+
if sys.argv[1].startswith(a):
|
335 |
+
split_file = sys.argv[1][len(a):]
|
336 |
+
del sys.argv[1]
|
337 |
+
split = dict()
|
338 |
+
with codecs.open(split_file, 'r', 'utf-8') as fh:
|
339 |
+
for line in fh: # line in unicode
|
340 |
+
words = line.strip().split()
|
341 |
+
if len(words) >= 2:
|
342 |
+
split[words[0]] = words[1:]
|
343 |
+
continue
|
344 |
+
a = '--ig='
|
345 |
+
if sys.argv[1].startswith(a):
|
346 |
+
ignore_file = sys.argv[1][len(a):]
|
347 |
+
del sys.argv[1]
|
348 |
+
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
|
349 |
+
for line in fh: # line in unicode
|
350 |
+
line = line.strip()
|
351 |
+
if len(line) > 0:
|
352 |
+
ignore_words.add(line)
|
353 |
+
continue
|
354 |
+
a = '--char='
|
355 |
+
if sys.argv[1].startswith(a):
|
356 |
+
b = sys.argv[1][len(a):].lower()
|
357 |
+
del sys.argv[1]
|
358 |
+
tochar = (b == 'true') or (b != '0')
|
359 |
+
continue
|
360 |
+
a = '--v='
|
361 |
+
if sys.argv[1].startswith(a):
|
362 |
+
b = sys.argv[1][len(a):].lower()
|
363 |
+
del sys.argv[1]
|
364 |
+
verbose = 0
|
365 |
+
try:
|
366 |
+
verbose = int(b)
|
367 |
+
except:
|
368 |
+
if b == 'true' or b != '0':
|
369 |
+
verbose = 1
|
370 |
+
continue
|
371 |
+
a = '--padding-symbol='
|
372 |
+
if sys.argv[1].startswith(a):
|
373 |
+
b = sys.argv[1][len(a):].lower()
|
374 |
+
del sys.argv[1]
|
375 |
+
if b == 'space':
|
376 |
+
padding_symbol = ' '
|
377 |
+
elif b == 'underline':
|
378 |
+
padding_symbol = '_'
|
379 |
+
continue
|
380 |
+
if True or sys.argv[1].startswith('-'):
|
381 |
+
#ignore invalid switch
|
382 |
+
del sys.argv[1]
|
383 |
+
continue
|
384 |
+
|
385 |
+
if not case_sensitive:
|
386 |
+
ig = set([w.upper() for w in ignore_words])
|
387 |
+
ignore_words = ig
|
388 |
+
|
389 |
+
default_clusters = {}
|
390 |
+
default_words = {}
|
391 |
+
|
392 |
+
ref_file = sys.argv[1]
|
393 |
+
hyp_file = sys.argv[2]
|
394 |
+
rec_set = {}
|
395 |
+
if split and not case_sensitive:
|
396 |
+
newsplit = dict()
|
397 |
+
for w in split:
|
398 |
+
words = split[w]
|
399 |
+
for i in range(len(words)):
|
400 |
+
words[i] = words[i].upper()
|
401 |
+
newsplit[w.upper()] = words
|
402 |
+
split = newsplit
|
403 |
+
|
404 |
+
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
|
405 |
+
for line in fh:
|
406 |
+
if tochar:
|
407 |
+
array = characterize(line)
|
408 |
+
else:
|
409 |
+
array = line.strip().split()
|
410 |
+
if len(array) == 0: continue
|
411 |
+
fid = array[0]
|
412 |
+
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
|
413 |
+
split)
|
414 |
+
|
415 |
+
# compute error rate on the interaction of reference file and hyp file
|
416 |
+
for line in open(ref_file, 'r', encoding='utf-8'):
|
417 |
+
if tochar:
|
418 |
+
array = characterize(line)
|
419 |
+
else:
|
420 |
+
array = line.rstrip('\n').split()
|
421 |
+
if len(array) == 0: continue
|
422 |
+
fid = array[0]
|
423 |
+
if fid not in rec_set:
|
424 |
+
continue
|
425 |
+
lab = normalize(array[1:], ignore_words, case_sensitive, split)
|
426 |
+
rec = rec_set[fid]
|
427 |
+
if verbose:
|
428 |
+
print('\nutt: %s' % fid)
|
429 |
+
|
430 |
+
for word in rec + lab:
|
431 |
+
if word not in default_words:
|
432 |
+
default_cluster_name = default_cluster(word)
|
433 |
+
if default_cluster_name not in default_clusters:
|
434 |
+
default_clusters[default_cluster_name] = {}
|
435 |
+
if word not in default_clusters[default_cluster_name]:
|
436 |
+
default_clusters[default_cluster_name][word] = 1
|
437 |
+
default_words[word] = default_cluster_name
|
438 |
+
|
439 |
+
result = calculator.calculate(lab, rec)
|
440 |
+
if verbose:
|
441 |
+
if result['all'] != 0:
|
442 |
+
wer = float(result['ins'] + result['sub'] +
|
443 |
+
result['del']) * 100.0 / result['all']
|
444 |
+
else:
|
445 |
+
wer = 0.0
|
446 |
+
print('WER: %4.2f %%' % wer, end=' ')
|
447 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
448 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
449 |
+
result['ins']))
|
450 |
+
space = {}
|
451 |
+
space['lab'] = []
|
452 |
+
space['rec'] = []
|
453 |
+
for idx in range(len(result['lab'])):
|
454 |
+
len_lab = width(result['lab'][idx])
|
455 |
+
len_rec = width(result['rec'][idx])
|
456 |
+
length = max(len_lab, len_rec)
|
457 |
+
space['lab'].append(length - len_lab)
|
458 |
+
space['rec'].append(length - len_rec)
|
459 |
+
upper_lab = len(result['lab'])
|
460 |
+
upper_rec = len(result['rec'])
|
461 |
+
lab1, rec1 = 0, 0
|
462 |
+
while lab1 < upper_lab or rec1 < upper_rec:
|
463 |
+
if verbose > 1:
|
464 |
+
print('lab(%s):' % fid.encode('utf-8'), end=' ')
|
465 |
+
else:
|
466 |
+
print('lab:', end=' ')
|
467 |
+
lab2 = min(upper_lab, lab1 + max_words_per_line)
|
468 |
+
for idx in range(lab1, lab2):
|
469 |
+
token = result['lab'][idx]
|
470 |
+
print('{token}'.format(token=token), end='')
|
471 |
+
for n in range(space['lab'][idx]):
|
472 |
+
print(padding_symbol, end='')
|
473 |
+
print(' ', end='')
|
474 |
+
print()
|
475 |
+
if verbose > 1:
|
476 |
+
print('rec(%s):' % fid.encode('utf-8'), end=' ')
|
477 |
+
else:
|
478 |
+
print('rec:', end=' ')
|
479 |
+
rec2 = min(upper_rec, rec1 + max_words_per_line)
|
480 |
+
for idx in range(rec1, rec2):
|
481 |
+
token = result['rec'][idx]
|
482 |
+
print('{token}'.format(token=token), end='')
|
483 |
+
for n in range(space['rec'][idx]):
|
484 |
+
print(padding_symbol, end='')
|
485 |
+
print(' ', end='')
|
486 |
+
print('\n', end='\n')
|
487 |
+
lab1 = lab2
|
488 |
+
rec1 = rec2
|
489 |
+
|
490 |
+
if verbose:
|
491 |
+
print(
|
492 |
+
'==========================================================================='
|
493 |
+
)
|
494 |
+
print()
|
495 |
+
|
496 |
+
result = calculator.overall()
|
497 |
+
if result['all'] != 0:
|
498 |
+
wer = float(result['ins'] + result['sub'] +
|
499 |
+
result['del']) * 100.0 / result['all']
|
500 |
+
else:
|
501 |
+
wer = 0.0
|
502 |
+
print('Overall -> %4.2f %%' % wer, end=' ')
|
503 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
504 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
505 |
+
result['ins']))
|
506 |
+
if not verbose:
|
507 |
+
print()
|
508 |
+
|
509 |
+
if verbose:
|
510 |
+
for cluster_id in default_clusters:
|
511 |
+
result = calculator.cluster(
|
512 |
+
[k for k in default_clusters[cluster_id]])
|
513 |
+
if result['all'] != 0:
|
514 |
+
wer = float(result['ins'] + result['sub'] +
|
515 |
+
result['del']) * 100.0 / result['all']
|
516 |
+
else:
|
517 |
+
wer = 0.0
|
518 |
+
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
|
519 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
520 |
+
(result['all'], result['cor'], result['sub'], result['del'],
|
521 |
+
result['ins']))
|
522 |
+
if len(cluster_file) > 0: # compute separated WERs for word clusters
|
523 |
+
cluster_id = ''
|
524 |
+
cluster = []
|
525 |
+
for line in open(cluster_file, 'r', encoding='utf-8'):
|
526 |
+
for token in line.decode('utf-8').rstrip('\n').split():
|
527 |
+
# end of cluster reached, like </Keyword>
|
528 |
+
if token[0:2] == '</' and token[len(token)-1] == '>' and \
|
529 |
+
token.lstrip('</').rstrip('>') == cluster_id :
|
530 |
+
result = calculator.cluster(cluster)
|
531 |
+
if result['all'] != 0:
|
532 |
+
wer = float(result['ins'] + result['sub'] +
|
533 |
+
result['del']) * 100.0 / result['all']
|
534 |
+
else:
|
535 |
+
wer = 0.0
|
536 |
+
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
|
537 |
+
print('N=%d C=%d S=%d D=%d I=%d' %
|
538 |
+
(result['all'], result['cor'], result['sub'],
|
539 |
+
result['del'], result['ins']))
|
540 |
+
cluster_id = ''
|
541 |
+
cluster = []
|
542 |
+
# begin of cluster reached, like <Keyword>
|
543 |
+
elif token[0] == '<' and token[len(token)-1] == '>' and \
|
544 |
+
cluster_id == '' :
|
545 |
+
cluster_id = token.lstrip('<').rstrip('>')
|
546 |
+
cluster = []
|
547 |
+
# general terms, like WEATHER / CAR / ...
|
548 |
+
else:
|
549 |
+
cluster.append(token)
|
550 |
+
print()
|
551 |
+
print(
|
552 |
+
'==========================================================================='
|
553 |
+
)
|
evaluation/evaluate_asr.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
from datetime import timedelta
|
9 |
+
from functools import partial
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from datasets import load_dataset
|
15 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
16 |
+
from transformers.generation import GenerationConfig
|
17 |
+
|
18 |
+
import torchaudio
|
19 |
+
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
|
20 |
+
from vita_audio.tokenizer import get_audio_tokenizer
|
21 |
+
|
22 |
+
|
23 |
+
def collate_fn(batches):
|
24 |
+
input_ids = [sample["input_ids"] for sample in batches]
|
25 |
+
audios = [sample["audios"] for sample in batches]
|
26 |
+
audio_indices = [sample["audio_indices"] for sample in batches]
|
27 |
+
|
28 |
+
refs = [sample["ref"] for sample in batches]
|
29 |
+
|
30 |
+
return input_ids, audios, audio_indices, refs
|
31 |
+
|
32 |
+
|
33 |
+
class ASRDataset(torch.utils.data.Dataset):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
json_path,
|
37 |
+
tokenizer,
|
38 |
+
audio_tokenizer,
|
39 |
+
default_system_message=None,
|
40 |
+
add_generation_prompt=True,
|
41 |
+
):
|
42 |
+
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
|
43 |
+
self.data = data["train"]
|
44 |
+
|
45 |
+
self.tokenizer = tokenizer
|
46 |
+
self.add_generation_prompt = add_generation_prompt
|
47 |
+
|
48 |
+
self.audio_tokenizer = audio_tokenizer
|
49 |
+
self.default_system_message = default_system_message
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return len(self.data)
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
sample = self.data[idx]
|
56 |
+
# print(f"sample {sample}")
|
57 |
+
|
58 |
+
audio_path = sample["audios"][0]
|
59 |
+
|
60 |
+
if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
|
61 |
+
# discrete codec
|
62 |
+
audio_tokens = self.audio_tokenizer.encode(audio_path)
|
63 |
+
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
|
64 |
+
else:
|
65 |
+
audio_tokens = None
|
66 |
+
|
67 |
+
messages = []
|
68 |
+
if len(sample["messages"]) == 2:
|
69 |
+
assert len(sample["messages"]) == 2
|
70 |
+
assert sample["messages"][0]["role"] == "user"
|
71 |
+
assert sample["messages"][1]["role"] == "assistant"
|
72 |
+
|
73 |
+
if self.default_system_message is not None:
|
74 |
+
messages = self.default_system_message + messages
|
75 |
+
|
76 |
+
elif len(sample["messages"]) == 3:
|
77 |
+
assert len(sample["messages"]) == 3
|
78 |
+
assert sample["messages"][0]["role"] == "system"
|
79 |
+
assert sample["messages"][1]["role"] == "user"
|
80 |
+
assert sample["messages"][2]["role"] == "assistant"
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise NotImplementedError
|
84 |
+
|
85 |
+
# print(sample)
|
86 |
+
for conv in sample["messages"][:-1]:
|
87 |
+
new_conv = {}
|
88 |
+
new_conv["role"] = conv["role"]
|
89 |
+
|
90 |
+
content = conv["content"]
|
91 |
+
|
92 |
+
if audio_tokens is not None:
|
93 |
+
content = content.replace(
|
94 |
+
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
|
95 |
+
)
|
96 |
+
|
97 |
+
new_conv["content"] = content
|
98 |
+
messages.append(new_conv)
|
99 |
+
|
100 |
+
input_ids = self.tokenizer.apply_chat_template(
|
101 |
+
messages,
|
102 |
+
tokenize=True,
|
103 |
+
add_generation_prompt=self.add_generation_prompt,
|
104 |
+
# return_tensors="pt",
|
105 |
+
)
|
106 |
+
|
107 |
+
ref = sample["messages"][-1]["content"]
|
108 |
+
|
109 |
+
if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
|
110 |
+
# contiguous codec
|
111 |
+
input_ids, audios, audio_indices = add_audio_input_contiguous(
|
112 |
+
input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
audios = None
|
116 |
+
audio_indices = None
|
117 |
+
|
118 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
119 |
+
return {
|
120 |
+
"input_ids": input_ids,
|
121 |
+
"audios": audios,
|
122 |
+
"audio_indices": audio_indices,
|
123 |
+
"ref": ref,
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
128 |
+
def __init__(self, size):
|
129 |
+
self._size = int(size)
|
130 |
+
assert size > 0
|
131 |
+
self._rank = torch.distributed.get_rank()
|
132 |
+
self._world_size = torch.distributed.get_world_size()
|
133 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def _get_local_indices(total_size, world_size, rank):
|
137 |
+
shard_size = total_size // world_size
|
138 |
+
left = total_size % world_size
|
139 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
140 |
+
|
141 |
+
begin = sum(shard_sizes[:rank])
|
142 |
+
end = min(sum(shard_sizes[: rank + 1]), total_size)
|
143 |
+
return range(begin, end)
|
144 |
+
|
145 |
+
def __iter__(self):
|
146 |
+
yield from self._local_indices
|
147 |
+
|
148 |
+
def __len__(self):
|
149 |
+
return len(self._local_indices)
|
150 |
+
|
151 |
+
|
152 |
+
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
|
153 |
+
|
154 |
+
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
|
155 |
+
|
156 |
+
outputs = []
|
157 |
+
|
158 |
+
for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref) in enumerate(
|
159 |
+
tqdm.tqdm(dataloader)
|
160 |
+
):
|
161 |
+
|
162 |
+
for input_ids, audios, audio_indices, ref in zip(
|
163 |
+
batched_input_ids, batched_audios, batched_audio_indices, batched_ref
|
164 |
+
):
|
165 |
+
kwargs = {
|
166 |
+
# "temperature": 0.2,
|
167 |
+
# "top_p": 0.8,
|
168 |
+
# "do_sample": False,
|
169 |
+
# "temperature": 1.0,
|
170 |
+
"max_new_tokens": max([len(x) for x in batched_ref]) + 10,
|
171 |
+
"min_new_tokens": 1,
|
172 |
+
}
|
173 |
+
if audios is not None:
|
174 |
+
kwargs["audios"] = audios
|
175 |
+
kwargs["audio_indices"] = audio_indices
|
176 |
+
|
177 |
+
responses = model.generate(
|
178 |
+
input_ids=input_ids.cuda(),
|
179 |
+
**kwargs,
|
180 |
+
)
|
181 |
+
|
182 |
+
response = responses[0][len(input_ids[0]) :]
|
183 |
+
|
184 |
+
text_tokens = []
|
185 |
+
audio_tokens = []
|
186 |
+
for token_id in response:
|
187 |
+
if token_id >= audio_offset:
|
188 |
+
audio_tokens.append(token_id - audio_offset)
|
189 |
+
else:
|
190 |
+
text_tokens.append(token_id)
|
191 |
+
|
192 |
+
hyp = tokenizer.decode(text_tokens, skip_special_tokens=True)
|
193 |
+
|
194 |
+
outputs.append((hyp, ref))
|
195 |
+
|
196 |
+
print("")
|
197 |
+
print("=" * 100)
|
198 |
+
print(f"{hyp=}")
|
199 |
+
print(f"{ref=}")
|
200 |
+
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
parser = argparse.ArgumentParser(
|
206 |
+
description="",
|
207 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
208 |
+
)
|
209 |
+
|
210 |
+
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
|
211 |
+
parser.add_argument(
|
212 |
+
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
|
216 |
+
)
|
217 |
+
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
|
218 |
+
|
219 |
+
parser.add_argument("--json_path", type=str, required=True, help="json_path")
|
220 |
+
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
|
221 |
+
|
222 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
223 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
224 |
+
|
225 |
+
args = parser.parse_args()
|
226 |
+
|
227 |
+
print(f"{args=}")
|
228 |
+
|
229 |
+
torch.distributed.init_process_group(
|
230 |
+
backend="nccl",
|
231 |
+
world_size=int(os.getenv("WORLD_SIZE", "1")),
|
232 |
+
rank=int(os.getenv("RANK", "0")),
|
233 |
+
timeout=timedelta(seconds=7200),
|
234 |
+
)
|
235 |
+
|
236 |
+
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
|
237 |
+
|
238 |
+
random.seed(42)
|
239 |
+
torch.manual_seed(42)
|
240 |
+
|
241 |
+
config = AutoConfig.from_pretrained(
|
242 |
+
args.model_name_or_path,
|
243 |
+
trust_remote_code=True,
|
244 |
+
)
|
245 |
+
|
246 |
+
# ================================================================
|
247 |
+
if "glm" in config.model_type.lower():
|
248 |
+
from get_chat_template import glm4_chat_template as chat_template
|
249 |
+
|
250 |
+
add_generation_prompt = True
|
251 |
+
|
252 |
+
default_system_message = [
|
253 |
+
{
|
254 |
+
"role": "system",
|
255 |
+
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
|
256 |
+
}
|
257 |
+
]
|
258 |
+
|
259 |
+
if "qwen2" in config.model_type.lower():
|
260 |
+
from get_chat_template import qwen2_chat_template as chat_template
|
261 |
+
|
262 |
+
add_generation_prompt = True
|
263 |
+
|
264 |
+
default_system_message = []
|
265 |
+
|
266 |
+
if "hunyuan" in config.model_type.lower():
|
267 |
+
from get_chat_template import hunyuan_chat_template as chat_template
|
268 |
+
|
269 |
+
add_generation_prompt = False
|
270 |
+
|
271 |
+
default_system_message = [
|
272 |
+
{
|
273 |
+
"role": "system",
|
274 |
+
"content": "You are a helpful AI assistant.",
|
275 |
+
}
|
276 |
+
]
|
277 |
+
|
278 |
+
# ================================================================
|
279 |
+
print("Loading model")
|
280 |
+
# device_map = "auto"
|
281 |
+
device_map = "cuda"
|
282 |
+
# torch_dtype=torch.float16
|
283 |
+
torch_dtype = torch.bfloat16
|
284 |
+
|
285 |
+
rank = torch.distributed.get_rank()
|
286 |
+
|
287 |
+
audio_tokenizer = get_audio_tokenizer(
|
288 |
+
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
|
289 |
+
)
|
290 |
+
|
291 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
292 |
+
args.model_name_or_path,
|
293 |
+
trust_remote_code=True,
|
294 |
+
chat_template=chat_template,
|
295 |
+
)
|
296 |
+
# print("tokenizer", tokenizer)
|
297 |
+
|
298 |
+
model = AutoModelForCausalLM.from_pretrained(
|
299 |
+
args.model_name_or_path,
|
300 |
+
trust_remote_code=True,
|
301 |
+
device_map=device_map,
|
302 |
+
torch_dtype=torch_dtype,
|
303 |
+
attn_implementation="flash_attention_2",
|
304 |
+
).eval()
|
305 |
+
# print("model", model)
|
306 |
+
|
307 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
308 |
+
args.model_name_or_path, trust_remote_code=True
|
309 |
+
)
|
310 |
+
|
311 |
+
model.generation_config.max_new_tokens = 4096
|
312 |
+
model.generation_config.chat_format = "chatml"
|
313 |
+
model.generation_config.max_window_size = 8192
|
314 |
+
model.generation_config.use_cache = True
|
315 |
+
model.generation_config.do_sample = False
|
316 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
317 |
+
if model.config.model_type == "hunyuan":
|
318 |
+
model.generation_config.eos_token_id = tokenizer.eos_id
|
319 |
+
|
320 |
+
# ================================================================
|
321 |
+
print("Loading data")
|
322 |
+
dataset = ASRDataset(
|
323 |
+
json_path=args.json_path,
|
324 |
+
tokenizer=tokenizer,
|
325 |
+
audio_tokenizer=audio_tokenizer,
|
326 |
+
default_system_message=default_system_message,
|
327 |
+
add_generation_prompt=add_generation_prompt,
|
328 |
+
)
|
329 |
+
|
330 |
+
dataloader = torch.utils.data.DataLoader(
|
331 |
+
dataset=dataset,
|
332 |
+
sampler=InferenceSampler(len(dataset)),
|
333 |
+
batch_size=args.batch_size,
|
334 |
+
num_workers=args.num_workers,
|
335 |
+
pin_memory=True,
|
336 |
+
drop_last=False,
|
337 |
+
collate_fn=partial(
|
338 |
+
collate_fn,
|
339 |
+
),
|
340 |
+
)
|
341 |
+
|
342 |
+
# ================================================================
|
343 |
+
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
|
344 |
+
|
345 |
+
torch.distributed.barrier()
|
346 |
+
|
347 |
+
world_size = torch.distributed.get_world_size()
|
348 |
+
merged_outputs = [None for _ in range(world_size)]
|
349 |
+
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
|
350 |
+
|
351 |
+
merged_outputs = [json.loads(_) for _ in merged_outputs]
|
352 |
+
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
|
353 |
+
|
354 |
+
if torch.distributed.get_rank() == 0:
|
355 |
+
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
|
356 |
+
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
|
357 |
+
hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
|
358 |
+
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
|
359 |
+
|
360 |
+
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
361 |
+
os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
|
362 |
+
|
363 |
+
hyp_file = open(hyp_path, "w")
|
364 |
+
ref_file = open(ref_path, "w")
|
365 |
+
|
366 |
+
for sample_idx, (hyp, ref) in enumerate(merged_outputs):
|
367 |
+
hyp_file.write(f"{sample_idx} {hyp}" + "\n")
|
368 |
+
ref_file.write(f"{sample_idx} {ref}" + "\n")
|
369 |
+
|
370 |
+
hyp_file.close()
|
371 |
+
ref_file.close()
|
372 |
+
|
373 |
+
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
|
374 |
+
hyp_ref_file = open(hyp_ref_path, "w")
|
375 |
+
json.dump(merged_outputs, hyp_ref_file, indent=4)
|
376 |
+
hyp_ref_file.close()
|
377 |
+
|
378 |
+
torch.distributed.barrier()
|
379 |
+
print("Done.")
|
evaluation/evaluate_libritts.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import uuid
|
9 |
+
from datetime import timedelta
|
10 |
+
from functools import partial
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import tqdm
|
15 |
+
from datasets import load_dataset
|
16 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
17 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
18 |
+
from transformers.generation import GenerationConfig
|
19 |
+
|
20 |
+
import torchaudio
|
21 |
+
from vita_audio.tokenizer import get_audio_tokenizer
|
22 |
+
|
23 |
+
|
24 |
+
def collate_fn(batches):
|
25 |
+
input_ids = [sample["input_ids"] for sample in batches]
|
26 |
+
|
27 |
+
refs = [sample["ref"] for sample in batches]
|
28 |
+
filenames = [sample["filename"] for sample in batches]
|
29 |
+
|
30 |
+
return input_ids, refs, filenames
|
31 |
+
|
32 |
+
|
33 |
+
class TTSDataset(torch.utils.data.Dataset):
|
34 |
+
def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
|
35 |
+
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
|
36 |
+
self.data = data["train"]
|
37 |
+
|
38 |
+
self.tokenizer = tokenizer
|
39 |
+
self.audio_tokenizer = audio_tokenizer
|
40 |
+
self.default_system_message = default_system_message
|
41 |
+
self.add_generation_prompt = add_generation_prompt
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.data)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
sample = self.data[idx]
|
48 |
+
|
49 |
+
messages = []
|
50 |
+
|
51 |
+
if self.default_system_message is not None:
|
52 |
+
messages = self.default_system_message + messages
|
53 |
+
|
54 |
+
role = "user"
|
55 |
+
content = sample["messages"][0]["content"]
|
56 |
+
messages.append(
|
57 |
+
{
|
58 |
+
"role": role,
|
59 |
+
"content": content,
|
60 |
+
}
|
61 |
+
)
|
62 |
+
|
63 |
+
input_ids = self.tokenizer.apply_chat_template(
|
64 |
+
messages,
|
65 |
+
tokenize=True,
|
66 |
+
add_generation_prompt=self.add_generation_prompt,
|
67 |
+
return_tensors="pt",
|
68 |
+
)
|
69 |
+
|
70 |
+
ref = sample["messages"][0]["content"]
|
71 |
+
ref = ref.replace("Convert the text to speech.\n", "")
|
72 |
+
ref = ref.strip()
|
73 |
+
|
74 |
+
filepath = sample["audios"][0]
|
75 |
+
filename = os.path.basename(filepath)
|
76 |
+
|
77 |
+
return {
|
78 |
+
"input_ids": input_ids,
|
79 |
+
"ref": ref,
|
80 |
+
"filename": filename,
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
85 |
+
def __init__(self, size):
|
86 |
+
self._size = int(size)
|
87 |
+
assert size > 0
|
88 |
+
self._rank = torch.distributed.get_rank()
|
89 |
+
self._world_size = torch.distributed.get_world_size()
|
90 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def _get_local_indices(total_size, world_size, rank):
|
94 |
+
shard_size = total_size // world_size
|
95 |
+
left = total_size % world_size
|
96 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
97 |
+
|
98 |
+
begin = sum(shard_sizes[:rank])
|
99 |
+
end = min(sum(shard_sizes[: rank + 1]), total_size)
|
100 |
+
return range(begin, end)
|
101 |
+
|
102 |
+
def __iter__(self):
|
103 |
+
yield from self._local_indices
|
104 |
+
|
105 |
+
def __len__(self):
|
106 |
+
return len(self._local_indices)
|
107 |
+
|
108 |
+
|
109 |
+
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
|
110 |
+
|
111 |
+
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
|
112 |
+
en_tn_model = EnNormalizer(overwrite_cache=True)
|
113 |
+
|
114 |
+
outputs = []
|
115 |
+
|
116 |
+
for _, (
|
117 |
+
batched_input_ids,
|
118 |
+
batched_ref,
|
119 |
+
batched_filename,
|
120 |
+
) in enumerate(tqdm.tqdm(dataloader)):
|
121 |
+
for input_ids, ref, filename in zip(
|
122 |
+
batched_input_ids, batched_ref, batched_filename
|
123 |
+
):
|
124 |
+
|
125 |
+
responses = model.generate(
|
126 |
+
input_ids=input_ids.cuda(),
|
127 |
+
# temperature=0.2,
|
128 |
+
# top_p=0.8,
|
129 |
+
# do_sample=False,
|
130 |
+
# temperature=1.0,
|
131 |
+
max_new_tokens=1024,
|
132 |
+
min_new_tokens=1,
|
133 |
+
)
|
134 |
+
|
135 |
+
response = responses[0][len(input_ids[0]) :]
|
136 |
+
|
137 |
+
text_tokens = []
|
138 |
+
audio_tokens = []
|
139 |
+
for token_id in response:
|
140 |
+
if token_id >= audio_offset:
|
141 |
+
audio_tokens.append(token_id - audio_offset)
|
142 |
+
else:
|
143 |
+
text_tokens.append(token_id)
|
144 |
+
|
145 |
+
if len(audio_tokens) == 0:
|
146 |
+
continue
|
147 |
+
|
148 |
+
tts_speech = audio_tokenizer.decode(audio_tokens)
|
149 |
+
|
150 |
+
wav_dir = os.path.join(output_dir, "audio")
|
151 |
+
wav_path = os.path.join(wav_dir, filename + ".wav")
|
152 |
+
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
153 |
+
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
|
154 |
+
|
155 |
+
hyp = asr_model(wav_path, return_timestamps=True)["text"].strip()
|
156 |
+
|
157 |
+
hyp = en_tn_model.normalize(hyp)
|
158 |
+
ref = en_tn_model.normalize(ref)
|
159 |
+
|
160 |
+
hyp = re.sub(r"\W+", " ", hyp)
|
161 |
+
ref = re.sub(r"\W+", " ", ref)
|
162 |
+
|
163 |
+
outputs.append((hyp, ref))
|
164 |
+
|
165 |
+
print("")
|
166 |
+
print("=" * 100)
|
167 |
+
# print(f"{len(input_id)=}")
|
168 |
+
# print(f"{len(response)=}")
|
169 |
+
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
|
170 |
+
print(f"{filename=}")
|
171 |
+
|
172 |
+
return outputs
|
173 |
+
|
174 |
+
|
175 |
+
def load_asr_model():
|
176 |
+
import torch
|
177 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
178 |
+
|
179 |
+
rank = torch.distributed.get_rank()
|
180 |
+
device = f"cuda:{rank}"
|
181 |
+
torch_dtype = torch.float16
|
182 |
+
|
183 |
+
model_id = "/data/models/openai/whisper-large-v3"
|
184 |
+
|
185 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
186 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
187 |
+
)
|
188 |
+
model.to(device)
|
189 |
+
|
190 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
191 |
+
|
192 |
+
pipe = pipeline(
|
193 |
+
"automatic-speech-recognition",
|
194 |
+
model=model,
|
195 |
+
tokenizer=processor.tokenizer,
|
196 |
+
feature_extractor=processor.feature_extractor,
|
197 |
+
torch_dtype=torch_dtype,
|
198 |
+
device=device,
|
199 |
+
)
|
200 |
+
|
201 |
+
return pipe
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
parser = argparse.ArgumentParser(
|
206 |
+
description="",
|
207 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
208 |
+
)
|
209 |
+
|
210 |
+
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
|
211 |
+
parser.add_argument(
|
212 |
+
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
|
216 |
+
)
|
217 |
+
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
|
218 |
+
|
219 |
+
parser.add_argument("--json_path", type=str, required=True, help="json_path")
|
220 |
+
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
|
221 |
+
|
222 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
223 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
224 |
+
|
225 |
+
parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
|
226 |
+
|
227 |
+
args = parser.parse_args()
|
228 |
+
|
229 |
+
print(f"{args=}")
|
230 |
+
|
231 |
+
torch.distributed.init_process_group(
|
232 |
+
backend="nccl",
|
233 |
+
world_size=int(os.getenv("WORLD_SIZE", "1")),
|
234 |
+
rank=int(os.getenv("RANK", "0")),
|
235 |
+
timeout=timedelta(seconds=7200),
|
236 |
+
)
|
237 |
+
|
238 |
+
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
|
239 |
+
|
240 |
+
random.seed(42)
|
241 |
+
torch.manual_seed(42)
|
242 |
+
|
243 |
+
config = AutoConfig.from_pretrained(
|
244 |
+
args.model_name_or_path,
|
245 |
+
trust_remote_code=True,
|
246 |
+
)
|
247 |
+
|
248 |
+
# ================================================================
|
249 |
+
if "glm" in config.model_type.lower():
|
250 |
+
from get_chat_template import glm4_chat_template as chat_template
|
251 |
+
|
252 |
+
add_generation_prompt = True
|
253 |
+
|
254 |
+
default_system_message = [
|
255 |
+
{
|
256 |
+
"role": "system",
|
257 |
+
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
|
258 |
+
}
|
259 |
+
]
|
260 |
+
|
261 |
+
if "qwen2" in config.model_type.lower():
|
262 |
+
from get_chat_template import qwen2_chat_template as chat_template
|
263 |
+
|
264 |
+
add_generation_prompt = True
|
265 |
+
|
266 |
+
default_system_message = []
|
267 |
+
|
268 |
+
if "hunyuan" in config.model_type.lower():
|
269 |
+
from get_chat_template import hunyuan_chat_template as chat_template
|
270 |
+
|
271 |
+
add_generation_prompt = False
|
272 |
+
|
273 |
+
default_system_message = [
|
274 |
+
{
|
275 |
+
"role": "system",
|
276 |
+
"content": "You are a helpful AI assistant.",
|
277 |
+
}
|
278 |
+
]
|
279 |
+
|
280 |
+
# ================================================================
|
281 |
+
print("Loading model")
|
282 |
+
device = "cuda"
|
283 |
+
# device_map = "auto"
|
284 |
+
device_map = "cuda"
|
285 |
+
# torch_dtype=torch.float16
|
286 |
+
torch_dtype = torch.bfloat16
|
287 |
+
|
288 |
+
rank = torch.distributed.get_rank()
|
289 |
+
|
290 |
+
audio_tokenizer = get_audio_tokenizer(
|
291 |
+
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
|
292 |
+
)
|
293 |
+
|
294 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
295 |
+
args.model_name_or_path,
|
296 |
+
trust_remote_code=True,
|
297 |
+
chat_template=chat_template,
|
298 |
+
)
|
299 |
+
# print("tokenizer", tokenizer)
|
300 |
+
|
301 |
+
model = AutoModelForCausalLM.from_pretrained(
|
302 |
+
args.model_name_or_path,
|
303 |
+
trust_remote_code=True,
|
304 |
+
device_map=device_map,
|
305 |
+
torch_dtype=torch_dtype,
|
306 |
+
attn_implementation="flash_attention_2",
|
307 |
+
).eval()
|
308 |
+
# print("model", model)
|
309 |
+
|
310 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
311 |
+
args.model_name_or_path, trust_remote_code=True
|
312 |
+
)
|
313 |
+
|
314 |
+
model.generation_config.max_new_tokens = 4096
|
315 |
+
model.generation_config.chat_format = "chatml"
|
316 |
+
model.generation_config.max_window_size = 8192
|
317 |
+
model.generation_config.use_cache = True
|
318 |
+
model.generation_config.do_sample = True
|
319 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
320 |
+
if model.config.model_type == "hunyuan":
|
321 |
+
model.generation_config.eos_token_id = tokenizer.eos_id
|
322 |
+
|
323 |
+
asr_model = load_asr_model()
|
324 |
+
|
325 |
+
# ================================================================
|
326 |
+
print("Loading data")
|
327 |
+
dataset = TTSDataset(
|
328 |
+
json_path=args.json_path,
|
329 |
+
tokenizer=tokenizer,
|
330 |
+
audio_tokenizer=audio_tokenizer,
|
331 |
+
default_system_message=default_system_message,
|
332 |
+
add_generation_prompt=add_generation_prompt,
|
333 |
+
)
|
334 |
+
|
335 |
+
dataloader = torch.utils.data.DataLoader(
|
336 |
+
dataset=dataset,
|
337 |
+
sampler=InferenceSampler(len(dataset)),
|
338 |
+
batch_size=args.batch_size,
|
339 |
+
num_workers=args.num_workers,
|
340 |
+
pin_memory=True,
|
341 |
+
drop_last=False,
|
342 |
+
collate_fn=partial(
|
343 |
+
collate_fn,
|
344 |
+
),
|
345 |
+
)
|
346 |
+
|
347 |
+
# ================================================================
|
348 |
+
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
|
349 |
+
|
350 |
+
torch.distributed.barrier()
|
351 |
+
|
352 |
+
world_size = torch.distributed.get_world_size()
|
353 |
+
merged_outputs = [None for _ in range(world_size)]
|
354 |
+
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
|
355 |
+
|
356 |
+
merged_outputs = [json.loads(_) for _ in merged_outputs]
|
357 |
+
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
|
358 |
+
|
359 |
+
if torch.distributed.get_rank() == 0:
|
360 |
+
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
|
361 |
+
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
|
362 |
+
hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
|
363 |
+
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
|
364 |
+
|
365 |
+
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
366 |
+
os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
|
367 |
+
|
368 |
+
hyp_file = open(hyp_path, "w")
|
369 |
+
ref_file = open(ref_path, "w")
|
370 |
+
|
371 |
+
for sample_idx, (hyp, ref) in enumerate(merged_outputs):
|
372 |
+
hyp_file.write(f"{sample_idx} {hyp}" + "\n")
|
373 |
+
ref_file.write(f"{sample_idx} {ref}" + "\n")
|
374 |
+
|
375 |
+
hyp_file.close()
|
376 |
+
ref_file.close()
|
377 |
+
|
378 |
+
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
|
379 |
+
hyp_ref_file = open(hyp_ref_path, "w")
|
380 |
+
json.dump(merged_outputs, hyp_ref_file, indent=4)
|
381 |
+
hyp_ref_file.close()
|
382 |
+
|
383 |
+
torch.distributed.barrier()
|
384 |
+
print("Done.")
|
evaluation/evaluate_seedtts.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
from datetime import timedelta
|
9 |
+
from functools import partial
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from datasets import load_dataset
|
15 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
16 |
+
from transformers.generation import GenerationConfig
|
17 |
+
|
18 |
+
import torchaudio
|
19 |
+
from vita_audio.tokenizer import get_audio_tokenizer
|
20 |
+
|
21 |
+
|
22 |
+
def collate_fn(batches):
|
23 |
+
input_ids = [sample["input_ids"] for sample in batches]
|
24 |
+
|
25 |
+
refs = [sample["ref"] for sample in batches]
|
26 |
+
filenames = [sample["filename"] for sample in batches]
|
27 |
+
prompt_audio_path = [sample["prompt_audio_path"] for sample in batches]
|
28 |
+
|
29 |
+
return input_ids, refs, filenames, prompt_audio_path
|
30 |
+
|
31 |
+
|
32 |
+
class SeedTTSDataset(torch.utils.data.Dataset):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
data_path,
|
36 |
+
tokenizer,
|
37 |
+
audio_tokenizer,
|
38 |
+
default_system_message=None,
|
39 |
+
speaker_prompt=False,
|
40 |
+
add_generation_prompt=True,
|
41 |
+
):
|
42 |
+
self.data = []
|
43 |
+
|
44 |
+
meta_path = os.path.join(data_path, f"seedtts_testset/zh/meta.lst")
|
45 |
+
with open(meta_path, "r") as f:
|
46 |
+
lines = f.readlines()
|
47 |
+
|
48 |
+
for line in lines:
|
49 |
+
line = line.strip().split("|")
|
50 |
+
filename = line[0]
|
51 |
+
prompt_text = line[1]
|
52 |
+
prompt_audio = line[2]
|
53 |
+
text = line[3]
|
54 |
+
self.data.append(["zh", filename, prompt_text, prompt_audio, text])
|
55 |
+
|
56 |
+
meta_path = os.path.join(data_path, f"seedtts_testset/zh/hardcase.lst")
|
57 |
+
with open(meta_path, "r") as f:
|
58 |
+
lines = f.readlines()
|
59 |
+
|
60 |
+
for line in lines:
|
61 |
+
line = line.strip().split("|")
|
62 |
+
filename = line[0]
|
63 |
+
prompt_text = line[1]
|
64 |
+
prompt_audio = line[2]
|
65 |
+
text = line[3]
|
66 |
+
self.data.append(["hardcase", filename, prompt_text, prompt_audio, text])
|
67 |
+
|
68 |
+
meta_path = os.path.join(data_path, f"seedtts_testset/en/meta.lst")
|
69 |
+
with open(meta_path, "r") as f:
|
70 |
+
lines = f.readlines()
|
71 |
+
|
72 |
+
for line in lines:
|
73 |
+
line = line.strip().split("|")
|
74 |
+
filename = line[0]
|
75 |
+
prompt_text = line[1]
|
76 |
+
prompt_audio = line[2]
|
77 |
+
text = line[3]
|
78 |
+
self.data.append(["en", filename, prompt_text, prompt_audio, text])
|
79 |
+
|
80 |
+
self.tokenizer = tokenizer
|
81 |
+
self.audio_tokenizer = audio_tokenizer
|
82 |
+
self.default_system_message = default_system_message
|
83 |
+
self.add_generation_prompt = add_generation_prompt
|
84 |
+
|
85 |
+
self.data_path = data_path
|
86 |
+
self.speaker_prompt = speaker_prompt
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.data)
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
sample = self.data[idx]
|
93 |
+
|
94 |
+
split, filename, prompt_text, prompt_audio, text = sample
|
95 |
+
|
96 |
+
messages = []
|
97 |
+
|
98 |
+
if self.default_system_message is not None:
|
99 |
+
messages = self.default_system_message + messages
|
100 |
+
|
101 |
+
if self.speaker_prompt:
|
102 |
+
if split == "hardcase":
|
103 |
+
prompt_audio_path = os.path.join(
|
104 |
+
self.data_path, "seedtts_testset", "zh", prompt_audio
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
prompt_audio_path = os.path.join(
|
108 |
+
self.data_path, "seedtts_testset", split, prompt_audio
|
109 |
+
)
|
110 |
+
|
111 |
+
if self.audio_tokenizer.apply_to_role("system", is_discrete=True):
|
112 |
+
# discrete codec
|
113 |
+
prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path)
|
114 |
+
prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens)
|
115 |
+
|
116 |
+
prompt_text = f"Speaker Metadata:\nAudio: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n"
|
117 |
+
|
118 |
+
if len(messages) > 0 and messages[0]["role"] == "system":
|
119 |
+
messages[0]["content"] += prompt_text
|
120 |
+
|
121 |
+
else:
|
122 |
+
messages.append(
|
123 |
+
{
|
124 |
+
"role": "system",
|
125 |
+
"content": prompt_text,
|
126 |
+
}
|
127 |
+
)
|
128 |
+
else:
|
129 |
+
prompt_audio_path = None
|
130 |
+
|
131 |
+
role = "user"
|
132 |
+
content = "Convert the text to speech.\n" + text
|
133 |
+
messages.append(
|
134 |
+
{
|
135 |
+
"role": role,
|
136 |
+
"content": content,
|
137 |
+
}
|
138 |
+
)
|
139 |
+
|
140 |
+
input_ids = self.tokenizer.apply_chat_template(
|
141 |
+
messages,
|
142 |
+
tokenize=True,
|
143 |
+
add_generation_prompt=self.add_generation_prompt,
|
144 |
+
return_tensors="pt",
|
145 |
+
)
|
146 |
+
|
147 |
+
ref = text
|
148 |
+
|
149 |
+
return {
|
150 |
+
"input_ids": input_ids,
|
151 |
+
"ref": ref,
|
152 |
+
"filename": split + "/" + filename,
|
153 |
+
"prompt_audio_path": prompt_audio_path,
|
154 |
+
}
|
155 |
+
|
156 |
+
|
157 |
+
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
158 |
+
def __init__(self, size):
|
159 |
+
self._size = int(size)
|
160 |
+
assert size > 0
|
161 |
+
self._rank = torch.distributed.get_rank()
|
162 |
+
self._world_size = torch.distributed.get_world_size()
|
163 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
164 |
+
|
165 |
+
@staticmethod
|
166 |
+
def _get_local_indices(total_size, world_size, rank):
|
167 |
+
shard_size = total_size // world_size
|
168 |
+
left = total_size % world_size
|
169 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
170 |
+
|
171 |
+
begin = sum(shard_sizes[:rank])
|
172 |
+
end = min(sum(shard_sizes[: rank + 1]), total_size)
|
173 |
+
return range(begin, end)
|
174 |
+
|
175 |
+
def __iter__(self):
|
176 |
+
yield from self._local_indices
|
177 |
+
|
178 |
+
def __len__(self):
|
179 |
+
return len(self._local_indices)
|
180 |
+
|
181 |
+
|
182 |
+
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
|
183 |
+
|
184 |
+
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
|
185 |
+
|
186 |
+
outputs = []
|
187 |
+
|
188 |
+
for _, (
|
189 |
+
batched_input_ids,
|
190 |
+
batched_ref,
|
191 |
+
batched_filename,
|
192 |
+
batched_prompt_audio_path,
|
193 |
+
) in enumerate(tqdm.tqdm(dataloader)):
|
194 |
+
|
195 |
+
for input_ids, ref, filename, prompt_audio_path in zip(
|
196 |
+
batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path
|
197 |
+
):
|
198 |
+
responses = model.generate(
|
199 |
+
input_ids=input_ids.cuda(),
|
200 |
+
# temperature=0.2,
|
201 |
+
# top_p=0.8,
|
202 |
+
# do_sample=False,
|
203 |
+
# temperature=1.0,
|
204 |
+
max_new_tokens=1024,
|
205 |
+
min_new_tokens=1,
|
206 |
+
)
|
207 |
+
|
208 |
+
response = responses[0][len(input_ids[0]) :]
|
209 |
+
|
210 |
+
text_tokens = []
|
211 |
+
audio_tokens = []
|
212 |
+
for token_id in response:
|
213 |
+
if token_id >= audio_offset:
|
214 |
+
audio_tokens.append(token_id - audio_offset)
|
215 |
+
else:
|
216 |
+
text_tokens.append(token_id)
|
217 |
+
|
218 |
+
if len(audio_tokens) == 0:
|
219 |
+
continue
|
220 |
+
|
221 |
+
tts_speech = audio_tokenizer.decode(audio_tokens, source_speech_16k=prompt_audio_path)
|
222 |
+
|
223 |
+
wav_path = os.path.join(output_dir, filename + ".wav")
|
224 |
+
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
225 |
+
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
|
226 |
+
|
227 |
+
outputs.append((wav_path, filename))
|
228 |
+
|
229 |
+
print("")
|
230 |
+
print("=" * 100)
|
231 |
+
# print(f"{len(input_id)=}")
|
232 |
+
# print(f"{len(response)=}")
|
233 |
+
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
|
234 |
+
print(f"{filename=}")
|
235 |
+
|
236 |
+
return outputs
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
parser = argparse.ArgumentParser(
|
241 |
+
description="",
|
242 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
243 |
+
)
|
244 |
+
|
245 |
+
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
|
246 |
+
parser.add_argument(
|
247 |
+
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
|
248 |
+
)
|
249 |
+
parser.add_argument(
|
250 |
+
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
|
251 |
+
)
|
252 |
+
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
|
253 |
+
|
254 |
+
parser.add_argument("--data_path", type=str, required=True, help="data_path")
|
255 |
+
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
|
256 |
+
|
257 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
258 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
259 |
+
|
260 |
+
parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
|
261 |
+
|
262 |
+
args = parser.parse_args()
|
263 |
+
|
264 |
+
print(f"{args=}")
|
265 |
+
|
266 |
+
torch.distributed.init_process_group(
|
267 |
+
backend="nccl",
|
268 |
+
world_size=int(os.getenv("WORLD_SIZE", "1")),
|
269 |
+
rank=int(os.getenv("RANK", "0")),
|
270 |
+
timeout=timedelta(seconds=7200),
|
271 |
+
)
|
272 |
+
|
273 |
+
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
|
274 |
+
|
275 |
+
random.seed(42)
|
276 |
+
torch.manual_seed(42)
|
277 |
+
|
278 |
+
config = AutoConfig.from_pretrained(
|
279 |
+
args.model_name_or_path,
|
280 |
+
trust_remote_code=True,
|
281 |
+
)
|
282 |
+
|
283 |
+
# ================================================================
|
284 |
+
if "glm" in config.model_type.lower():
|
285 |
+
from get_chat_template import glm4_chat_template as chat_template
|
286 |
+
|
287 |
+
add_generation_prompt = True
|
288 |
+
|
289 |
+
default_system_message = [
|
290 |
+
{
|
291 |
+
"role": "system",
|
292 |
+
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
|
293 |
+
}
|
294 |
+
]
|
295 |
+
|
296 |
+
if "qwen2" in config.model_type.lower():
|
297 |
+
from get_chat_template import qwen2_chat_template as chat_template
|
298 |
+
|
299 |
+
add_generation_prompt = True
|
300 |
+
|
301 |
+
default_system_message = []
|
302 |
+
|
303 |
+
if "hunyuan" in config.model_type.lower():
|
304 |
+
from get_chat_template import hunyuan_chat_template as chat_template
|
305 |
+
|
306 |
+
add_generation_prompt = False
|
307 |
+
|
308 |
+
default_system_message = [
|
309 |
+
{
|
310 |
+
"role": "system",
|
311 |
+
"content": "You are a helpful AI assistant.",
|
312 |
+
}
|
313 |
+
]
|
314 |
+
|
315 |
+
# ================================================================
|
316 |
+
print("Loading model")
|
317 |
+
device = "cuda"
|
318 |
+
# device_map = "auto"
|
319 |
+
device_map = "cuda"
|
320 |
+
# torch_dtype=torch.float16
|
321 |
+
torch_dtype = torch.bfloat16
|
322 |
+
|
323 |
+
rank = torch.distributed.get_rank()
|
324 |
+
|
325 |
+
audio_tokenizer = get_audio_tokenizer(
|
326 |
+
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
|
327 |
+
)
|
328 |
+
|
329 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
330 |
+
args.model_name_or_path,
|
331 |
+
trust_remote_code=True,
|
332 |
+
chat_template=chat_template,
|
333 |
+
)
|
334 |
+
# print("tokenizer", tokenizer)
|
335 |
+
|
336 |
+
model = AutoModelForCausalLM.from_pretrained(
|
337 |
+
args.model_name_or_path,
|
338 |
+
trust_remote_code=True,
|
339 |
+
device_map=device_map,
|
340 |
+
torch_dtype=torch_dtype,
|
341 |
+
attn_implementation="flash_attention_2",
|
342 |
+
).eval()
|
343 |
+
# print("model", model)
|
344 |
+
|
345 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
346 |
+
args.model_name_or_path, trust_remote_code=True
|
347 |
+
)
|
348 |
+
|
349 |
+
model.generation_config.max_new_tokens = 4096
|
350 |
+
model.generation_config.chat_format = "chatml"
|
351 |
+
model.generation_config.max_window_size = 8192
|
352 |
+
model.generation_config.use_cache = True
|
353 |
+
model.generation_config.do_sample = True
|
354 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
355 |
+
if model.config.model_type == "hunyuan":
|
356 |
+
model.generation_config.eos_token_id = tokenizer.eos_id
|
357 |
+
|
358 |
+
# ================================================================
|
359 |
+
print("Loading data")
|
360 |
+
dataset = SeedTTSDataset(
|
361 |
+
data_path=args.data_path,
|
362 |
+
tokenizer=tokenizer,
|
363 |
+
audio_tokenizer=audio_tokenizer,
|
364 |
+
default_system_message=default_system_message,
|
365 |
+
speaker_prompt=args.speaker_prompt,
|
366 |
+
add_generation_prompt=add_generation_prompt,
|
367 |
+
)
|
368 |
+
|
369 |
+
dataloader = torch.utils.data.DataLoader(
|
370 |
+
dataset=dataset,
|
371 |
+
sampler=InferenceSampler(len(dataset)),
|
372 |
+
batch_size=args.batch_size,
|
373 |
+
num_workers=args.num_workers,
|
374 |
+
pin_memory=True,
|
375 |
+
drop_last=False,
|
376 |
+
collate_fn=partial(
|
377 |
+
collate_fn,
|
378 |
+
),
|
379 |
+
)
|
380 |
+
|
381 |
+
# ================================================================
|
382 |
+
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
|
383 |
+
|
384 |
+
torch.distributed.barrier()
|
385 |
+
|
386 |
+
world_size = torch.distributed.get_world_size()
|
387 |
+
merged_outputs = [None for _ in range(world_size)]
|
388 |
+
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
|
389 |
+
|
390 |
+
merged_outputs = [json.loads(_) for _ in merged_outputs]
|
391 |
+
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
|
392 |
+
|
393 |
+
torch.distributed.barrier()
|
394 |
+
print("Done.")
|
evaluation/evaluate_sqa.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
from datetime import timedelta
|
9 |
+
from functools import partial
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import tqdm
|
14 |
+
from datasets import load_dataset
|
15 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
16 |
+
from transformers.generation import GenerationConfig
|
17 |
+
|
18 |
+
import torchaudio
|
19 |
+
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
|
20 |
+
from vita_audio.tokenizer import get_audio_tokenizer
|
21 |
+
|
22 |
+
|
23 |
+
def collate_fn(batches):
|
24 |
+
input_ids = [sample["input_ids"] for sample in batches]
|
25 |
+
audios = [sample["audios"] for sample in batches]
|
26 |
+
audio_indices = [sample["audio_indices"] for sample in batches]
|
27 |
+
|
28 |
+
refs = [sample["ref"] for sample in batches]
|
29 |
+
filenames = [sample["filename"] for sample in batches]
|
30 |
+
|
31 |
+
return input_ids, audios, audio_indices, refs, filenames
|
32 |
+
|
33 |
+
|
34 |
+
class STSDataset(torch.utils.data.Dataset):
|
35 |
+
def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
|
36 |
+
data = load_dataset("json", data_files=json_path, keep_in_memory=False)
|
37 |
+
self.data = data["train"]
|
38 |
+
|
39 |
+
self.tokenizer = tokenizer
|
40 |
+
self.add_generation_prompt = add_generation_prompt
|
41 |
+
|
42 |
+
self.audio_tokenizer = audio_tokenizer
|
43 |
+
self.default_system_message = default_system_message
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.data)
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
sample = self.data[idx]
|
50 |
+
|
51 |
+
assert len(sample["audios"]) == 1
|
52 |
+
|
53 |
+
audio_path = sample["audios"][0]
|
54 |
+
|
55 |
+
if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
|
56 |
+
# discrete codec
|
57 |
+
audio_tokens = self.audio_tokenizer.encode(audio_path)
|
58 |
+
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
|
59 |
+
else:
|
60 |
+
audio_tokens = None
|
61 |
+
|
62 |
+
messages = []
|
63 |
+
|
64 |
+
if len(sample["messages"]) == 2:
|
65 |
+
assert len(sample["messages"]) == 2
|
66 |
+
assert sample["messages"][0]["role"] == "user"
|
67 |
+
assert sample["messages"][1]["role"] == "assistant"
|
68 |
+
|
69 |
+
if self.default_system_message is not None:
|
70 |
+
messages = self.default_system_message + messages
|
71 |
+
|
72 |
+
elif len(sample["messages"]) == 3:
|
73 |
+
assert len(sample["messages"]) == 3
|
74 |
+
assert sample["messages"][0]["role"] == "system"
|
75 |
+
assert sample["messages"][1]["role"] == "user"
|
76 |
+
assert sample["messages"][2]["role"] == "assistant"
|
77 |
+
|
78 |
+
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
|
81 |
+
for conv in sample["messages"][:-1]:
|
82 |
+
new_conv = {}
|
83 |
+
new_conv["role"] = conv["role"]
|
84 |
+
|
85 |
+
content = conv["content"]
|
86 |
+
if isinstance(content, list):
|
87 |
+
assert len(content) == 1
|
88 |
+
content = content[0]
|
89 |
+
|
90 |
+
if audio_tokens is not None:
|
91 |
+
content = content.replace(
|
92 |
+
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
|
93 |
+
)
|
94 |
+
|
95 |
+
new_conv["content"] = content
|
96 |
+
messages.append(new_conv)
|
97 |
+
|
98 |
+
input_ids = self.tokenizer.apply_chat_template(
|
99 |
+
messages,
|
100 |
+
tokenize=True,
|
101 |
+
add_generation_prompt=self.add_generation_prompt,
|
102 |
+
# return_tensors="pt",
|
103 |
+
)
|
104 |
+
|
105 |
+
ref = sample["messages"][-1]["content"]
|
106 |
+
|
107 |
+
if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
|
108 |
+
# contiguous codec
|
109 |
+
input_ids, audios, audio_indices = add_audio_input_contiguous(
|
110 |
+
input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
audios = None
|
114 |
+
audio_indices = None
|
115 |
+
|
116 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
117 |
+
|
118 |
+
filename = os.path.basename(audio_path)
|
119 |
+
filename = os.path.splitext(filename)[0]
|
120 |
+
|
121 |
+
return {
|
122 |
+
"input_ids": input_ids,
|
123 |
+
"audios": audios,
|
124 |
+
"audio_indices": audio_indices,
|
125 |
+
"ref": ref,
|
126 |
+
"filename": filename,
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
class InferenceSampler(torch.utils.data.sampler.Sampler):
|
131 |
+
def __init__(self, size):
|
132 |
+
self._size = int(size)
|
133 |
+
assert size > 0
|
134 |
+
self._rank = torch.distributed.get_rank()
|
135 |
+
self._world_size = torch.distributed.get_world_size()
|
136 |
+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def _get_local_indices(total_size, world_size, rank):
|
140 |
+
shard_size = total_size // world_size
|
141 |
+
left = total_size % world_size
|
142 |
+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
|
143 |
+
|
144 |
+
begin = sum(shard_sizes[:rank])
|
145 |
+
end = min(sum(shard_sizes[: rank + 1]), total_size)
|
146 |
+
return range(begin, end)
|
147 |
+
|
148 |
+
def __iter__(self):
|
149 |
+
yield from self._local_indices
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self._local_indices)
|
153 |
+
|
154 |
+
|
155 |
+
def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
|
156 |
+
|
157 |
+
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
|
158 |
+
|
159 |
+
outputs = []
|
160 |
+
|
161 |
+
for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename) in enumerate(
|
162 |
+
tqdm.tqdm(dataloader)
|
163 |
+
):
|
164 |
+
for input_ids, audios, audio_indices, ref, filename in zip(
|
165 |
+
batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename
|
166 |
+
):
|
167 |
+
|
168 |
+
responses = model.generate(
|
169 |
+
input_ids=input_ids.cuda(),
|
170 |
+
audios=audios,
|
171 |
+
audio_indices=audio_indices,
|
172 |
+
# temperature=0.2,
|
173 |
+
# top_p=0.8,
|
174 |
+
# do_sample=False,
|
175 |
+
# temperature=1.0,
|
176 |
+
max_new_tokens=1024,
|
177 |
+
min_new_tokens=1,
|
178 |
+
)
|
179 |
+
|
180 |
+
response = responses[0][len(input_ids[0]) :]
|
181 |
+
|
182 |
+
text_tokens = []
|
183 |
+
audio_tokens = []
|
184 |
+
for token_id in response:
|
185 |
+
if token_id >= audio_offset:
|
186 |
+
audio_tokens.append(token_id - audio_offset)
|
187 |
+
else:
|
188 |
+
text_tokens.append(token_id)
|
189 |
+
|
190 |
+
hyp_text = tokenizer.decode(text_tokens, skip_special_tokens=True)
|
191 |
+
|
192 |
+
if len(audio_tokens) == 0:
|
193 |
+
continue
|
194 |
+
|
195 |
+
tts_speech = audio_tokenizer.decode(audio_tokens)
|
196 |
+
|
197 |
+
wav_dir = os.path.join(output_dir, "audio")
|
198 |
+
wav_path = os.path.join(wav_dir, filename + ".wav")
|
199 |
+
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
200 |
+
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
|
201 |
+
|
202 |
+
# hyp_speech = asr_model.transcribe(wav_path)["text"].strip()
|
203 |
+
hyp_speech = asr_model(wav_path, return_timestamps=True)["text"].strip()
|
204 |
+
# hyp_speech = ""
|
205 |
+
|
206 |
+
outputs.append((hyp_text, hyp_speech, ref))
|
207 |
+
|
208 |
+
print("")
|
209 |
+
print("=" * 100)
|
210 |
+
print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
|
211 |
+
print(f" {hyp_text=}")
|
212 |
+
print(f"{hyp_speech=}")
|
213 |
+
print(f" {ref=}")
|
214 |
+
print(f"{filename=}")
|
215 |
+
|
216 |
+
return outputs
|
217 |
+
|
218 |
+
|
219 |
+
def load_asr_model():
|
220 |
+
import torch
|
221 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
222 |
+
|
223 |
+
rank = torch.distributed.get_rank()
|
224 |
+
device = f"cuda:{rank}"
|
225 |
+
torch_dtype = torch.float16
|
226 |
+
|
227 |
+
model_id = "/data/models/openai/whisper-large-v3"
|
228 |
+
|
229 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
230 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
231 |
+
)
|
232 |
+
model.to(device)
|
233 |
+
|
234 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
235 |
+
|
236 |
+
pipe = pipeline(
|
237 |
+
"automatic-speech-recognition",
|
238 |
+
model=model,
|
239 |
+
tokenizer=processor.tokenizer,
|
240 |
+
feature_extractor=processor.feature_extractor,
|
241 |
+
torch_dtype=torch_dtype,
|
242 |
+
device=device,
|
243 |
+
)
|
244 |
+
|
245 |
+
return pipe
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
parser = argparse.ArgumentParser(
|
250 |
+
description="",
|
251 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
252 |
+
)
|
253 |
+
|
254 |
+
parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
|
255 |
+
parser.add_argument(
|
256 |
+
"--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
|
257 |
+
)
|
258 |
+
parser.add_argument(
|
259 |
+
"--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
|
260 |
+
)
|
261 |
+
parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
|
262 |
+
|
263 |
+
parser.add_argument("--json_path", type=str, required=True, help="json_path")
|
264 |
+
parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
|
265 |
+
|
266 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
267 |
+
parser.add_argument("--num_workers", type=int, default=0)
|
268 |
+
|
269 |
+
args = parser.parse_args()
|
270 |
+
|
271 |
+
print(f"{args=}")
|
272 |
+
|
273 |
+
torch.distributed.init_process_group(
|
274 |
+
backend="nccl",
|
275 |
+
world_size=int(os.getenv("WORLD_SIZE", "1")),
|
276 |
+
rank=int(os.getenv("RANK", "0")),
|
277 |
+
timeout=timedelta(seconds=7200),
|
278 |
+
)
|
279 |
+
|
280 |
+
torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
|
281 |
+
|
282 |
+
random.seed(42)
|
283 |
+
torch.manual_seed(42)
|
284 |
+
|
285 |
+
config = AutoConfig.from_pretrained(
|
286 |
+
args.model_name_or_path,
|
287 |
+
trust_remote_code=True,
|
288 |
+
)
|
289 |
+
|
290 |
+
# ================================================================
|
291 |
+
if "glm" in config.model_type.lower():
|
292 |
+
from get_chat_template import glm4_chat_template as chat_template
|
293 |
+
|
294 |
+
add_generation_prompt = True
|
295 |
+
|
296 |
+
default_system_message = [
|
297 |
+
{
|
298 |
+
"role": "system",
|
299 |
+
"content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
|
300 |
+
}
|
301 |
+
]
|
302 |
+
|
303 |
+
if "qwen2" in config.model_type.lower():
|
304 |
+
from get_chat_template import qwen2_chat_template as chat_template
|
305 |
+
|
306 |
+
add_generation_prompt = True
|
307 |
+
|
308 |
+
default_system_message = []
|
309 |
+
|
310 |
+
if "hunyuan" in config.model_type.lower():
|
311 |
+
from get_chat_template import hunyuan_chat_template as chat_template
|
312 |
+
|
313 |
+
add_generation_prompt = False
|
314 |
+
|
315 |
+
default_system_message = [
|
316 |
+
{
|
317 |
+
"role": "system",
|
318 |
+
"content": "You are a helpful AI assistant.",
|
319 |
+
}
|
320 |
+
]
|
321 |
+
|
322 |
+
default_system_message = [
|
323 |
+
{
|
324 |
+
"role": "system",
|
325 |
+
# "content": "Your Name: Luke\nYour Gender: male\nRespond in a text-audio interleaved manner.",
|
326 |
+
# "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.",
|
327 |
+
"content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.",
|
328 |
+
},
|
329 |
+
]
|
330 |
+
|
331 |
+
# ================================================================
|
332 |
+
print("Loading model")
|
333 |
+
device = "cuda"
|
334 |
+
# device_map = "auto"
|
335 |
+
device_map = "cuda"
|
336 |
+
# torch_dtype=torch.float16
|
337 |
+
torch_dtype = torch.bfloat16
|
338 |
+
|
339 |
+
rank = torch.distributed.get_rank()
|
340 |
+
|
341 |
+
audio_tokenizer = get_audio_tokenizer(
|
342 |
+
args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
|
343 |
+
)
|
344 |
+
|
345 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
346 |
+
args.model_name_or_path,
|
347 |
+
trust_remote_code=True,
|
348 |
+
chat_template=chat_template,
|
349 |
+
)
|
350 |
+
# print("tokenizer", tokenizer)
|
351 |
+
|
352 |
+
model = AutoModelForCausalLM.from_pretrained(
|
353 |
+
args.model_name_or_path,
|
354 |
+
trust_remote_code=True,
|
355 |
+
device_map=device_map,
|
356 |
+
torch_dtype=torch_dtype,
|
357 |
+
attn_implementation="flash_attention_2",
|
358 |
+
).eval()
|
359 |
+
# print("model", model)
|
360 |
+
|
361 |
+
model.generation_config = GenerationConfig.from_pretrained(
|
362 |
+
args.model_name_or_path, trust_remote_code=True
|
363 |
+
)
|
364 |
+
|
365 |
+
model.generation_config.max_new_tokens = 4096
|
366 |
+
model.generation_config.chat_format = "chatml"
|
367 |
+
model.generation_config.max_window_size = 8192
|
368 |
+
model.generation_config.use_cache = True
|
369 |
+
model.generation_config.do_sample = False
|
370 |
+
model.generation_config.temperature = None
|
371 |
+
model.generation_config.top_p = None
|
372 |
+
model.generation_config.top_k = None
|
373 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
374 |
+
if model.config.model_type == "hunyuan":
|
375 |
+
model.generation_config.eos_token_id = tokenizer.eos_id
|
376 |
+
|
377 |
+
asr_model = load_asr_model()
|
378 |
+
|
379 |
+
# ================================================================
|
380 |
+
print("Loading data")
|
381 |
+
dataset = STSDataset(
|
382 |
+
json_path=args.json_path,
|
383 |
+
tokenizer=tokenizer,
|
384 |
+
audio_tokenizer=audio_tokenizer,
|
385 |
+
default_system_message=default_system_message,
|
386 |
+
add_generation_prompt=add_generation_prompt,
|
387 |
+
)
|
388 |
+
|
389 |
+
dataloader = torch.utils.data.DataLoader(
|
390 |
+
dataset=dataset,
|
391 |
+
sampler=InferenceSampler(len(dataset)),
|
392 |
+
batch_size=args.batch_size,
|
393 |
+
num_workers=args.num_workers,
|
394 |
+
pin_memory=True,
|
395 |
+
drop_last=False,
|
396 |
+
collate_fn=partial(
|
397 |
+
collate_fn,
|
398 |
+
),
|
399 |
+
)
|
400 |
+
|
401 |
+
# ================================================================
|
402 |
+
outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
|
403 |
+
|
404 |
+
torch.distributed.barrier()
|
405 |
+
|
406 |
+
world_size = torch.distributed.get_world_size()
|
407 |
+
merged_outputs = [None for _ in range(world_size)]
|
408 |
+
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
|
409 |
+
|
410 |
+
merged_outputs = [json.loads(_) for _ in merged_outputs]
|
411 |
+
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
|
412 |
+
|
413 |
+
if torch.distributed.get_rank() == 0:
|
414 |
+
# json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
|
415 |
+
json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
|
416 |
+
hyp_text_path = os.path.join(args.output_dir, f"{json_name}_hyp_text.txt")
|
417 |
+
hyp_speech_path = os.path.join(args.output_dir, f"{json_name}_hyp_speech.txt")
|
418 |
+
ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
|
419 |
+
|
420 |
+
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
|
421 |
+
os.makedirs(os.path.dirname(hyp_text_path), exist_ok=True)
|
422 |
+
os.makedirs(os.path.dirname(hyp_speech_path), exist_ok=True)
|
423 |
+
|
424 |
+
hyp_text_file = open(hyp_text_path, "w")
|
425 |
+
hyp_speech_file = open(hyp_speech_path, "w")
|
426 |
+
ref_file = open(ref_path, "w")
|
427 |
+
|
428 |
+
for sample_idx, (hyp_text, hyp_speech, ref) in enumerate(merged_outputs):
|
429 |
+
hyp_text_file.write(f"{sample_idx} {hyp_text}" + "\n")
|
430 |
+
hyp_speech_file.write(f"{sample_idx} {hyp_speech}" + "\n")
|
431 |
+
ref_file.write(f"{sample_idx} {ref}" + "\n")
|
432 |
+
|
433 |
+
hyp_text_file.close()
|
434 |
+
hyp_speech_file.close()
|
435 |
+
ref_file.close()
|
436 |
+
|
437 |
+
outputs_speech = [[x[1], x[2]] for x in merged_outputs]
|
438 |
+
outputs_text = [[x[0], x[2]] for x in merged_outputs]
|
439 |
+
|
440 |
+
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_text.json")
|
441 |
+
hyp_ref_file = open(hyp_ref_path, "w")
|
442 |
+
json.dump(outputs_text, hyp_ref_file, indent=4)
|
443 |
+
hyp_ref_file.close()
|
444 |
+
|
445 |
+
hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_speech.json")
|
446 |
+
hyp_ref_file = open(hyp_ref_path, "w")
|
447 |
+
json.dump(outputs_speech, hyp_ref_file, indent=4)
|
448 |
+
hyp_ref_file.close()
|
449 |
+
|
450 |
+
torch.distributed.barrier()
|
451 |
+
print("Done.")
|
evaluation/get_chat_template.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
qwen2_chat_template = """
|
2 |
+
{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n
|
3 |
+
"""
|
4 |
+
|
5 |
+
qwen3_chat_template = """
|
6 |
+
"{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
7 |
+
"""
|
8 |
+
|
9 |
+
hunyuan_chat_template = """
|
10 |
+
{% set context = {'has_head': true} %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = message['content'] %}{% if loop.index0 == 0 %}{% if content == '' %}{% set _ = context.update({'has_head': false}) %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_4|>' %}{% endif %}{% endif %}{% if message['role'] == 'user' %}{% if loop.index0 == 1 and not context.has_head %}{% set content = '<|startoftext|>' + content %}{% endif %}{% if loop.index0 == 1 and context.has_head %}{% set content = content + '<|extra_0|>' %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_0|>' %}{% endif %}{% elif message['role'] == 'assistant' %}{% set content = content + '<|eos|>' %}{% endif %}{{ content }}{% endfor %}
|
11 |
+
"""
|
12 |
+
|
13 |
+
glm4_chat_template = """
|
14 |
+
{%- for message in messages %} {%- if (message.role == "system") %} {{- '<|system|>' + '\n' + message.content }} {%- elif (message.role == "user") %} {{- '<|user|>' + '\n' + message.content }} {%- elif message.role == "assistant" %} {{- '<|assistant|>' }} {%- if message.content %} {{- 'streaming_transcription\n' + message.content }} {%- endif %} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} {{- '<|assistant|>streaming_transcription\n' }} {%- endif %}
|
15 |
+
"""
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
from transformers import AutoTokenizer
|
19 |
+
|
20 |
+
chat = [
|
21 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
22 |
+
{"role": "user", "content": "Hello, how are you?"},
|
23 |
+
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
24 |
+
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
25 |
+
]
|
26 |
+
|
27 |
+
# print("=" * 100)
|
28 |
+
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Instruct-2407")
|
29 |
+
# print(tokenizer.get_chat_template())
|
30 |
+
# message = tokenizer.apply_chat_template(chat, tokenize=False)
|
31 |
+
# print(message)
|
32 |
+
|
33 |
+
print("=" * 100)
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B-Instruct")
|
35 |
+
print(tokenizer.get_chat_template())
|
36 |
+
message = tokenizer.apply_chat_template(chat, tokenize=False)
|
37 |
+
print(message)
|
38 |
+
|
39 |
+
print("=" * 100)
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-70B-Instruct")
|
41 |
+
print(tokenizer.get_chat_template())
|
42 |
+
message = tokenizer.apply_chat_template(chat, tokenize=False)
|
43 |
+
print(message)
|
44 |
+
|
45 |
+
print("=" * 100)
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
|
47 |
+
print(tokenizer.get_chat_template())
|
48 |
+
message = tokenizer.apply_chat_template(chat, tokenize=False)
|
49 |
+
print(message)
|
50 |
+
message = tokenizer.apply_chat_template(chat, tokenize=True)
|
51 |
+
print(message)
|
52 |
+
|
53 |
+
print("=" * 100)
|
54 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
55 |
+
print(tokenizer.get_chat_template())
|
56 |
+
message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=True)
|
57 |
+
print(message)
|
58 |
+
message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
59 |
+
print(message)
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-r requirements_ds_gpu.txt
|
requirements_ds_gpu.txt
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
expecttest
|
2 |
+
peft
|
3 |
+
xlsxwriter
|
4 |
+
termcolor
|
5 |
+
tabulate
|
6 |
+
tiktoken
|
7 |
+
matplotlib
|
8 |
+
datasets
|
9 |
+
einops
|
10 |
+
pybind11
|
11 |
+
tensorboardX
|
12 |
+
pyarrow
|
13 |
+
transformers==4.48.3
|
14 |
+
deepspeed
|
15 |
+
accelerate>=1.1.1
|
16 |
+
timm
|
17 |
+
flask
|
18 |
+
flask_restful
|
19 |
+
decord
|
20 |
+
natsort
|
21 |
+
# setuptools==69.5.1
|
22 |
+
setuptools
|
23 |
+
|
24 |
+
# cosyvoice2
|
25 |
+
pyworld
|
26 |
+
evaluate
|
27 |
+
hyperpyyaml
|
28 |
+
diffusers
|
29 |
+
conformer
|
30 |
+
hydra-core
|
31 |
+
lightning
|
32 |
+
gdown
|
33 |
+
wget
|
34 |
+
funasr
|
35 |
+
zhconv
|
36 |
+
jiwer
|
37 |
+
zhon
|
38 |
+
WeTextProcessing
|
39 |
+
inflect
|
40 |
+
openai-whisper
|
41 |
+
onnxruntime
|
42 |
+
modelscope
|
43 |
+
word2number
|
44 |
+
|
scripts/deepspeed/ds_config_zero1.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
|
23 |
+
"scheduler": {
|
24 |
+
"type": "WarmupCosineLR",
|
25 |
+
"params": {
|
26 |
+
"total_num_steps": "auto",
|
27 |
+
"warmup_min_ratio": 0,
|
28 |
+
"warmup_num_steps": "auto",
|
29 |
+
"cos_min_ratio": 0.1
|
30 |
+
}
|
31 |
+
},
|
32 |
+
|
33 |
+
"zero_optimization": {
|
34 |
+
"stage": 1,
|
35 |
+
"offload_optimizer": {
|
36 |
+
"device": "none",
|
37 |
+
"pin_memory": true
|
38 |
+
},
|
39 |
+
"offload_param": {
|
40 |
+
"device": "none",
|
41 |
+
"pin_memory": true
|
42 |
+
},
|
43 |
+
"allgather_partitions": true,
|
44 |
+
"allgather_bucket_size": 5e8,
|
45 |
+
"overlap_comm": true,
|
46 |
+
"reduce_scatter": true,
|
47 |
+
"reduce_bucket_size": 5e8,
|
48 |
+
"contiguous_gradients": true,
|
49 |
+
"round_robin_gradients": true,
|
50 |
+
"sub_group_size": 1e12
|
51 |
+
},
|
52 |
+
|
53 |
+
"gradient_accumulation_steps": "auto",
|
54 |
+
"gradient_clipping": "auto",
|
55 |
+
"steps_per_print": 100,
|
56 |
+
"train_batch_size": "auto",
|
57 |
+
"train_micro_batch_size_per_gpu": "auto",
|
58 |
+
"wall_clock_breakdown": false,
|
59 |
+
"dump_state": false
|
60 |
+
|
61 |
+
}
|
scripts/deepspeed/ds_config_zero2.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
|
23 |
+
"scheduler": {
|
24 |
+
"type": "WarmupCosineLR",
|
25 |
+
"params": {
|
26 |
+
"total_num_steps": "auto",
|
27 |
+
"warmup_min_ratio": 0,
|
28 |
+
"warmup_num_steps": "auto",
|
29 |
+
"cos_min_ratio": 0.1
|
30 |
+
}
|
31 |
+
},
|
32 |
+
|
33 |
+
"zero_optimization": {
|
34 |
+
"stage": 2,
|
35 |
+
"offload_optimizer": {
|
36 |
+
"device": "none",
|
37 |
+
"pin_memory": true
|
38 |
+
},
|
39 |
+
"offload_param": {
|
40 |
+
"device": "none",
|
41 |
+
"pin_memory": true
|
42 |
+
},
|
43 |
+
"allgather_partitions": true,
|
44 |
+
"allgather_bucket_size": 5e8,
|
45 |
+
"overlap_comm": true,
|
46 |
+
"reduce_scatter": true,
|
47 |
+
"reduce_bucket_size": 5e8,
|
48 |
+
"contiguous_gradients": true,
|
49 |
+
"round_robin_gradients": true,
|
50 |
+
"sub_group_size": 1e12
|
51 |
+
},
|
52 |
+
|
53 |
+
"gradient_accumulation_steps": "auto",
|
54 |
+
"gradient_clipping": "auto",
|
55 |
+
"steps_per_print": 100,
|
56 |
+
"train_batch_size": "auto",
|
57 |
+
"train_micro_batch_size_per_gpu": "auto",
|
58 |
+
"wall_clock_breakdown": false,
|
59 |
+
"dump_state": false
|
60 |
+
|
61 |
+
}
|
scripts/deepspeed/ds_config_zero2_no_optimizer.json
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupCosineLR",
|
16 |
+
"params": {
|
17 |
+
"total_num_steps": "auto",
|
18 |
+
"warmup_min_ratio": 0,
|
19 |
+
"warmup_num_steps": "auto",
|
20 |
+
"cos_min_ratio": 0.1
|
21 |
+
}
|
22 |
+
},
|
23 |
+
|
24 |
+
"zero_optimization": {
|
25 |
+
"stage": 2,
|
26 |
+
"offload_optimizer": {
|
27 |
+
"device": "none",
|
28 |
+
"pin_memory": true
|
29 |
+
},
|
30 |
+
"offload_param": {
|
31 |
+
"device": "none",
|
32 |
+
"pin_memory": true
|
33 |
+
},
|
34 |
+
"allgather_partitions": true,
|
35 |
+
"allgather_bucket_size": 5e8,
|
36 |
+
"overlap_comm": true,
|
37 |
+
"reduce_scatter": true,
|
38 |
+
"reduce_bucket_size": 5e8,
|
39 |
+
"contiguous_gradients": true,
|
40 |
+
"round_robin_gradients": true,
|
41 |
+
"sub_group_size": 1e12
|
42 |
+
},
|
43 |
+
|
44 |
+
"gradient_accumulation_steps": "auto",
|
45 |
+
"gradient_clipping": "auto",
|
46 |
+
"steps_per_print": 100,
|
47 |
+
"train_batch_size": "auto",
|
48 |
+
"train_micro_batch_size_per_gpu": "auto",
|
49 |
+
"wall_clock_breakdown": false,
|
50 |
+
"dump_state": false
|
51 |
+
|
52 |
+
}
|
scripts/deepspeed/ds_config_zero2_offload.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
|
23 |
+
"scheduler": {
|
24 |
+
"type": "WarmupCosineLR",
|
25 |
+
"params": {
|
26 |
+
"total_num_steps": "auto",
|
27 |
+
"warmup_min_ratio": 0,
|
28 |
+
"warmup_num_steps": "auto",
|
29 |
+
"cos_min_ratio": 0.1
|
30 |
+
}
|
31 |
+
},
|
32 |
+
|
33 |
+
"zero_optimization": {
|
34 |
+
"stage": 2,
|
35 |
+
"offload_optimizer": {
|
36 |
+
"device": "cpu",
|
37 |
+
"pin_memory": true
|
38 |
+
},
|
39 |
+
"offload_param": {
|
40 |
+
"device": "cpu",
|
41 |
+
"pin_memory": true
|
42 |
+
},
|
43 |
+
"allgather_partitions": true,
|
44 |
+
"allgather_bucket_size": 5e8,
|
45 |
+
"overlap_comm": true,
|
46 |
+
"reduce_scatter": true,
|
47 |
+
"reduce_bucket_size": 5e8,
|
48 |
+
"contiguous_gradients": true,
|
49 |
+
"round_robin_gradients": true,
|
50 |
+
"sub_group_size": 1e12
|
51 |
+
},
|
52 |
+
|
53 |
+
"gradient_accumulation_steps": "auto",
|
54 |
+
"gradient_clipping": "auto",
|
55 |
+
"steps_per_print": 100,
|
56 |
+
"train_batch_size": "auto",
|
57 |
+
"train_micro_batch_size_per_gpu": "auto",
|
58 |
+
"wall_clock_breakdown": false,
|
59 |
+
"dump_state": false
|
60 |
+
|
61 |
+
}
|
scripts/deepspeed/ds_config_zero3.json
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
|
23 |
+
"scheduler": {
|
24 |
+
"type": "WarmupCosineLR",
|
25 |
+
"params": {
|
26 |
+
"total_num_steps": "auto",
|
27 |
+
"warmup_min_ratio": 0,
|
28 |
+
"warmup_num_steps": "auto",
|
29 |
+
"cos_min_ratio": 0.1
|
30 |
+
}
|
31 |
+
},
|
32 |
+
|
33 |
+
"zero_optimization": {
|
34 |
+
"stage": 3,
|
35 |
+
"offload_optimizer": {
|
36 |
+
"device": "none",
|
37 |
+
"pin_memory": true
|
38 |
+
},
|
39 |
+
"offload_param": {
|
40 |
+
"device": "none",
|
41 |
+
"pin_memory": true
|
42 |
+
},
|
43 |
+
"allgather_partitions": true,
|
44 |
+
"allgather_bucket_size": 2e8,
|
45 |
+
"overlap_comm": true,
|
46 |
+
"reduce_scatter": true,
|
47 |
+
"reduce_bucket_size": 1e9,
|
48 |
+
"contiguous_gradients": true,
|
49 |
+
"sub_group_size": 1e9,
|
50 |
+
"stage3_prefetch_bucket_size": 1e9,
|
51 |
+
"stage3_param_persistence_threshold": 1e9,
|
52 |
+
"stage3_max_live_parameters": 1e9,
|
53 |
+
"stage3_max_reuse_distance": 1e9,
|
54 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
55 |
+
},
|
56 |
+
|
57 |
+
"gradient_accumulation_steps": "auto",
|
58 |
+
"gradient_clipping": "auto",
|
59 |
+
"steps_per_print": 100,
|
60 |
+
"train_batch_size": "auto",
|
61 |
+
"train_micro_batch_size_per_gpu": "auto",
|
62 |
+
"wall_clock_breakdown": false
|
63 |
+
}
|
scripts/deepspeed/ds_config_zero3_offload.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"fp16": {
|
3 |
+
"enabled": "auto",
|
4 |
+
"loss_scale": 0,
|
5 |
+
"loss_scale_window": 1000,
|
6 |
+
"initial_scale_power": 16,
|
7 |
+
"hysteresis": 2,
|
8 |
+
"min_loss_scale": 1
|
9 |
+
},
|
10 |
+
"bf16": {
|
11 |
+
"enabled": "auto"
|
12 |
+
},
|
13 |
+
"optimizer": {
|
14 |
+
"type": "AdamW",
|
15 |
+
"params": {
|
16 |
+
"lr": "auto",
|
17 |
+
"betas": "auto",
|
18 |
+
"eps": "auto",
|
19 |
+
"weight_decay": "auto"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
|
23 |
+
"scheduler": {
|
24 |
+
"type": "WarmupCosineLR",
|
25 |
+
"params": {
|
26 |
+
"total_num_steps": "auto",
|
27 |
+
"warmup_min_ratio": 0,
|
28 |
+
"warmup_num_steps": "auto",
|
29 |
+
"cos_min_ratio": 0.1
|
30 |
+
}
|
31 |
+
},
|
32 |
+
|
33 |
+
"zero_optimization": {
|
34 |
+
"stage": 3,
|
35 |
+
"offload_optimizer": {
|
36 |
+
"device": "cpu",
|
37 |
+
"pin_memory": true
|
38 |
+
},
|
39 |
+
"offload_param": {
|
40 |
+
"device": "cpu",
|
41 |
+
"pin_memory": true
|
42 |
+
},
|
43 |
+
"allgather_partitions": true,
|
44 |
+
"allgather_bucket_size": 5e8,
|
45 |
+
"overlap_comm": true,
|
46 |
+
"reduce_scatter": true,
|
47 |
+
"reduce_bucket_size": 5e8,
|
48 |
+
"contiguous_gradients": true,
|
49 |
+
"round_robin_gradients": true,
|
50 |
+
"sub_group_size": 1e12,
|
51 |
+
"stage3_prefetch_bucket_size": 5e8,
|
52 |
+
"stage3_param_persistence_threshold": 1e5,
|
53 |
+
"stage3_max_live_parameters": 1e9,
|
54 |
+
"stage3_max_reuse_distance": 1e9,
|
55 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
56 |
+
},
|
57 |
+
|
58 |
+
"flops_profiler": {
|
59 |
+
"enabled": false,
|
60 |
+
"profile_step": 1,
|
61 |
+
"module_depth": -1,
|
62 |
+
"top_modules": 1,
|
63 |
+
"detailed": true,
|
64 |
+
"output_file": null
|
65 |
+
},
|
66 |
+
|
67 |
+
"gradient_accumulation_steps": "auto",
|
68 |
+
"gradient_clipping": "auto",
|
69 |
+
"steps_per_print": 100,
|
70 |
+
"train_batch_size": "auto",
|
71 |
+
"train_micro_batch_size_per_gpu": "auto",
|
72 |
+
"wall_clock_breakdown": false,
|
73 |
+
"dump_state": false
|
74 |
+
|
75 |
+
}
|
scripts/deepspeed/evaluate_sts.sh
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H%M%S'`
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt install -y rsync
|
28 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
export HF_ENDPOINT=https://hf-mirror.com
|
50 |
+
|
51 |
+
export MODELSCOPE_CACHE="${ROOT_PATH}/data/MODELSCOPE_CACHE/"
|
52 |
+
mkdir -p ${MODELSCOPE_CACHE}
|
53 |
+
|
54 |
+
export LC_ALL="en_US.utf8"
|
55 |
+
|
56 |
+
######################################################################
|
57 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
58 |
+
exec &> >(tee -a "$LOG")
|
59 |
+
echo Logging output to "$LOG"
|
60 |
+
|
61 |
+
|
62 |
+
######################################################################
|
63 |
+
if true
|
64 |
+
#if false
|
65 |
+
then
|
66 |
+
MODEL_NAME_OR_PATH="/data/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh/VITA-Audio-Boost/"
|
67 |
+
MODEL_NAME_OR_PATH="/data/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh/VITA-Audio-Balance/"
|
68 |
+
|
69 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
70 |
+
FLOW_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-decoder
|
71 |
+
AUDIO_TOKENIZER_TYPE="glm4voice"
|
72 |
+
|
73 |
+
export PYTHONPATH=${PYTHONPATH}:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/cosyvoice/:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
74 |
+
|
75 |
+
fi
|
76 |
+
|
77 |
+
######################################################################
|
78 |
+
DISTRIBUTED_ARGS="
|
79 |
+
--nproc_per_node $NPROC_PER_NODE \
|
80 |
+
--nnodes $NNODES \
|
81 |
+
--node_rank $NODE_RANK \
|
82 |
+
--master_addr $MASTER_ADDR \
|
83 |
+
--master_port $MASTER_PORT
|
84 |
+
"
|
85 |
+
|
86 |
+
######################################################################
|
87 |
+
if true
|
88 |
+
#if false
|
89 |
+
then
|
90 |
+
apt-get update && apt install -y ffmpeg
|
91 |
+
|
92 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/llama-questions/test.jsonl
|
93 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
|
94 |
+
--json_path ${JSON_PATH} \
|
95 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
96 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
97 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
98 |
+
--flow_path ${FLOW_PATH} \
|
99 |
+
--output_dir ${OUTPUT_DIR}/llama-questions/
|
100 |
+
|
101 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/llama-questions/test_hyp_ref_text.json
|
102 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
103 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/llama-questions/test_hyp_ref_speech.json
|
104 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
105 |
+
|
106 |
+
|
107 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/trivia_qa-audio/validation.jsonl
|
108 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
|
109 |
+
--json_path ${JSON_PATH} \
|
110 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
111 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
112 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
113 |
+
--flow_path ${FLOW_PATH} \
|
114 |
+
--output_dir ${OUTPUT_DIR}/trivia_qa-audio/
|
115 |
+
|
116 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/trivia_qa-audio/validation_hyp_ref_text.json
|
117 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
118 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/trivia_qa-audio/validation_hyp_ref_speech.json
|
119 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
120 |
+
|
121 |
+
|
122 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/spoken-web-questions/test.jsonl
|
123 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
|
124 |
+
--json_path ${JSON_PATH} \
|
125 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
126 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
127 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
128 |
+
--flow_path ${FLOW_PATH} \
|
129 |
+
--output_dir ${OUTPUT_DIR}/spoken-web-questions/
|
130 |
+
|
131 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/spoken-web-questions/test_hyp_ref_text.json
|
132 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
133 |
+
python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/spoken-web-questions/test_hyp_ref_speech.json
|
134 |
+
echo "copypaste ACC: ${JSON_PATH}"
|
135 |
+
|
136 |
+
fi
|
137 |
+
|
138 |
+
|
139 |
+
######################################################################
|
140 |
+
if true
|
141 |
+
#if false
|
142 |
+
then
|
143 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/validation.clean.jsonl
|
144 |
+
|
145 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
146 |
+
--json_path ${JSON_PATH} \
|
147 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
148 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
149 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
150 |
+
--flow_path ${FLOW_PATH} \
|
151 |
+
--output_dir ${OUTPUT_DIR}/librispeech_asr/
|
152 |
+
|
153 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.clean_hyp.txt
|
154 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
155 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.clean_hyp.txt
|
156 |
+
echo "copypaste WER: ${JSON_PATH}"
|
157 |
+
|
158 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/validation.other.jsonl
|
159 |
+
|
160 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
161 |
+
--json_path ${JSON_PATH} \
|
162 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
163 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
164 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
165 |
+
--flow_path ${FLOW_PATH} \
|
166 |
+
--output_dir ${OUTPUT_DIR}/librispeech_asr/
|
167 |
+
|
168 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.other_hyp.txt
|
169 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
170 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.other_hyp.txt
|
171 |
+
echo "copypaste WER: ${JSON_PATH}"
|
172 |
+
|
173 |
+
|
174 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/test.clean.jsonl
|
175 |
+
|
176 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
177 |
+
--json_path ${JSON_PATH} \
|
178 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
179 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
180 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
181 |
+
--flow_path ${FLOW_PATH} \
|
182 |
+
--output_dir ${OUTPUT_DIR}/librispeech_asr/
|
183 |
+
|
184 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.clean_hyp.txt
|
185 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
186 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.clean_hyp.txt
|
187 |
+
echo "copypaste WER: ${JSON_PATH}"
|
188 |
+
|
189 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/test.other.jsonl
|
190 |
+
|
191 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
192 |
+
--json_path ${JSON_PATH} \
|
193 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
194 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
195 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
196 |
+
--flow_path ${FLOW_PATH} \
|
197 |
+
--output_dir ${OUTPUT_DIR}/librispeech_asr/
|
198 |
+
|
199 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.other_hyp.txt
|
200 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
201 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.other_hyp.txt
|
202 |
+
echo "copypaste WER: ${JSON_PATH}"
|
203 |
+
|
204 |
+
fi
|
205 |
+
|
206 |
+
|
207 |
+
######################################################################
|
208 |
+
if true
|
209 |
+
#if false
|
210 |
+
then
|
211 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/wenet-e2e/wenetspeech/TEST_MEETING.jsonl
|
212 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
213 |
+
--json_path ${JSON_PATH} \
|
214 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
215 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
216 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
217 |
+
--flow_path ${FLOW_PATH} \
|
218 |
+
--output_dir ${OUTPUT_DIR}/wenetspeech/
|
219 |
+
|
220 |
+
python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_hyp.txt
|
221 |
+
echo "copypaste CER: ${JSON_PATH}"
|
222 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_hyp.txt
|
223 |
+
echo "copypaste WER: ${JSON_PATH}"
|
224 |
+
|
225 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/wenet-e2e/wenetspeech/TEST_NET.jsonl
|
226 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
227 |
+
--json_path ${JSON_PATH} \
|
228 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
229 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
230 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
231 |
+
--flow_path ${FLOW_PATH} \
|
232 |
+
--output_dir ${OUTPUT_DIR}/wenetspeech/
|
233 |
+
|
234 |
+
python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_NET_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_NET_hyp.txt
|
235 |
+
echo "copypaste CER: ${JSON_PATH}"
|
236 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_NET_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_NET_hyp.txt
|
237 |
+
echo "copypaste WER: ${JSON_PATH}"
|
238 |
+
fi
|
239 |
+
|
240 |
+
|
241 |
+
######################################################################
|
242 |
+
if true
|
243 |
+
#if false
|
244 |
+
then
|
245 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/shenyunhang/AISHELL-1/test.jsonl
|
246 |
+
|
247 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
|
248 |
+
--json_path ${JSON_PATH} \
|
249 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
250 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
251 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
252 |
+
--flow_path ${FLOW_PATH} \
|
253 |
+
--output_dir ${OUTPUT_DIR}/AISHELL-1/
|
254 |
+
|
255 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/AISHELL-1/_test.clean_ref.txt ${OUTPUT_DIR}/AISHELL-1/test.clean_hyp.txt
|
256 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
257 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/AISHELL-1/test_ref.txt ${OUTPUT_DIR}/AISHELL-1/test_hyp.txt
|
258 |
+
echo "copypaste WER: ${JSON_PATH}"
|
259 |
+
|
260 |
+
|
261 |
+
fi
|
262 |
+
|
263 |
+
|
264 |
+
######################################################################
|
265 |
+
if true
|
266 |
+
#if false
|
267 |
+
then
|
268 |
+
JSON_PATH=${ROOT_PATH}/data/jsonl/mythicinfinity/libritts/test.clean.jsonl
|
269 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_libritts.py \
|
270 |
+
--json_path ${JSON_PATH} \
|
271 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
272 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
273 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
274 |
+
--flow_path ${FLOW_PATH} \
|
275 |
+
--output_dir ${OUTPUT_DIR}/libritts/ \
|
276 |
+
|
277 |
+
#python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/libritts/test.clean_ref.txt ${OUTPUT_DIR}/libritts/test.clean_hyp.txt
|
278 |
+
#echo "copypaste CER: ${JSON_PATH}"
|
279 |
+
python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/libritts/test.clean_ref.txt ${OUTPUT_DIR}/libritts/test.clean_hyp.txt
|
280 |
+
echo "copypaste WER: ${JSON_PATH}"
|
281 |
+
fi
|
282 |
+
|
283 |
+
|
284 |
+
######################################################################
|
285 |
+
if true
|
286 |
+
#if false
|
287 |
+
then
|
288 |
+
|
289 |
+
DATA_PATH=${ROOT_PATH}/data/BytedanceSpeech/seed-tts-eval/
|
290 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_seedtts.py \
|
291 |
+
--data_path ${DATA_PATH} \
|
292 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
293 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
294 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
295 |
+
--flow_path ${FLOW_PATH} \
|
296 |
+
--output_dir ${OUTPUT_DIR}/seed-tts/ \
|
297 |
+
--speaker_prompt \
|
298 |
+
|
299 |
+
export ARNOLD_WORKER_GPU=${NPROC_PER_NODE}
|
300 |
+
cd ${LOCAL_CODE_PATH}/third_party/seed-tts-eval
|
301 |
+
|
302 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ zh
|
303 |
+
echo "copypaste WER: ${DATA_PATH} zh"
|
304 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ zh
|
305 |
+
echo "copypaste WER: ${DATA_PATH} hardcase"
|
306 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ en
|
307 |
+
echo "copypaste WER: ${DATA_PATH} en"
|
308 |
+
|
309 |
+
bash cal_sim.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ ${DATA_PATH}/wavlm_large_finetune.pth
|
310 |
+
echo "copypaste SIM: ${DATA_PATH} zh"
|
311 |
+
bash cal_sim.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ ${DATA_PATH}/wavlm_large_finetune.pth
|
312 |
+
echo "copypaste SIM: ${DATA_PATH} hardcase"
|
313 |
+
bash cal_sim.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ ${DATA_PATH}/wavlm_large_finetune.pth
|
314 |
+
echo "copypaste SIM: ${DATA_PATH} en"
|
315 |
+
|
316 |
+
cd ${LOCAL_CODE_PATH}
|
317 |
+
|
318 |
+
fi
|
319 |
+
|
320 |
+
|
321 |
+
######################################################################
|
322 |
+
if false
|
323 |
+
then
|
324 |
+
DATA_PATH=${ROOT_PATH}/data/BytedanceSpeech/seed-tts-eval/
|
325 |
+
torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_seedtts.py \
|
326 |
+
--data_path ${DATA_PATH} \
|
327 |
+
--model_name_or_path ${MODEL_NAME_OR_PATH} \
|
328 |
+
--audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
|
329 |
+
--audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
|
330 |
+
--flow_path ${FLOW_PATH} \
|
331 |
+
--output_dir ${OUTPUT_DIR}/seed-tts/ \
|
332 |
+
|
333 |
+
export ARNOLD_WORKER_GPU=${NPROC_PER_NODE}
|
334 |
+
cd ${LOCAL_CODE_PATH}/third_party/seed-tts-eval
|
335 |
+
|
336 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ zh
|
337 |
+
echo "copypaste WER: ${DATA_PATH} zh"
|
338 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ zh
|
339 |
+
echo "copypaste WER: ${DATA_PATH} hardcase"
|
340 |
+
bash cal_wer.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ en
|
341 |
+
echo "copypaste WER: ${DATA_PATH} en"
|
342 |
+
|
343 |
+
cd ${LOCAL_CODE_PATH}
|
344 |
+
|
345 |
+
fi
|
346 |
+
|
347 |
+
|
348 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh/20250313_040353/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name vita_audio/models/qwen2_mtp_v4_48_3/config_7B_mtp10.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 8000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 1.00e-3 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--language-model-freeze \
|
126 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
127 |
+
|
128 |
+
#--language-model-freeze \
|
129 |
+
#--dataset_joint false \
|
130 |
+
#--variable_length true \
|
131 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
132 |
+
|
133 |
+
#--bf16 True \
|
134 |
+
#--fp16 True \
|
135 |
+
#--tf32 True \
|
136 |
+
|
137 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage2.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/s2s_qwen25/finetune_glm4voice_mtp10_stage1.sh/20250315_022047/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${MODEL_NAME_OR_PATH} \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 4000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 5.00e-5 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.1 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2_no_optimizer.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--mtp_model_lr_mult 1.00e1 \
|
126 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
127 |
+
|
128 |
+
#--language-model-freeze \
|
129 |
+
#--dataset_joint false \
|
130 |
+
#--variable_length true \
|
131 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
132 |
+
|
133 |
+
#--bf16 True \
|
134 |
+
#--fp16 True \
|
135 |
+
#--tf32 True \
|
136 |
+
|
137 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/s2s_qwen25/finetune_glm4voice_stage1.sh/20250222_043913/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name vita_audio/models/qwen2_mtp_v4_48_3/config_7B_mtp1.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 4000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 1.00e-3 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--language-model-freeze \
|
126 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
127 |
+
|
128 |
+
#--language-model-freeze \
|
129 |
+
#--dataset_joint false \
|
130 |
+
#--variable_length true \
|
131 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
132 |
+
|
133 |
+
#--bf16 True \
|
134 |
+
#--fp16 True \
|
135 |
+
#--tf32 True \
|
136 |
+
|
137 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/models/Qwen/Qwen2.5-7B-Instruct/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${MODEL_NAME_OR_PATH} \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 8000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 6.00e-5 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
126 |
+
|
127 |
+
#--language-model-freeze \
|
128 |
+
#--dataset_joint false \
|
129 |
+
#--variable_length true \
|
130 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
131 |
+
|
132 |
+
#--bf16 True \
|
133 |
+
#--fp16 True \
|
134 |
+
#--tf32 True \
|
135 |
+
|
136 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh/20250418_075843/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp10.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "sensevoice_glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 8000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 1.00e-3 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--audio-model-freeze \
|
126 |
+
--language-model-freeze \
|
127 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
128 |
+
|
129 |
+
#--language-model-freeze \
|
130 |
+
#--dataset_joint false \
|
131 |
+
#--variable_length true \
|
132 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
133 |
+
|
134 |
+
#--bf16 True \
|
135 |
+
#--fp16 True \
|
136 |
+
#--tf32 True \
|
137 |
+
|
138 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage2.sh
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage2.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh/20250421_180624/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp10.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "sensevoice_glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 4000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 5.00e-5 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.1 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2_no_optimizer.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 2 \
|
125 |
+
--mtp_model_lr_mult 1.00e1 \
|
126 |
+
--audio-model-freeze \
|
127 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
128 |
+
|
129 |
+
#--language-model-freeze \
|
130 |
+
#--dataset_joint false \
|
131 |
+
#--variable_length true \
|
132 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
133 |
+
|
134 |
+
#--bf16 True \
|
135 |
+
#--fp16 True \
|
136 |
+
#--tf32 True \
|
137 |
+
|
138 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh/20250409_161438/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp1.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "sensevoice_glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 4000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 1.00e-3 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--audio-model-freeze \
|
126 |
+
--language-model-freeze \
|
127 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
128 |
+
|
129 |
+
#--language-model-freeze \
|
130 |
+
#--dataset_joint false \
|
131 |
+
#--variable_length true \
|
132 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
133 |
+
|
134 |
+
#--bf16 True \
|
135 |
+
#--fp16 True \
|
136 |
+
#--tf32 True \
|
137 |
+
|
138 |
+
set +x
|
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
set -x
|
5 |
+
|
6 |
+
SEQ_LENGTH="$1"
|
7 |
+
if [ -z "$SEQ_LENGTH" ]
|
8 |
+
then
|
9 |
+
SEQ_LENGTH=32768
|
10 |
+
fi
|
11 |
+
|
12 |
+
timestamp="$2"
|
13 |
+
if [ -z "$timestamp" ]
|
14 |
+
then
|
15 |
+
timestamp=`date +'%Y%m%d_%H'`0000
|
16 |
+
fi
|
17 |
+
|
18 |
+
######################################################################
|
19 |
+
export ROOT_PATH=/data/
|
20 |
+
export CODE_PATH=${ROOT_PATH}/VITA-Audio/
|
21 |
+
|
22 |
+
export LOCAL_ROOT_PATH=/data_local/
|
23 |
+
export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
|
24 |
+
mkdir -p ${LOCAL_ROOT_PATH}
|
25 |
+
mkdir -p ${LOCAL_CODE_PATH}
|
26 |
+
|
27 |
+
apt update
|
28 |
+
apt install -y rsync
|
29 |
+
rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
|
30 |
+
|
31 |
+
cd ${LOCAL_CODE_PATH}
|
32 |
+
rm -fr datasets
|
33 |
+
ln -s ${ROOT_PATH}/data datasets
|
34 |
+
|
35 |
+
######################################################################
|
36 |
+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
37 |
+
source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
|
38 |
+
pip3 install transformers==4.48.3
|
39 |
+
#pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
|
40 |
+
|
41 |
+
######################################################################
|
42 |
+
OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
|
43 |
+
|
44 |
+
mkdir -p ${OUTPUT_DIR}
|
45 |
+
rsync -avh $0 ${OUTPUT_DIR}
|
46 |
+
|
47 |
+
export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
|
48 |
+
mkdir -p ${HF_HOME}
|
49 |
+
|
50 |
+
export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
|
51 |
+
|
52 |
+
export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
|
53 |
+
|
54 |
+
######################################################################
|
55 |
+
LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
|
56 |
+
exec &> >(tee -a "$LOG")
|
57 |
+
echo Logging output to "$LOG"
|
58 |
+
|
59 |
+
echo ${@}
|
60 |
+
|
61 |
+
######################################################################
|
62 |
+
DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
|
63 |
+
|
64 |
+
MODEL_NAME_OR_PATH=${ROOT_PATH}/models/Qwen/Qwen2.5-7B-Instruct/
|
65 |
+
|
66 |
+
AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
|
67 |
+
|
68 |
+
rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
|
69 |
+
|
70 |
+
######################################################################
|
71 |
+
DISTRIBUTED_ARGS="
|
72 |
+
--nproc_per_node $NPROC_PER_NODE \
|
73 |
+
--nnodes $NNODES \
|
74 |
+
--node_rank $NODE_RANK \
|
75 |
+
--master_addr $MASTER_ADDR \
|
76 |
+
--master_port $MASTER_PORT
|
77 |
+
"
|
78 |
+
|
79 |
+
torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
|
80 |
+
--log_level "info" \
|
81 |
+
--do_train \
|
82 |
+
--overwrite_output_dir \
|
83 |
+
--config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp0.json \
|
84 |
+
--tokenizer_name $MODEL_NAME_OR_PATH \
|
85 |
+
--model_name_or_path $MODEL_NAME_OR_PATH \
|
86 |
+
--audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
|
87 |
+
--audio_tokenizer_type "sensevoice_glm4voice" \
|
88 |
+
--dataset_name $DATA_PATH \
|
89 |
+
--bf16 True \
|
90 |
+
--tf32 True \
|
91 |
+
--torch_dtype bfloat16 \
|
92 |
+
--output_dir $OUTPUT_DIR \
|
93 |
+
--num_train_epochs 1 \
|
94 |
+
--max_steps 8000 \
|
95 |
+
--per_device_train_batch_size 1 \
|
96 |
+
--per_device_eval_batch_size 1 \
|
97 |
+
--gradient_accumulation_steps 16 \
|
98 |
+
--save_strategy "steps" \
|
99 |
+
--save_steps 0.1 \
|
100 |
+
--save_total_limit 2 \
|
101 |
+
--learning_rate 6.00e-5 \
|
102 |
+
--max_grad_norm 1.0 \
|
103 |
+
--weight_decay 0.0 \
|
104 |
+
--adam_beta1 0.9 \
|
105 |
+
--adam_beta2 0.95 \
|
106 |
+
--adam_epsilon 1e-8 \
|
107 |
+
--warmup_ratio 0.03 \
|
108 |
+
--lr_scheduler_type "cosine" \
|
109 |
+
--logging_steps 1 \
|
110 |
+
--report_to "tensorboard" \
|
111 |
+
--model_max_length ${SEQ_LENGTH} \
|
112 |
+
--gradient_checkpointing True \
|
113 |
+
--deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
|
114 |
+
--trust_remote_code False \
|
115 |
+
--ddp_timeout 7200 \
|
116 |
+
--ddp_backend ${DISTRIBUTED_BACKEND} \
|
117 |
+
--attn_implementation flash_attention_2 \
|
118 |
+
--seed 42 \
|
119 |
+
--data_seed 42 \
|
120 |
+
--reset_attention_mask \
|
121 |
+
--reset_position_ids \
|
122 |
+
--create_attention_mask false \
|
123 |
+
--create_attention_mask_2d false \
|
124 |
+
--dataloader_num_workers 8 \
|
125 |
+
--audio-model-freeze \
|
126 |
+
--text-audio-interval-ratio 1 10 4 10 \
|
127 |
+
|
128 |
+
#--language-model-freeze \
|
129 |
+
#--dataset_joint false \
|
130 |
+
#--variable_length true \
|
131 |
+
#--tokenizer_name_or_path Qwen2Tokenizer \
|
132 |
+
|
133 |
+
#--bf16 True \
|
134 |
+
#--fp16 True \
|
135 |
+
#--tf32 True \
|
136 |
+
|
137 |
+
set +x
|
scripts/set_env_ds_gpu.sh
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#set -e
|
2 |
+
#set -x
|
3 |
+
|
4 |
+
######################################################################
|
5 |
+
#export NCCL_NET=IB
|
6 |
+
|
7 |
+
#export NCCL_SOCKET_IFNAME="bond1"
|
8 |
+
#export GLOO_SOCKET_IFNAME="bond1"
|
9 |
+
#export NCCL_DEBUG=INFO
|
10 |
+
#export NCCL_IB_QPS_PER_CONNECTION=2
|
11 |
+
|
12 |
+
#export GLOO_SOCKET_IFNAME=eth0
|
13 |
+
#export NCCL_DEBUG=INFO
|
14 |
+
#export NCCL_IB_QPS_PER_CONNECTION=2
|
15 |
+
|
16 |
+
#export NCCL_IB_DISABLE=1
|
17 |
+
|
18 |
+
export DISTRIBUTED_BACKEND="nccl"
|
19 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
20 |
+
|
21 |
+
######################################################################
|
22 |
+
pip3 install -r requirements_ds_gpu.txt
|
23 |
+
#pip3 install --no-index --find-links=/data/software/ -r requirements_ds_gpu.txt
|
24 |
+
|
25 |
+
pip3 install deepspeed==0.15.4
|
26 |
+
#pip3 install --no-index --find-links=/data/software/ deepspeed==0.15.4
|
27 |
+
#pip3 install deepspeed==0.16.1
|
28 |
+
#pip3 install deepspeed==0.14.2
|
29 |
+
|
30 |
+
pip3 install -e `pwd`
|
31 |
+
|
32 |
+
######################################################################
|
33 |
+
#export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
34 |
+
|
35 |
+
#apt update
|
36 |
+
#apt install -y openssh-server
|
37 |
+
#apt install -y rsync
|
38 |
+
|
39 |
+
######################################################################
|
40 |
+
|
41 |
+
export NNODES=${WORLD_SIZE}
|
42 |
+
export NODE_RANK=${RANK}
|
43 |
+
export MASTER_PORT=34567
|
44 |
+
|
45 |
+
if [ -z "$NPROC_PER_NODE" ]
|
46 |
+
then
|
47 |
+
export NPROC_PER_NODE=8
|
48 |
+
export NNODES=1
|
49 |
+
export NODE_RANK=0
|
50 |
+
export MASTER_ADDR=127.0.0.1
|
51 |
+
fi
|
52 |
+
|
53 |
+
######################################################################
|
setup.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import find_packages, setup
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='vita_audio',
|
5 |
+
version='0.0.1',
|
6 |
+
packages=[
|
7 |
+
"vita_audio",
|
8 |
+
],
|
9 |
+
install_requires=[
|
10 |
+
],
|
11 |
+
)
|
12 |
+
|
third_party/GLM-4-Voice/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*venv
|
2 |
+
*.DS_Store
|
3 |
+
*.idea/
|
4 |
+
test*
|
third_party/GLM-4-Voice/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/Matcha-TTS"]
|
2 |
+
path = third_party/Matcha-TTS
|
3 |
+
url = https://github.com/shivammehta25/Matcha-TTS
|
third_party/GLM-4-Voice/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 GLM-4-Voice Model Team @ Zhipu AI
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
third_party/GLM-4-Voice/README.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLM-4-Voice
|
2 |
+
<p align="center">
|
3 |
+
📄<a href="https://arxiv.org/abs/2412.02612" target="_blank"> Report </a> • 🤗 <a href="https://huggingface.co/THUDM/glm-4-voice-9b" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/studios/ZhipuAI/GLM-4-Voice-Demo" target="_blank">Demo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
|
4 |
+
</p>
|
5 |
+
|
6 |
+
Read this in [English](./README_en.md)
|
7 |
+
|
8 |
+
GLM-4-Voice 是智谱 AI 推出的端到端语音模型。GLM-4-Voice 能够直接理解和生成中英文语音,进行实时语音对话,并且能够遵循用户的指令要求改变语音的情感、语调、语速、方言等属性。
|
9 |
+
|
10 |
+
## Model Architecture
|
11 |
+

|
12 |
+
|
13 |
+
GLM-4-Voice 由三个部分组成:
|
14 |
+
* GLM-4-Voice-Tokenizer: 通过在 [Whisper](https://github.com/openai/whisper) 的 Encoder 部分增加 Vector Quantization 并在 ASR 数据上有监督训练,将连续的语音输入转化为离散的 token。每秒音频平均只需要用 12.5 个离散 token 表示。
|
15 |
+
* GLM-4-Voice-Decoder: 基于 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 的 Flow Matching 模型结构训练的支持流式推理的语音解码器,将离散化的语音 token 转化为连续的语音输出。最少只需要 10 个语音 token 即可开始生成,降低端到端对话延迟。
|
16 |
+
* GLM-4-Voice-9B: 在 [GLM-4-9B](https://github.com/THUDM/GLM-4) 的基础上进行语音模态的预训练和对齐,从而能够理解和生成离散化的语音 token。
|
17 |
+
|
18 |
+
预训练方面,为了攻克模型在语音模态下的智商和合成表现力两个难关,我们将 Speech2Speech 任务解耦合为“根据用户音频做出文本回复”和“根据文本回复和用户语音合成回复语音”两个任务,并设计两种预训练目标,分别基于文本预训练数据和无监督音频数据合成语音-文本交错数据以适配这两种任务形式。GLM-4-Voice-9B 在 GLM-4-9B 的基座模型基础之上,经过了数百万小时音频和数千亿 token 的音频文本交错数据预训练,拥有很强的音频理解和建模能力。
|
19 |
+
|
20 |
+
对齐方面,为了支持高质量的语音对话,我们设计了一套流式思考架构:根据用户语音,GLM-4-Voice 可以流式交替输出文本和语音两个模态的内容,其中语音模态以文本作为参照保证回复内容的高质量,并根据用户的语音指令要求做出相应的声音变化,在最大程度保留语言模型智商的情况下仍然具有端到端建模的能力,同时具备低延迟性,最低只需要输出 20 个 token 便可以合成语音。
|
21 |
+
|
22 |
+
## Model List
|
23 |
+
|
24 |
+
| Model | Type | Download |
|
25 |
+
|:---------------------:|:----------------:|:------------------------------------------------------------------------------------------------------------------------------------------------:|
|
26 |
+
| GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-tokenizer) |
|
27 |
+
| GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-9b) |
|
28 |
+
| GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-decoder) |
|
29 |
+
|
30 |
+
## Usage
|
31 |
+
我们提供了可以直接启动的 Web Demo。用户可以输入语音或文本,模型会同时给出语音和文字回复。
|
32 |
+
|
33 |
+

|
34 |
+
|
35 |
+
### Preparation
|
36 |
+
|
37 |
+
首先下载仓库
|
38 |
+
```shell
|
39 |
+
git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
|
40 |
+
cd GLM-4-Voice
|
41 |
+
```
|
42 |
+
然后安装依赖。也可以使用我们提供的镜像 `zhipuai/glm-4-voice:0.1` 以跳过这一步。
|
43 |
+
```shell
|
44 |
+
pip install -r requirements.txt
|
45 |
+
```
|
46 |
+
由于 Decoder 模型不支持通过 `transformers` 初始化,因此 checkpoint 需要单独下载。
|
47 |
+
|
48 |
+
```shell
|
49 |
+
# git 模型下载,请确保已安装 git-lfs
|
50 |
+
git lfs install
|
51 |
+
git clone https://huggingface.co/THUDM/glm-4-voice-decoder
|
52 |
+
```
|
53 |
+
|
54 |
+
### Launch Web Demo
|
55 |
+
|
56 |
+
1. 启动模型服务
|
57 |
+
|
58 |
+
```shell
|
59 |
+
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
|
60 |
+
```
|
61 |
+
|
62 |
+
如果你需要使用 Int4 精度启动,请运行
|
63 |
+
|
64 |
+
```shell
|
65 |
+
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
|
66 |
+
```
|
67 |
+
|
68 |
+
此命令会自动下载 `glm-4-voice-9b`。如果网络条件不好,也手动下载之后通过 `--model-path` 指定本地的路径。
|
69 |
+
|
70 |
+
2. 启动 web 服务
|
71 |
+
|
72 |
+
```shell
|
73 |
+
python web_demo.py --tokenizer-path THUDM/glm-4-voice-tokenizer --model-path THUDM/glm-4-voice-9b --flow-path ./glm-4-voice-decoder
|
74 |
+
```
|
75 |
+
|
76 |
+
即可在 http://127.0.0.1:8888 访问 web demo。
|
77 |
+
|
78 |
+
此命令会自动下载 `glm-4-voice-tokenizer` 和 `glm-4-voice-9b`。 请注意,`glm-4-voice-decoder` 需要手动下载。
|
79 |
+
|
80 |
+
如果网络条件不好,可以手动下载这三个模型之后通过 `--tokenizer-path`, `--flow-path` 和 `--model-path` 指定本地的路径。
|
81 |
+
|
82 |
+
### Known Issues
|
83 |
+
|
84 |
+
* Gradio 的流式音频播放效果不稳定。在生成完成后点击对话框中的音频质量会更高。
|
85 |
+
|
86 |
+
## Cases
|
87 |
+
|
88 |
+
我们提供了 GLM-4-Voice 的部分对话案例,包括控制情绪、改变语速、生成方言等。
|
89 |
+
|
90 |
+
* 用轻柔的声音引导我放松
|
91 |
+
|
92 |
+
https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
|
93 |
+
|
94 |
+
* 用激动的声音解说足球比赛
|
95 |
+
|
96 |
+
https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
|
97 |
+
|
98 |
+
* 用哀怨的声音讲一个鬼故事
|
99 |
+
|
100 |
+
https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
|
101 |
+
|
102 |
+
* 用东北话介绍一下冬天有多冷
|
103 |
+
|
104 |
+
https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
|
105 |
+
|
106 |
+
* 用重庆话念“吃葡萄不吐葡萄皮”
|
107 |
+
|
108 |
+
https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
|
109 |
+
|
110 |
+
* 用北京话念一句绕口令
|
111 |
+
|
112 |
+
https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
|
113 |
+
|
114 |
+
* 加快语速
|
115 |
+
|
116 |
+
https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
|
117 |
+
|
118 |
+
* 再快一点
|
119 |
+
|
120 |
+
https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
|
121 |
+
|
122 |
+
## Acknowledgements
|
123 |
+
|
124 |
+
本项目的部分代码来自:
|
125 |
+
* [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
126 |
+
* [transformers](https://github.com/huggingface/transformers)
|
127 |
+
* [GLM-4](https://github.com/THUDM/GLM-4)
|
128 |
+
|
129 |
+
## 协议
|
130 |
+
|
131 |
+
+ GLM-4 模型的权重的使用则需要遵循 [模型协议](https://huggingface.co/THUDM/glm-4-voice-9b/blob/main/LICENSE)。
|
132 |
+
|
133 |
+
+ 本开源仓库的代码则遵循 [Apache 2.0](LICENSE) 协议。
|
134 |
+
|
135 |
+
## 引用
|
136 |
+
|
137 |
+
```
|
138 |
+
@misc{zeng2024glm4,
|
139 |
+
title={GLM-4-Voice: Towards Intelligent and Human-Like End-to-End Spoken Chatbot},
|
140 |
+
author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Kedong Wang and Shengmin Jiang and Lei Zhao and Yuxiao Dong and Jie Tang},
|
141 |
+
year={2024},
|
142 |
+
eprint={2412.02612},
|
143 |
+
archivePrefix={arXiv},
|
144 |
+
primaryClass={cs.CL},
|
145 |
+
url={https://arxiv.org/abs/2412.02612},
|
146 |
+
}
|
147 |
+
```
|
148 |
+
|
149 |
+
```
|
150 |
+
@misc{zeng2024scaling,
|
151 |
+
title={Scaling Speech-Text Pre-training with Synthetic Interleaved Data},
|
152 |
+
author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Lei Zhang and Shengmin Jiang and Yuxiao Dong and Jie Tang},
|
153 |
+
year={2024},
|
154 |
+
eprint={2411.17607},
|
155 |
+
archivePrefix={arXiv},
|
156 |
+
primaryClass={cs.CL},
|
157 |
+
url={https://arxiv.org/abs/2411.17607},
|
158 |
+
}
|
159 |
+
```
|
third_party/GLM-4-Voice/README_en.md
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GLM-4-Voice
|
2 |
+
<p align="center">
|
3 |
+
📄<a href="https://arxiv.org/abs/2412.02612" target="_blank"> Report </a> • 🤗 <a href="https://huggingface.co/THUDM/glm-4-voice-9b" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/studios/ZhipuAI/GLM-4-Voice-Demo" target="_blank">Demo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
|
4 |
+
</p>
|
5 |
+
|
6 |
+
GLM-4-Voice is an end-to-end voice model launched by Zhipu AI. GLM-4-Voice can directly understand and generate Chinese and English speech, engage in real-time voice conversations, and change attributes such as emotion, intonation, speech rate, and dialect based on user instructions.
|
7 |
+
|
8 |
+
## Model Architecture
|
9 |
+
|
10 |
+

|
11 |
+
We provide the three components of GLM-4-Voice:
|
12 |
+
* GLM-4-Voice-Tokenizer: Trained by adding vector quantization to the encoder part of [Whisper](https://github.com/openai/whisper), converting continuous speech input into discrete tokens. Each second of audio is converted into 12.5 discrete tokens.
|
13 |
+
* GLM-4-Voice-9B: Pre-trained and aligned on speech modality based on [GLM-4-9B](https://github.com/THUDM/GLM-4), enabling understanding and generation of discretized speech.
|
14 |
+
* GLM-4-Voice-Decoder: A speech decoder supporting streaming inference, retrained based on [CosyVoice](https://github.com/FunAudioLLM/CosyVoice), converting discrete speech tokens into continuous speech output. Generation can start with as few as 10 audio tokens, reducing conversation latency.
|
15 |
+
|
16 |
+
## Model List
|
17 |
+
|
18 |
+
| Model | Type | Download |
|
19 |
+
|:---------------------:|:----------------:|:--------------------------------------------------------------------:|
|
20 |
+
| GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) |
|
21 |
+
| GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b) |
|
22 |
+
| GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder) |
|
23 |
+
|
24 |
+
## Usage
|
25 |
+
We provide a Web Demo that can be launched directly. Users can input speech or text, and the model will respond with both speech and text.
|
26 |
+
|
27 |
+

|
28 |
+
|
29 |
+
### Preparation
|
30 |
+
|
31 |
+
First, download the repository
|
32 |
+
```shell
|
33 |
+
git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
|
34 |
+
cd GLM-4-Voice
|
35 |
+
```
|
36 |
+
Then, install the dependencies. You can also use our pre-built docker image `zhipuai/glm-4-voice:0.1` to skip the step.
|
37 |
+
```shell
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
Since the Decoder model does not support initialization via `transformers`, the checkpoint needs to be downloaded separately.
|
41 |
+
|
42 |
+
```shell
|
43 |
+
# Git model download, please ensure git-lfs is installed
|
44 |
+
git clone https://huggingface.co/THUDM/glm-4-voice-decoder
|
45 |
+
```
|
46 |
+
|
47 |
+
### Launch Web Demo
|
48 |
+
|
49 |
+
1. Start the model server
|
50 |
+
|
51 |
+
```shell
|
52 |
+
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
|
53 |
+
```
|
54 |
+
|
55 |
+
If you need to launch with Int4 precision, run
|
56 |
+
|
57 |
+
```shell
|
58 |
+
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
|
59 |
+
```
|
60 |
+
|
61 |
+
This command will automatically download `glm-4-voice-9b`. If network conditions are poor, you can manually download it and specify the local path using `--model-path`.
|
62 |
+
|
63 |
+
2. Start the web service
|
64 |
+
|
65 |
+
```shell
|
66 |
+
python web_demo.py --tokenizer-path THUDM/glm-4-voice-tokenizer --model-path THUDM/glm-4-voice-9b --flow-path ./glm-4-voice-decoder
|
67 |
+
```
|
68 |
+
|
69 |
+
You can access the web demo at [http://127.0.0.1:8888](http://127.0.0.1:8888).
|
70 |
+
This command will automatically download `glm-4-voice-tokenizer` and `glm-4-voice-9b`. Please note that `glm-4-voice-decoder` needs to be downloaded manually.
|
71 |
+
If the network connection is poor, you can manually download these three models and specify the local paths using `--tokenizer-path`, `--flow-path`, and `--model-path`.
|
72 |
+
|
73 |
+
### Known Issues
|
74 |
+
* Gradio’s streaming audio playback can be unstable. The audio quality will be higher when clicking on the audio in the dialogue box after generation is complete.
|
75 |
+
|
76 |
+
## Examples
|
77 |
+
We provide some dialogue cases for GLM-4-Voice, including emotion control, speech rate alteration, dialect generation, etc. (The examples are in Chinese.)
|
78 |
+
|
79 |
+
* Use a gentle voice to guide me to relax
|
80 |
+
|
81 |
+
https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
|
82 |
+
|
83 |
+
* Use an excited voice to commentate a football match
|
84 |
+
|
85 |
+
https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
|
86 |
+
|
87 |
+
* Tell a ghost story with a mournful voice
|
88 |
+
|
89 |
+
https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
|
90 |
+
|
91 |
+
* Introduce how cold winter is with a Northeastern dialect
|
92 |
+
|
93 |
+
https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
|
94 |
+
|
95 |
+
* Say "Eat grapes without spitting out the skins" in Chongqing dialect
|
96 |
+
|
97 |
+
https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
|
98 |
+
|
99 |
+
* Recite a tongue twister with a Beijing accent
|
100 |
+
|
101 |
+
https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
|
102 |
+
|
103 |
+
* Increase the speech rate
|
104 |
+
|
105 |
+
https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
|
106 |
+
|
107 |
+
* Even faster
|
108 |
+
|
109 |
+
https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
|
110 |
+
|
111 |
+
## Acknowledgements
|
112 |
+
|
113 |
+
Some code in this project is from:
|
114 |
+
* [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
115 |
+
* [transformers](https://github.com/huggingface/transformers)
|
116 |
+
* [GLM-4](https://github.com/THUDM/GLM-4)
|
117 |
+
|
118 |
+
## License Agreement
|
119 |
+
|
120 |
+
+ The use of GLM-4 model weights must follow the [Model License Agreement](https://huggingface.co/THUDM/glm-4-voice-9b/blob/main/LICENSE).
|
121 |
+
|
122 |
+
+ The code in this open-source repository is licensed under the [Apache 2.0](LICENSE) License.
|
123 |
+
|
124 |
+
## Citation
|
125 |
+
|
126 |
+
```
|
127 |
+
@misc{zeng2024glm4,
|
128 |
+
title={GLM-4-Voice: Towards Intelligent and Human-Like End-to-End Spoken Chatbot},
|
129 |
+
author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Kedong Wang and Shengmin Jiang and Lei Zhao and Yuxiao Dong and Jie Tang},
|
130 |
+
year={2024},
|
131 |
+
eprint={2412.02612},
|
132 |
+
archivePrefix={arXiv},
|
133 |
+
primaryClass={cs.CL},
|
134 |
+
url={https://arxiv.org/abs/2412.02612},
|
135 |
+
}
|
136 |
+
```
|
137 |
+
|
138 |
+
```
|
139 |
+
@misc{zeng2024scaling,
|
140 |
+
title={Scaling Speech-Text Pre-training with Synthetic Interleaved Data},
|
141 |
+
author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Lei Zhang and Shengmin Jiang and Yuxiao Dong and Jie Tang},
|
142 |
+
year={2024},
|
143 |
+
eprint={2411.17607},
|
144 |
+
archivePrefix={arXiv},
|
145 |
+
primaryClass={cs.CL},
|
146 |
+
url={https://arxiv.org/abs/2411.17607},
|
147 |
+
}
|
148 |
+
```
|
third_party/GLM-4-Voice/audio_process.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import librosa
|
3 |
+
import soundfile as sf
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
import io
|
7 |
+
|
8 |
+
# Split audio stream at silence points to prevent playback stuttering issues
|
9 |
+
# caused by AAC encoder frame padding when streaming audio through Gradio audio components.
|
10 |
+
class AudioStreamProcessor:
|
11 |
+
def __init__(self, sr=22050, min_silence_duration=0.1, threshold_db=-40):
|
12 |
+
self.sr = sr
|
13 |
+
self.min_silence_duration = min_silence_duration
|
14 |
+
self.threshold_db = threshold_db
|
15 |
+
self.buffer = np.array([])
|
16 |
+
|
17 |
+
|
18 |
+
def process(self, audio_data, last=False):
|
19 |
+
"""
|
20 |
+
Add audio data and process it
|
21 |
+
params:
|
22 |
+
audio_data: audio data in numpy array
|
23 |
+
last: whether this is the last chunk of data
|
24 |
+
returns:
|
25 |
+
Processed audio data, returns None if no split point is found
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Add new data to buffer
|
29 |
+
self.buffer = np.concatenate([self.buffer, audio_data]) if len(self.buffer) > 0 else audio_data
|
30 |
+
|
31 |
+
if last:
|
32 |
+
result = self.buffer
|
33 |
+
self.buffer = np.array([])
|
34 |
+
return self._to_wav_bytes(result)
|
35 |
+
|
36 |
+
# Find silence boundary
|
37 |
+
split_point = self._find_silence_boundary(self.buffer)
|
38 |
+
|
39 |
+
if split_point is not None:
|
40 |
+
# Modified: Extend split point to the end of silence
|
41 |
+
silence_end = self._find_silence_end(split_point)
|
42 |
+
result = self.buffer[:silence_end]
|
43 |
+
self.buffer = self.buffer[silence_end:]
|
44 |
+
return self._to_wav_bytes(result)
|
45 |
+
|
46 |
+
return None
|
47 |
+
|
48 |
+
def _find_silence_boundary(self, audio):
|
49 |
+
"""
|
50 |
+
Find the starting point of silence boundary in audio
|
51 |
+
"""
|
52 |
+
# Convert audio to decibels
|
53 |
+
db = librosa.amplitude_to_db(np.abs(audio), ref=np.max)
|
54 |
+
|
55 |
+
# Find points below threshold
|
56 |
+
silence_points = np.where(db < self.threshold_db)[0]
|
57 |
+
|
58 |
+
if len(silence_points) == 0:
|
59 |
+
return None
|
60 |
+
|
61 |
+
# Calculate minimum silence samples
|
62 |
+
min_silence_samples = int(self.min_silence_duration * self.sr)
|
63 |
+
|
64 |
+
# Search backwards for continuous silence segment starting point
|
65 |
+
for i in range(len(silence_points) - min_silence_samples, -1, -1):
|
66 |
+
if i < 0:
|
67 |
+
break
|
68 |
+
if np.all(np.diff(silence_points[i:i+min_silence_samples]) == 1):
|
69 |
+
return silence_points[i]
|
70 |
+
|
71 |
+
return None
|
72 |
+
|
73 |
+
def _find_silence_end(self, start_point):
|
74 |
+
"""
|
75 |
+
Find the end point of silence segment
|
76 |
+
"""
|
77 |
+
db = librosa.amplitude_to_db(np.abs(self.buffer[start_point:]), ref=np.max)
|
78 |
+
silence_points = np.where(db >= self.threshold_db)[0]
|
79 |
+
|
80 |
+
if len(silence_points) == 0:
|
81 |
+
return len(self.buffer)
|
82 |
+
|
83 |
+
return start_point + silence_points[0]
|
84 |
+
|
85 |
+
def _to_wav_bytes(self, audio_data):
|
86 |
+
"""
|
87 |
+
trans_to_wav_bytes
|
88 |
+
"""
|
89 |
+
wav_buffer = io.BytesIO()
|
90 |
+
sf.write(wav_buffer, audio_data, self.sr, format='WAV')
|
91 |
+
return wav_buffer.getvalue()
|
92 |
+
|
93 |
+
|
third_party/GLM-4-Voice/cosyvoice/__init__.py
ADDED
File without changes
|
third_party/GLM-4-Voice/cosyvoice/bin/inference.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
import os
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
import torchaudio
|
25 |
+
from hyperpyyaml import load_hyperpyyaml
|
26 |
+
from tqdm import tqdm
|
27 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
28 |
+
|
29 |
+
from cosyvoice.dataset.dataset import Dataset
|
30 |
+
|
31 |
+
def get_args():
|
32 |
+
parser = argparse.ArgumentParser(description='inference with your model')
|
33 |
+
parser.add_argument('--config', required=True, help='config file')
|
34 |
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
35 |
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
36 |
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
37 |
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
38 |
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
39 |
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
40 |
+
parser.add_argument('--gpu',
|
41 |
+
type=int,
|
42 |
+
default=-1,
|
43 |
+
help='gpu id for this rank, -1 for cpu')
|
44 |
+
parser.add_argument('--mode',
|
45 |
+
default='sft',
|
46 |
+
choices=['sft', 'zero_shot'],
|
47 |
+
help='inference mode')
|
48 |
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
49 |
+
args = parser.parse_args()
|
50 |
+
print(args)
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def main():
|
55 |
+
args = get_args()
|
56 |
+
logging.basicConfig(level=logging.DEBUG,
|
57 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
58 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
59 |
+
|
60 |
+
# Init cosyvoice models from configs
|
61 |
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
62 |
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
63 |
+
with open(args.config, 'r') as f:
|
64 |
+
configs = load_hyperpyyaml(f)
|
65 |
+
|
66 |
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
67 |
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
68 |
+
|
69 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
70 |
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
71 |
+
|
72 |
+
del configs
|
73 |
+
os.makedirs(args.result_dir, exist_ok=True)
|
74 |
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
75 |
+
f = open(fn, 'w')
|
76 |
+
with torch.no_grad():
|
77 |
+
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
|
78 |
+
utts = batch["utts"]
|
79 |
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
80 |
+
text = batch["text"]
|
81 |
+
text_token = batch["text_token"].to(device)
|
82 |
+
text_token_len = batch["text_token_len"].to(device)
|
83 |
+
tts_text = batch["tts_text"]
|
84 |
+
tts_index = batch["tts_index"]
|
85 |
+
tts_text_token = batch["tts_text_token"].to(device)
|
86 |
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
87 |
+
speech_token = batch["speech_token"].to(device)
|
88 |
+
speech_token_len = batch["speech_token_len"].to(device)
|
89 |
+
speech_feat = batch["speech_feat"].to(device)
|
90 |
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
91 |
+
utt_embedding = batch["utt_embedding"].to(device)
|
92 |
+
spk_embedding = batch["spk_embedding"].to(device)
|
93 |
+
if args.mode == 'sft':
|
94 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
95 |
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
96 |
+
else:
|
97 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
98 |
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
99 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
100 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
101 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
102 |
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
103 |
+
model_output = model.inference(**model_input)
|
104 |
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
105 |
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
106 |
+
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
|
107 |
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
108 |
+
f.flush()
|
109 |
+
f.close()
|
110 |
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
third_party/GLM-4-Voice/cosyvoice/bin/train.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import print_function
|
16 |
+
import argparse
|
17 |
+
import datetime
|
18 |
+
import logging
|
19 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
+
from copy import deepcopy
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
# import deepspeed
|
24 |
+
import pdb
|
25 |
+
from hyperpyyaml import load_hyperpyyaml
|
26 |
+
|
27 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
28 |
+
|
29 |
+
from cosyvoice.utils.executor import Executor
|
30 |
+
from cosyvoice.utils.train_utils import (
|
31 |
+
init_distributed,
|
32 |
+
init_dataset_and_dataloader,
|
33 |
+
init_optimizer_and_scheduler,
|
34 |
+
init_summarywriter, save_model,
|
35 |
+
wrap_cuda_model, check_modify_and_save_config)
|
36 |
+
|
37 |
+
|
38 |
+
def get_args():
|
39 |
+
parser = argparse.ArgumentParser(description='training your network')
|
40 |
+
parser.add_argument('--train_engine',
|
41 |
+
default='torch_ddp',
|
42 |
+
choices=['torch_ddp', 'deepspeed'],
|
43 |
+
help='Engine for paralleled training')
|
44 |
+
parser.add_argument('--model', required=True, help='model which will be trained')
|
45 |
+
parser.add_argument('--config', required=True, help='config file')
|
46 |
+
parser.add_argument('--train_data', required=True, help='train data file')
|
47 |
+
parser.add_argument('--cv_data', required=True, help='cv data file')
|
48 |
+
parser.add_argument('--checkpoint', help='checkpoint model')
|
49 |
+
parser.add_argument('--model_dir', required=True, help='save model dir')
|
50 |
+
parser.add_argument('--tensorboard_dir',
|
51 |
+
default='tensorboard',
|
52 |
+
help='tensorboard log dir')
|
53 |
+
parser.add_argument('--ddp.dist_backend',
|
54 |
+
dest='dist_backend',
|
55 |
+
default='nccl',
|
56 |
+
choices=['nccl', 'gloo'],
|
57 |
+
help='distributed backend')
|
58 |
+
parser.add_argument('--num_workers',
|
59 |
+
default=0,
|
60 |
+
type=int,
|
61 |
+
help='num of subprocess workers for reading')
|
62 |
+
parser.add_argument('--prefetch',
|
63 |
+
default=100,
|
64 |
+
type=int,
|
65 |
+
help='prefetch number')
|
66 |
+
parser.add_argument('--pin_memory',
|
67 |
+
action='store_true',
|
68 |
+
default=False,
|
69 |
+
help='Use pinned memory buffers used for reading')
|
70 |
+
parser.add_argument('--deepspeed.save_states',
|
71 |
+
dest='save_states',
|
72 |
+
default='model_only',
|
73 |
+
choices=['model_only', 'model+optimizer'],
|
74 |
+
help='save model/optimizer states')
|
75 |
+
parser.add_argument('--timeout',
|
76 |
+
default=30,
|
77 |
+
type=int,
|
78 |
+
help='timeout (in seconds) of cosyvoice_join.')
|
79 |
+
# parser = deepspeed.add_config_arguments(parser)
|
80 |
+
args = parser.parse_args()
|
81 |
+
return args
|
82 |
+
|
83 |
+
|
84 |
+
@record
|
85 |
+
def main():
|
86 |
+
args = get_args()
|
87 |
+
logging.basicConfig(level=logging.DEBUG,
|
88 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
89 |
+
|
90 |
+
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
|
91 |
+
with open(args.config, 'r') as f:
|
92 |
+
configs = load_hyperpyyaml(f, overrides=override_dict)
|
93 |
+
configs['train_conf'].update(vars(args))
|
94 |
+
|
95 |
+
# Init env for ddp
|
96 |
+
init_distributed(args)
|
97 |
+
|
98 |
+
# Get dataset & dataloader
|
99 |
+
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
100 |
+
init_dataset_and_dataloader(args, configs)
|
101 |
+
|
102 |
+
# Do some sanity checks and save config to arsg.model_dir
|
103 |
+
configs = check_modify_and_save_config(args, configs)
|
104 |
+
|
105 |
+
# Tensorboard summary
|
106 |
+
writer = init_summarywriter(args)
|
107 |
+
|
108 |
+
# load checkpoint
|
109 |
+
model = configs[args.model]
|
110 |
+
if args.checkpoint is not None:
|
111 |
+
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
|
112 |
+
|
113 |
+
# Dispatch model from cpu to gpu
|
114 |
+
model = wrap_cuda_model(args, model)
|
115 |
+
|
116 |
+
# Get optimizer & scheduler
|
117 |
+
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
|
118 |
+
# pdb.set_trace()
|
119 |
+
# Save init checkpoints
|
120 |
+
info_dict = deepcopy(configs['train_conf'])
|
121 |
+
save_model(model, 'init', info_dict)
|
122 |
+
|
123 |
+
# Get executor
|
124 |
+
executor = Executor()
|
125 |
+
|
126 |
+
# Start training loop
|
127 |
+
for epoch in range(info_dict['max_epoch']):
|
128 |
+
executor.epoch = epoch
|
129 |
+
train_dataset.set_epoch(epoch)
|
130 |
+
dist.barrier()
|
131 |
+
# try:
|
132 |
+
# dist.barrier()
|
133 |
+
# except RuntimeError as e:
|
134 |
+
# logging.info('except RuntimeError as e: {}'.format(e))
|
135 |
+
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
136 |
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
137 |
+
dist.destroy_process_group(group_join)
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
main()
|
third_party/GLM-4-Voice/cosyvoice/cli/__init__.py
ADDED
File without changes
|
third_party/GLM-4-Voice/cosyvoice/cli/cosyvoice.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
from hyperpyyaml import load_hyperpyyaml
|
17 |
+
from modelscope import snapshot_download
|
18 |
+
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
19 |
+
from cosyvoice.cli.model import CosyVoiceModel
|
20 |
+
|
21 |
+
class CosyVoice:
|
22 |
+
|
23 |
+
def __init__(self, model_dir):
|
24 |
+
instruct = True if '-Instruct' in model_dir else False
|
25 |
+
self.model_dir = model_dir
|
26 |
+
if not os.path.exists(model_dir):
|
27 |
+
model_dir = snapshot_download(model_dir)
|
28 |
+
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
29 |
+
configs = load_hyperpyyaml(f)
|
30 |
+
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
31 |
+
configs['feat_extractor'],
|
32 |
+
'{}/campplus.onnx'.format(model_dir),
|
33 |
+
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
34 |
+
'{}/spk2info.pt'.format(model_dir),
|
35 |
+
instruct,
|
36 |
+
configs['allowed_special'])
|
37 |
+
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
38 |
+
self.model.load('{}/llm.pt'.format(model_dir),
|
39 |
+
'{}/flow.pt'.format(model_dir),
|
40 |
+
'{}/hift.pt'.format(model_dir))
|
41 |
+
del configs
|
42 |
+
|
43 |
+
def list_avaliable_spks(self):
|
44 |
+
spks = list(self.frontend.spk2info.keys())
|
45 |
+
return spks
|
46 |
+
|
47 |
+
def inference_sft(self, tts_text, spk_id):
|
48 |
+
tts_speeches = []
|
49 |
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
50 |
+
model_input = self.frontend.frontend_sft(i, spk_id)
|
51 |
+
model_output = self.model.inference(**model_input)
|
52 |
+
tts_speeches.append(model_output['tts_speech'])
|
53 |
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
54 |
+
|
55 |
+
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
56 |
+
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
57 |
+
tts_speeches = []
|
58 |
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
59 |
+
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
60 |
+
model_output = self.model.inference(**model_input)
|
61 |
+
tts_speeches.append(model_output['tts_speech'])
|
62 |
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
63 |
+
|
64 |
+
def inference_cross_lingual(self, tts_text, prompt_speech_16k):
|
65 |
+
if self.frontend.instruct is True:
|
66 |
+
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
67 |
+
tts_speeches = []
|
68 |
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
69 |
+
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
70 |
+
model_output = self.model.inference(**model_input)
|
71 |
+
tts_speeches.append(model_output['tts_speech'])
|
72 |
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
73 |
+
|
74 |
+
def inference_instruct(self, tts_text, spk_id, instruct_text):
|
75 |
+
if self.frontend.instruct is False:
|
76 |
+
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
77 |
+
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
78 |
+
tts_speeches = []
|
79 |
+
for i in self.frontend.text_normalize(tts_text, split=True):
|
80 |
+
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
81 |
+
model_output = self.model.inference(**model_input)
|
82 |
+
tts_speeches.append(model_output['tts_speech'])
|
83 |
+
return {'tts_speech': torch.concat(tts_speeches, dim=1)}
|
third_party/GLM-4-Voice/cosyvoice/cli/frontend.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from functools import partial
|
15 |
+
import onnxruntime
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
import whisper
|
19 |
+
from typing import Callable
|
20 |
+
import torchaudio.compliance.kaldi as kaldi
|
21 |
+
import torchaudio
|
22 |
+
import os
|
23 |
+
import re
|
24 |
+
import inflect
|
25 |
+
try:
|
26 |
+
import ttsfrd
|
27 |
+
use_ttsfrd = True
|
28 |
+
except ImportError:
|
29 |
+
print("failed to import ttsfrd, use WeTextProcessing instead")
|
30 |
+
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
31 |
+
from tn.english.normalizer import Normalizer as EnNormalizer
|
32 |
+
use_ttsfrd = False
|
33 |
+
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
34 |
+
|
35 |
+
|
36 |
+
class CosyVoiceFrontEnd:
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
get_tokenizer: Callable,
|
40 |
+
feat_extractor: Callable,
|
41 |
+
campplus_model: str,
|
42 |
+
speech_tokenizer_model: str,
|
43 |
+
spk2info: str = '',
|
44 |
+
instruct: bool = False,
|
45 |
+
allowed_special: str = 'all'):
|
46 |
+
self.tokenizer = get_tokenizer()
|
47 |
+
self.feat_extractor = feat_extractor
|
48 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
49 |
+
option = onnxruntime.SessionOptions()
|
50 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
51 |
+
option.intra_op_num_threads = 1
|
52 |
+
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
53 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
|
54 |
+
if os.path.exists(spk2info):
|
55 |
+
self.spk2info = torch.load(spk2info, map_location=self.device)
|
56 |
+
self.instruct = instruct
|
57 |
+
self.allowed_special = allowed_special
|
58 |
+
self.inflect_parser = inflect.engine()
|
59 |
+
self.use_ttsfrd = use_ttsfrd
|
60 |
+
if self.use_ttsfrd:
|
61 |
+
self.frd = ttsfrd.TtsFrontendEngine()
|
62 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
63 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
|
64 |
+
self.frd.set_lang_type('pinyin')
|
65 |
+
self.frd.enable_pinyin_mix(True)
|
66 |
+
self.frd.set_breakmodel_index(1)
|
67 |
+
else:
|
68 |
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
69 |
+
self.en_tn_model = EnNormalizer()
|
70 |
+
|
71 |
+
def _extract_text_token(self, text):
|
72 |
+
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
73 |
+
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
74 |
+
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
75 |
+
return text_token, text_token_len
|
76 |
+
|
77 |
+
def _extract_speech_token(self, speech):
|
78 |
+
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
79 |
+
speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
|
80 |
+
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
81 |
+
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
82 |
+
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
83 |
+
return speech_token, speech_token_len
|
84 |
+
|
85 |
+
def _extract_spk_embedding(self, speech):
|
86 |
+
feat = kaldi.fbank(speech,
|
87 |
+
num_mel_bins=80,
|
88 |
+
dither=0,
|
89 |
+
sample_frequency=16000)
|
90 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
91 |
+
embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
92 |
+
embedding = torch.tensor([embedding]).to(self.device)
|
93 |
+
return embedding
|
94 |
+
|
95 |
+
def _extract_speech_feat(self, speech):
|
96 |
+
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
97 |
+
speech_feat = speech_feat.unsqueeze(dim=0)
|
98 |
+
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
99 |
+
return speech_feat, speech_feat_len
|
100 |
+
|
101 |
+
def text_normalize(self, text, split=True):
|
102 |
+
text = text.strip()
|
103 |
+
if contains_chinese(text):
|
104 |
+
if self.use_ttsfrd:
|
105 |
+
text = self.frd.get_frd_extra_info(text, 'input')
|
106 |
+
else:
|
107 |
+
text = self.zh_tn_model.normalize(text)
|
108 |
+
text = text.replace("\n", "")
|
109 |
+
text = replace_blank(text)
|
110 |
+
text = replace_corner_mark(text)
|
111 |
+
text = text.replace(".", "、")
|
112 |
+
text = text.replace(" - ", ",")
|
113 |
+
text = remove_bracket(text)
|
114 |
+
text = re.sub(r'[,,]+$', '。', text)
|
115 |
+
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
116 |
+
token_min_n=60, merge_len=20,
|
117 |
+
comma_split=False)]
|
118 |
+
else:
|
119 |
+
if self.use_ttsfrd:
|
120 |
+
text = self.frd.get_frd_extra_info(text, 'input')
|
121 |
+
else:
|
122 |
+
text = self.en_tn_model.normalize(text)
|
123 |
+
text = spell_out_number(text, self.inflect_parser)
|
124 |
+
texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
125 |
+
token_min_n=60, merge_len=20,
|
126 |
+
comma_split=False)]
|
127 |
+
if split is False:
|
128 |
+
return text
|
129 |
+
return texts
|
130 |
+
|
131 |
+
def frontend_sft(self, tts_text, spk_id):
|
132 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
133 |
+
embedding = self.spk2info[spk_id]['embedding']
|
134 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
135 |
+
return model_input
|
136 |
+
|
137 |
+
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
138 |
+
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
139 |
+
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
140 |
+
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
141 |
+
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
142 |
+
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
143 |
+
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
144 |
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
145 |
+
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
146 |
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
147 |
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
148 |
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
149 |
+
'llm_embedding': embedding, 'flow_embedding': embedding}
|
150 |
+
return model_input
|
151 |
+
|
152 |
+
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
153 |
+
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
154 |
+
# in cross lingual mode, we remove prompt in llm
|
155 |
+
del model_input['prompt_text']
|
156 |
+
del model_input['prompt_text_len']
|
157 |
+
del model_input['llm_prompt_speech_token']
|
158 |
+
del model_input['llm_prompt_speech_token_len']
|
159 |
+
return model_input
|
160 |
+
|
161 |
+
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
162 |
+
model_input = self.frontend_sft(tts_text, spk_id)
|
163 |
+
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
164 |
+
del model_input['llm_embedding']
|
165 |
+
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
166 |
+
model_input['prompt_text'] = instruct_text_token
|
167 |
+
model_input['prompt_text_len'] = instruct_text_token_len
|
168 |
+
return model_input
|
third_party/GLM-4-Voice/cosyvoice/cli/model.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
|
16 |
+
class CosyVoiceModel:
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
llm: torch.nn.Module,
|
20 |
+
flow: torch.nn.Module,
|
21 |
+
hift: torch.nn.Module):
|
22 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
23 |
+
self.llm = llm
|
24 |
+
self.flow = flow
|
25 |
+
self.hift = hift
|
26 |
+
|
27 |
+
def load(self, llm_model, flow_model, hift_model):
|
28 |
+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
29 |
+
self.llm.to(self.device).eval()
|
30 |
+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
|
31 |
+
self.flow.to(self.device).eval()
|
32 |
+
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
33 |
+
self.hift.to(self.device).eval()
|
34 |
+
|
35 |
+
def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
36 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
37 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
38 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
39 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
40 |
+
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
41 |
+
text_len=text_len.to(self.device),
|
42 |
+
prompt_text=prompt_text.to(self.device),
|
43 |
+
prompt_text_len=prompt_text_len.to(self.device),
|
44 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
45 |
+
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
46 |
+
embedding=llm_embedding.to(self.device),
|
47 |
+
beam_size=1,
|
48 |
+
sampling=25,
|
49 |
+
max_token_text_ratio=30,
|
50 |
+
min_token_text_ratio=3)
|
51 |
+
tts_mel = self.flow.inference(token=tts_speech_token,
|
52 |
+
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
53 |
+
prompt_token=flow_prompt_speech_token.to(self.device),
|
54 |
+
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
55 |
+
prompt_feat=prompt_speech_feat.to(self.device),
|
56 |
+
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
57 |
+
embedding=flow_embedding.to(self.device))
|
58 |
+
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
59 |
+
torch.cuda.empty_cache()
|
60 |
+
return {'tts_speech': tts_speech}
|
61 |
+
|
62 |
+
def text_to_token(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
63 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
64 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
65 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
66 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
67 |
+
tts_speech_token = self.llm.inference(text=text.to(self.device),
|
68 |
+
text_len=text_len.to(self.device),
|
69 |
+
prompt_text=prompt_text.to(self.device),
|
70 |
+
prompt_text_len=prompt_text_len.to(self.device),
|
71 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
72 |
+
prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
|
73 |
+
embedding=llm_embedding.to(self.device),
|
74 |
+
beam_size=1,
|
75 |
+
sampling=25,
|
76 |
+
max_token_text_ratio=30,
|
77 |
+
min_token_text_ratio=3)
|
78 |
+
return tts_speech_token
|
79 |
+
|
80 |
+
def token_to_speech(self, tts_speech_token, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
81 |
+
prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
|
82 |
+
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
83 |
+
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
|
84 |
+
prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
|
85 |
+
|
86 |
+
tts_mel = self.flow.inference(token=tts_speech_token,
|
87 |
+
token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
|
88 |
+
prompt_token=flow_prompt_speech_token.to(self.device),
|
89 |
+
prompt_token_len=flow_prompt_speech_token_len.to(self.device),
|
90 |
+
prompt_feat=prompt_speech_feat.to(self.device),
|
91 |
+
prompt_feat_len=prompt_speech_feat_len.to(self.device),
|
92 |
+
embedding=flow_embedding.to(self.device))
|
93 |
+
tts_speech = self.hift.inference(mel=tts_mel).cpu()
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
return {'tts_speech': tts_speech}
|
third_party/GLM-4-Voice/cosyvoice/dataset/__init__.py
ADDED
File without changes
|
third_party/GLM-4-Voice/cosyvoice/dataset/dataset.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import random
|
17 |
+
import json
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import IterableDataset
|
24 |
+
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
25 |
+
|
26 |
+
|
27 |
+
class Processor(IterableDataset):
|
28 |
+
|
29 |
+
def __init__(self, source, f, *args, **kw):
|
30 |
+
assert callable(f)
|
31 |
+
self.source = source
|
32 |
+
self.f = f
|
33 |
+
self.args = args
|
34 |
+
self.kw = kw
|
35 |
+
|
36 |
+
def set_epoch(self, epoch):
|
37 |
+
self.source.set_epoch(epoch)
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
""" Return an iterator over the source dataset processed by the
|
41 |
+
given processor.
|
42 |
+
"""
|
43 |
+
assert self.source is not None
|
44 |
+
assert callable(self.f)
|
45 |
+
return self.f(iter(self.source), *self.args, **self.kw)
|
46 |
+
|
47 |
+
def apply(self, f):
|
48 |
+
assert callable(f)
|
49 |
+
return Processor(self, f, *self.args, **self.kw)
|
50 |
+
|
51 |
+
|
52 |
+
class DistributedSampler:
|
53 |
+
|
54 |
+
def __init__(self, shuffle=True, partition=True):
|
55 |
+
self.epoch = -1
|
56 |
+
self.update()
|
57 |
+
self.shuffle = shuffle
|
58 |
+
self.partition = partition
|
59 |
+
|
60 |
+
def update(self):
|
61 |
+
assert dist.is_available()
|
62 |
+
if dist.is_initialized():
|
63 |
+
self.rank = dist.get_rank()
|
64 |
+
self.world_size = dist.get_world_size()
|
65 |
+
else:
|
66 |
+
self.rank = 0
|
67 |
+
self.world_size = 1
|
68 |
+
worker_info = torch.utils.data.get_worker_info()
|
69 |
+
if worker_info is None:
|
70 |
+
self.worker_id = 0
|
71 |
+
self.num_workers = 1
|
72 |
+
else:
|
73 |
+
self.worker_id = worker_info.id
|
74 |
+
self.num_workers = worker_info.num_workers
|
75 |
+
return dict(rank=self.rank,
|
76 |
+
world_size=self.world_size,
|
77 |
+
worker_id=self.worker_id,
|
78 |
+
num_workers=self.num_workers)
|
79 |
+
|
80 |
+
def set_epoch(self, epoch):
|
81 |
+
self.epoch = epoch
|
82 |
+
|
83 |
+
def sample(self, data):
|
84 |
+
""" Sample data according to rank/world_size/num_workers
|
85 |
+
|
86 |
+
Args:
|
87 |
+
data(List): input data list
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
List: data list after sample
|
91 |
+
"""
|
92 |
+
data = list(range(len(data)))
|
93 |
+
# force datalist even
|
94 |
+
if self.partition:
|
95 |
+
if self.shuffle:
|
96 |
+
random.Random(self.epoch).shuffle(data)
|
97 |
+
if len(data) < self.world_size:
|
98 |
+
data = data * math.ceil(self.world_size / len(data))
|
99 |
+
data = data[:self.world_size]
|
100 |
+
data = data[self.rank::self.world_size]
|
101 |
+
if len(data) < self.num_workers:
|
102 |
+
data = data * math.ceil(self.num_workers / len(data))
|
103 |
+
data = data[:self.num_workers]
|
104 |
+
data = data[self.worker_id::self.num_workers]
|
105 |
+
return data
|
106 |
+
|
107 |
+
|
108 |
+
class DataList(IterableDataset):
|
109 |
+
|
110 |
+
def __init__(self, lists, shuffle=True, partition=True):
|
111 |
+
self.lists = lists
|
112 |
+
self.sampler = DistributedSampler(shuffle, partition)
|
113 |
+
|
114 |
+
def set_epoch(self, epoch):
|
115 |
+
self.sampler.set_epoch(epoch)
|
116 |
+
|
117 |
+
def __iter__(self):
|
118 |
+
sampler_info = self.sampler.update()
|
119 |
+
indexes = self.sampler.sample(self.lists)
|
120 |
+
for index in indexes:
|
121 |
+
data = dict(src=self.lists[index])
|
122 |
+
data.update(sampler_info)
|
123 |
+
yield data
|
124 |
+
|
125 |
+
|
126 |
+
def Dataset(data_list_file,
|
127 |
+
data_pipeline,
|
128 |
+
mode='train',
|
129 |
+
shuffle=True,
|
130 |
+
partition=True,
|
131 |
+
tts_file='',
|
132 |
+
prompt_utt2data=''):
|
133 |
+
""" Construct dataset from arguments
|
134 |
+
|
135 |
+
We have two shuffle stage in the Dataset. The first is global
|
136 |
+
shuffle at shards tar/raw file level. The second is global shuffle
|
137 |
+
at training samples level.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
data_type(str): raw/shard
|
141 |
+
tokenizer (BaseTokenizer): tokenizer to tokenize
|
142 |
+
partition(bool): whether to do data partition in terms of rank
|
143 |
+
"""
|
144 |
+
assert mode in ['train', 'inference']
|
145 |
+
lists = read_lists(data_list_file)
|
146 |
+
# import pdb
|
147 |
+
# pdb.set_trace()
|
148 |
+
if mode == 'inference':
|
149 |
+
with open(tts_file) as f:
|
150 |
+
tts_data = json.load(f)
|
151 |
+
utt2lists = read_json_lists(prompt_utt2data)
|
152 |
+
# filter unnecessary file in inference mode
|
153 |
+
lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
|
154 |
+
dataset = DataList(lists,shuffle=shuffle,partition=partition)
|
155 |
+
if mode == 'inference':
|
156 |
+
# map partial arg tts_data in inference mode
|
157 |
+
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
158 |
+
for func in data_pipeline:
|
159 |
+
dataset = Processor(dataset, func, mode=mode)
|
160 |
+
return dataset
|
third_party/GLM-4-Voice/cosyvoice/dataset/processor.py
ADDED
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import logging
|
15 |
+
import random
|
16 |
+
import json
|
17 |
+
import tarfile
|
18 |
+
import json
|
19 |
+
import io
|
20 |
+
import pyarrow.parquet as pq
|
21 |
+
from io import BytesIO
|
22 |
+
import torch
|
23 |
+
import torchaudio
|
24 |
+
from torch.nn.utils.rnn import pad_sequence
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import tarfile
|
27 |
+
import json
|
28 |
+
import io
|
29 |
+
import wave
|
30 |
+
import numpy as np
|
31 |
+
import torchaudio
|
32 |
+
import os
|
33 |
+
import sys
|
34 |
+
import json
|
35 |
+
import random
|
36 |
+
import pickle
|
37 |
+
import argparse
|
38 |
+
import itertools
|
39 |
+
import mmap
|
40 |
+
import struct
|
41 |
+
import collections
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
import shutil
|
46 |
+
import multiprocessing as mp
|
47 |
+
from pathlib import Path
|
48 |
+
|
49 |
+
from tqdm import tqdm
|
50 |
+
from collections import defaultdict
|
51 |
+
from copy import deepcopy
|
52 |
+
from datetime import datetime
|
53 |
+
import pickle
|
54 |
+
|
55 |
+
from wids import wids
|
56 |
+
import math
|
57 |
+
|
58 |
+
torchaudio.set_audio_backend('soundfile')
|
59 |
+
|
60 |
+
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
|
61 |
+
|
62 |
+
try:
|
63 |
+
MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
|
64 |
+
GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
|
65 |
+
except:
|
66 |
+
MAIN_SPK_EMBEDDING=torch.zeros(1,192)
|
67 |
+
GPT_SPK_EMBEDDING=torch.zeros(1,192)
|
68 |
+
|
69 |
+
def parquet_opener(data, mode='train', tts_data={}):
|
70 |
+
""" Give url or local file, return file descriptor
|
71 |
+
Inplace operation.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
data(Iterable[str]): url or local file list
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Iterable[{src, stream}]
|
78 |
+
"""
|
79 |
+
for sample in data:
|
80 |
+
assert 'src' in sample
|
81 |
+
url = sample['src']
|
82 |
+
try:
|
83 |
+
df = pq.read_table(url).to_pandas()
|
84 |
+
for i in range(len(df)):
|
85 |
+
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
|
86 |
+
continue
|
87 |
+
sample.update(dict(df.loc[i]))
|
88 |
+
if mode == 'train':
|
89 |
+
# NOTE do not return sample directly, must initialize a new dict
|
90 |
+
yield {**sample}
|
91 |
+
else:
|
92 |
+
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
|
93 |
+
yield {**sample, 'tts_index': index, 'tts_text': text}
|
94 |
+
except Exception as ex:
|
95 |
+
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def parse_tar_header(header_bytes):
|
101 |
+
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
|
102 |
+
return TarHeader(*header)
|
103 |
+
|
104 |
+
TarHeader = collections.namedtuple(
|
105 |
+
"TarHeader",
|
106 |
+
[
|
107 |
+
"name",
|
108 |
+
"mode",
|
109 |
+
"uid",
|
110 |
+
"gid",
|
111 |
+
"size",
|
112 |
+
"mtime",
|
113 |
+
"chksum",
|
114 |
+
"typeflag",
|
115 |
+
"linkname",
|
116 |
+
"magic",
|
117 |
+
"version",
|
118 |
+
"uname",
|
119 |
+
"gname",
|
120 |
+
"devmajor",
|
121 |
+
"devminor",
|
122 |
+
"prefix",
|
123 |
+
],
|
124 |
+
)
|
125 |
+
|
126 |
+
class MMTar:
|
127 |
+
def __init__(self, file_path: Path | str):
|
128 |
+
self.stream = open(file_path, "rb")
|
129 |
+
self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
|
130 |
+
|
131 |
+
def __del__(self):
|
132 |
+
try:
|
133 |
+
self.mmap.close()
|
134 |
+
self.stream.close()
|
135 |
+
except: # noqa
|
136 |
+
pass
|
137 |
+
|
138 |
+
def get_at_offset(self, offset) -> tuple[str, bytes]:
|
139 |
+
header = parse_tar_header(self.mmap[offset : offset + 500])
|
140 |
+
name = header.name.decode("utf-8").strip("\x00")
|
141 |
+
start = offset + 512
|
142 |
+
end = start + int(header.size.decode("utf-8")[:-1], 8)
|
143 |
+
return name, self.mmap[start:end]
|
144 |
+
|
145 |
+
|
146 |
+
class Tar:
|
147 |
+
def __init__(self, path: Path):
|
148 |
+
self.tar = MMTar(path)
|
149 |
+
indices_path = path.with_suffix(".index")
|
150 |
+
self.index = pickle.loads(indices_path.read_bytes())
|
151 |
+
self.name_mapping = {}
|
152 |
+
for name, offset, _ in self.index:
|
153 |
+
self.name_mapping[name] = offset
|
154 |
+
|
155 |
+
def read(self, name: str) -> bytes:
|
156 |
+
return self.tar.get_at_offset(self.name_mapping[name])[1]
|
157 |
+
|
158 |
+
def cosy_jsonl_opener(data, mode='train', tts_data={}):
|
159 |
+
""" Give url or local file, return file descriptor
|
160 |
+
Inplace operation.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
data(Iterable[str]): url or local file list
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
Iterable[{src, stream}]
|
167 |
+
"""
|
168 |
+
for sample in data:
|
169 |
+
assert 'src' in sample
|
170 |
+
cosy_jsonl_path = sample['src']
|
171 |
+
tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
|
172 |
+
try:
|
173 |
+
tar_data=Tar(Path(tar_file_path))
|
174 |
+
with open(cosy_jsonl_path, 'r') as f:
|
175 |
+
for line in f:
|
176 |
+
item=json.loads(line)
|
177 |
+
cosy_token = item['cosy_token']
|
178 |
+
sample['speech_token']=torch.tensor(cosy_token)
|
179 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
|
180 |
+
# print(item['filename'])
|
181 |
+
yield {**sample}
|
182 |
+
|
183 |
+
except Exception as ex:
|
184 |
+
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
|
185 |
+
|
186 |
+
|
187 |
+
def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
|
188 |
+
""" Give url or local file, return file descriptor
|
189 |
+
Inplace operation.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
data(Iterable[str]): url or local file list
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
Iterable[{src, stream}]
|
196 |
+
"""
|
197 |
+
for sample in data:
|
198 |
+
assert 'src' in sample
|
199 |
+
cosy_jsonl_path = sample['src']
|
200 |
+
tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")
|
201 |
+
|
202 |
+
|
203 |
+
try:
|
204 |
+
tar_data=Tar(Path(tar_file_path))
|
205 |
+
with open(cosy_jsonl_path, 'r') as f:
|
206 |
+
# cosy_data = [json.loads(line) for line in f]
|
207 |
+
for line in f:
|
208 |
+
item=json.loads(line)
|
209 |
+
cosy_token = item['cosy_token']
|
210 |
+
sample['speech_token']=torch.tensor(cosy_token)
|
211 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
|
212 |
+
# print(item['filename'])
|
213 |
+
yield {**sample}
|
214 |
+
|
215 |
+
except Exception as ex:
|
216 |
+
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
|
221 |
+
""" Give url or local file, return file descriptor
|
222 |
+
Inplace operation.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
data(Iterable[str]): url or local file list
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Iterable[{src, stream}]
|
229 |
+
"""
|
230 |
+
for sample in data:
|
231 |
+
assert 'src' in sample
|
232 |
+
cosy_jsonl_path = sample['src']
|
233 |
+
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")
|
234 |
+
|
235 |
+
try:
|
236 |
+
tar_data=Tar(Path(tar_file_path))
|
237 |
+
with open(cosy_jsonl_path, 'r') as f:
|
238 |
+
for line in f:
|
239 |
+
item=json.loads(line)
|
240 |
+
cosy_token = item['cosy_token']
|
241 |
+
sample['speech_token']=torch.tensor(cosy_token)
|
242 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
|
243 |
+
|
244 |
+
yield {**sample}
|
245 |
+
|
246 |
+
except Exception as ex:
|
247 |
+
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
|
248 |
+
|
249 |
+
|
250 |
+
def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
|
251 |
+
""" Give url or local file, return file descriptor
|
252 |
+
Inplace operation.
|
253 |
+
|
254 |
+
Args:
|
255 |
+
data(Iterable[str]): url or local file list
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Iterable[{src, stream}]
|
259 |
+
"""
|
260 |
+
for sample in data:
|
261 |
+
assert 'src' in sample
|
262 |
+
cosy_jsonl_path = sample['src']
|
263 |
+
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
|
264 |
+
try:
|
265 |
+
tar_data=Tar(Path(tar_file_path))
|
266 |
+
with open(cosy_jsonl_path, 'r') as f:
|
267 |
+
# cosy_data = [json.loads(line) for line in f]
|
268 |
+
for line in f:
|
269 |
+
item=json.loads(line)
|
270 |
+
cosy_token = item['cosy_token']
|
271 |
+
sample['speech_token']=torch.tensor(cosy_token)
|
272 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
|
273 |
+
# print(item['filename'])
|
274 |
+
yield {**sample}
|
275 |
+
|
276 |
+
except Exception as ex:
|
277 |
+
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
|
278 |
+
|
279 |
+
|
280 |
+
def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
|
281 |
+
""" Give url or local file, return file descriptor
|
282 |
+
Inplace operation.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
data(Iterable[str]): url or local file list
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Iterable[{src, stream}]
|
289 |
+
"""
|
290 |
+
for sample in data:
|
291 |
+
assert 'src' in sample
|
292 |
+
cosy_jsonl_path = sample['src']
|
293 |
+
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")
|
294 |
+
|
295 |
+
try:
|
296 |
+
tar_data=Tar(Path(tar_file_path))
|
297 |
+
with open(cosy_jsonl_path, 'r') as f:
|
298 |
+
# cosy_data = [json.loads(line) for line in f]
|
299 |
+
for line in f:
|
300 |
+
item=json.loads(line)
|
301 |
+
cosy_token = item['cosy_token']
|
302 |
+
sample['speech_token']=torch.tensor(cosy_token)
|
303 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
|
304 |
+
# print(item['filename'])
|
305 |
+
yield {**sample}
|
306 |
+
|
307 |
+
except Exception as ex:
|
308 |
+
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
|
313 |
+
for sample in data:
|
314 |
+
assert 'src' in sample
|
315 |
+
|
316 |
+
token_npy_path = sample['src']
|
317 |
+
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
|
318 |
+
|
319 |
+
# wav_path,token_npy_path=sample['src'].split(' ')
|
320 |
+
try:
|
321 |
+
sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
322 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
323 |
+
if sample['speech'].shape[0] > 1:
|
324 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
325 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
326 |
+
yield {**sample}
|
327 |
+
except Exception as ex:
|
328 |
+
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
329 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
330 |
+
|
331 |
+
|
332 |
+
def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
|
333 |
+
for sample in data:
|
334 |
+
assert 'src' in sample
|
335 |
+
|
336 |
+
token_npy_path = sample['src']
|
337 |
+
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
|
338 |
+
|
339 |
+
# wav_path,token_npy_path=sample['src'].split(' ')
|
340 |
+
try:
|
341 |
+
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
342 |
+
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
343 |
+
# if sample['speech'].shape[0] > 1:
|
344 |
+
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
345 |
+
|
346 |
+
|
347 |
+
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
348 |
+
|
349 |
+
|
350 |
+
speech_token=torch.tensor(np.load(token_npy_path))
|
351 |
+
speech,sample_rate= torchaudio.load(wav_path)
|
352 |
+
# split_speech=int(split_token / 12.5 * sample_rate)
|
353 |
+
if speech.shape[0] > 1:
|
354 |
+
speech = speech.mean(dim=0, keepdim=True)
|
355 |
+
|
356 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
357 |
+
sample['sample_rate']=sample_rate
|
358 |
+
|
359 |
+
num_splits = (speech_token.size(0) + split_token - 1) // split_token
|
360 |
+
|
361 |
+
for split_id in range(num_splits):
|
362 |
+
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
|
363 |
+
end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
|
364 |
+
sample['speech_token']=speech_token[:end_token_idx]
|
365 |
+
sample['speech']=speech[:,:end_speech_idx]
|
366 |
+
print(sample['speech_token'].size(),sample['speech'].size())
|
367 |
+
yield {**sample}
|
368 |
+
except Exception as ex:
|
369 |
+
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
370 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
371 |
+
|
372 |
+
|
373 |
+
def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
|
374 |
+
for sample in data:
|
375 |
+
assert 'src' in sample
|
376 |
+
|
377 |
+
token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
|
378 |
+
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
|
379 |
+
|
380 |
+
# wav_path,token_npy_path=sample['src'].split(' ')
|
381 |
+
try:
|
382 |
+
sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
383 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
384 |
+
if sample['speech'].shape[0] > 1:
|
385 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
386 |
+
|
387 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
388 |
+
yield {**sample}
|
389 |
+
except Exception as ex:
|
390 |
+
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
391 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
392 |
+
|
393 |
+
|
394 |
+
def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
|
395 |
+
for sample in data:
|
396 |
+
assert 'src' in sample
|
397 |
+
|
398 |
+
token_npy_path = sample['src']
|
399 |
+
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
|
400 |
+
|
401 |
+
# wav_path,token_npy_path=sample['src'].split(' ')
|
402 |
+
try:
|
403 |
+
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
404 |
+
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
405 |
+
# if sample['speech'].shape[0] > 1:
|
406 |
+
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
407 |
+
|
408 |
+
|
409 |
+
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
410 |
+
|
411 |
+
|
412 |
+
speech_token=torch.tensor(np.load(token_npy_path))
|
413 |
+
speech,sample_rate= torchaudio.load(wav_path)
|
414 |
+
# split_speech=int(split_token / 12.5 * sample_rate)
|
415 |
+
if speech.shape[0] > 1:
|
416 |
+
speech = speech.mean(dim=0, keepdim=True)
|
417 |
+
|
418 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
419 |
+
sample['sample_rate']=sample_rate
|
420 |
+
|
421 |
+
num_splits = (speech_token.size(0) + split_token - 1) // split_token
|
422 |
+
|
423 |
+
for split_id in range(num_splits):
|
424 |
+
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
|
425 |
+
end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
|
426 |
+
sample['speech_token']=speech_token[:end_token_idx]
|
427 |
+
sample['speech']=speech[:,:end_speech_idx]
|
428 |
+
print(sample['speech_token'].size(),sample['speech'].size())
|
429 |
+
yield {**sample}
|
430 |
+
except Exception as ex:
|
431 |
+
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
432 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
433 |
+
|
434 |
+
def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
|
435 |
+
for sample in data:
|
436 |
+
assert 'src' in sample
|
437 |
+
try:
|
438 |
+
entry=json.loads(sample['src'])
|
439 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
440 |
+
|
441 |
+
for conv in entry["conversations"]:
|
442 |
+
if "response_wav" in conv:
|
443 |
+
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
|
444 |
+
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
|
445 |
+
sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
446 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
447 |
+
if sample['speech'].shape[0] > 1:
|
448 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
449 |
+
sample['spk_embedding']=spk_embedding
|
450 |
+
yield {**sample}
|
451 |
+
except Exception as ex:
|
452 |
+
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
453 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
454 |
+
|
455 |
+
|
456 |
+
def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
|
457 |
+
for sample in data:
|
458 |
+
assert 'src' in sample
|
459 |
+
try:
|
460 |
+
entry=json.loads(sample['src'])
|
461 |
+
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
|
462 |
+
|
463 |
+
for conv in entry["conversations"]:
|
464 |
+
if "response_wav" in conv:
|
465 |
+
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
|
466 |
+
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
|
467 |
+
sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
468 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
469 |
+
if sample['speech'].shape[0] > 1:
|
470 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
471 |
+
sample['spk_embedding']=spk_embedding
|
472 |
+
yield {**sample}
|
473 |
+
if "prompt_wav" in conv:
|
474 |
+
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
|
475 |
+
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
|
476 |
+
sample['speech_token']=torch.tensor(np.load(token_npy_path))
|
477 |
+
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
|
478 |
+
if sample['speech'].shape[0] > 1:
|
479 |
+
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
480 |
+
sample['spk_embedding']=spk_embedding
|
481 |
+
yield {**sample}
|
482 |
+
except Exception as ex:
|
483 |
+
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
|
484 |
+
logging.warning('Failed to open {}'.format(wav_path))
|
485 |
+
|
486 |
+
|
487 |
+
def filter(data,
|
488 |
+
max_length=10240,
|
489 |
+
min_length=10,
|
490 |
+
token_max_length=200,
|
491 |
+
token_min_length=1,
|
492 |
+
min_output_input_ratio=0.0005,
|
493 |
+
max_output_input_ratio=1,
|
494 |
+
mode='train'):
|
495 |
+
""" Filter sample according to feature and label length
|
496 |
+
Inplace operation.
|
497 |
+
|
498 |
+
Args::
|
499 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
500 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
501 |
+
min_length: drop utterance which is less than min_length(10ms)
|
502 |
+
token_max_length: drop utterance which is greater than
|
503 |
+
token_max_length, especially when use char unit for
|
504 |
+
english modeling
|
505 |
+
token_min_length: drop utterance which is
|
506 |
+
less than token_max_length
|
507 |
+
min_output_input_ratio: minimal ration of
|
508 |
+
token_length / feats_length(10ms)
|
509 |
+
max_output_input_ratio: maximum ration of
|
510 |
+
token_length / feats_length(10ms)
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
Iterable[{key, wav, label, sample_rate}]
|
514 |
+
"""
|
515 |
+
for sample in data:
|
516 |
+
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
517 |
+
# del sample['audio_data']
|
518 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
519 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
520 |
+
if num_frames < min_length:
|
521 |
+
continue
|
522 |
+
if num_frames > max_length:
|
523 |
+
continue
|
524 |
+
if len(sample['text_token']) < token_min_length:
|
525 |
+
continue
|
526 |
+
if len(sample['text_token']) > token_max_length:
|
527 |
+
continue
|
528 |
+
if len(sample['speech_token']) == 0:
|
529 |
+
continue
|
530 |
+
if num_frames != 0:
|
531 |
+
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
532 |
+
continue
|
533 |
+
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
534 |
+
continue
|
535 |
+
yield sample
|
536 |
+
|
537 |
+
|
538 |
+
def filter_speech_token(data,
|
539 |
+
max_length=10240,
|
540 |
+
min_length=10,
|
541 |
+
token_max_length=5000,
|
542 |
+
token_min_length=1,
|
543 |
+
min_output_input_ratio=0.0005,
|
544 |
+
max_output_input_ratio=30,
|
545 |
+
mode='train'):
|
546 |
+
""" Filter sample according to feature and label length
|
547 |
+
Inplace operation.
|
548 |
+
|
549 |
+
Args::
|
550 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
551 |
+
max_length: drop utterance which is greater than max_length(10ms)
|
552 |
+
min_length: drop utterance which is less than min_length(10ms)
|
553 |
+
token_max_length: drop utterance which is greater than
|
554 |
+
token_max_length, especially when use char unit for
|
555 |
+
english modeling
|
556 |
+
token_min_length: drop utterance which is
|
557 |
+
less than token_max_length
|
558 |
+
min_output_input_ratio: minimal ration of
|
559 |
+
token_length / feats_length(10ms)
|
560 |
+
max_output_input_ratio: maximum ration of
|
561 |
+
token_length / feats_length(10ms)
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
Iterable[{key, wav, label, sample_rate}]
|
565 |
+
"""
|
566 |
+
for sample in data:
|
567 |
+
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
568 |
+
# del sample['audio_data']
|
569 |
+
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
570 |
+
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
571 |
+
if num_frames < min_length:
|
572 |
+
continue
|
573 |
+
if num_frames > max_length:
|
574 |
+
continue
|
575 |
+
if len(sample['speech_token']) < token_min_length:
|
576 |
+
continue
|
577 |
+
if len(sample['speech_token']) > token_max_length:
|
578 |
+
continue
|
579 |
+
if len(sample['speech_token']) == 0:
|
580 |
+
continue
|
581 |
+
if num_frames != 0:
|
582 |
+
if len(sample['speech_token']) / num_frames < min_output_input_ratio:
|
583 |
+
continue
|
584 |
+
if len(sample['speech_token']) / num_frames > max_output_input_ratio:
|
585 |
+
continue
|
586 |
+
yield sample
|
587 |
+
|
588 |
+
|
589 |
+
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
590 |
+
""" Resample data.
|
591 |
+
Inplace operation.
|
592 |
+
|
593 |
+
Args:
|
594 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
595 |
+
resample_rate: target resample rate
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
Iterable[{key, wav, label, sample_rate}]
|
599 |
+
"""
|
600 |
+
for sample in data:
|
601 |
+
assert 'sample_rate' in sample
|
602 |
+
assert 'speech' in sample
|
603 |
+
sample_rate = sample['sample_rate']
|
604 |
+
waveform = sample['speech']
|
605 |
+
if sample_rate != resample_rate:
|
606 |
+
if sample_rate < min_sample_rate:
|
607 |
+
continue
|
608 |
+
sample['sample_rate'] = resample_rate
|
609 |
+
sample['speech'] = torchaudio.transforms.Resample(
|
610 |
+
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
611 |
+
max_val = sample['speech'].abs().max()
|
612 |
+
if max_val > 1:
|
613 |
+
sample['speech'] /= max_val
|
614 |
+
yield sample
|
615 |
+
|
616 |
+
|
617 |
+
def compute_fbank(data,
|
618 |
+
feat_extractor,
|
619 |
+
mode='train'):
|
620 |
+
""" Extract fbank
|
621 |
+
|
622 |
+
Args:
|
623 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
Iterable[{key, feat, label}]
|
627 |
+
"""
|
628 |
+
for sample in data:
|
629 |
+
assert 'sample_rate' in sample
|
630 |
+
assert 'speech' in sample
|
631 |
+
# assert 'utt' in sample
|
632 |
+
# assert 'text_token' in sample
|
633 |
+
waveform = sample['speech']
|
634 |
+
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
|
635 |
+
sample['speech_feat'] = mat
|
636 |
+
del sample['speech']
|
637 |
+
yield sample
|
638 |
+
|
639 |
+
|
640 |
+
def parse_embedding(data, normalize, mode='train'):
|
641 |
+
""" Parse utt_embedding/spk_embedding
|
642 |
+
|
643 |
+
Args:
|
644 |
+
data: Iterable[{key, wav, label, sample_rate}]
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
Iterable[{key, feat, label}]
|
648 |
+
"""
|
649 |
+
for sample in data:
|
650 |
+
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
651 |
+
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
652 |
+
if normalize:
|
653 |
+
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
654 |
+
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
655 |
+
yield sample
|
656 |
+
|
657 |
+
|
658 |
+
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
659 |
+
""" Decode text to chars or BPE
|
660 |
+
Inplace operation
|
661 |
+
|
662 |
+
Args:
|
663 |
+
data: Iterable[{key, wav, txt, sample_rate}]
|
664 |
+
|
665 |
+
Returns:
|
666 |
+
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
667 |
+
"""
|
668 |
+
tokenizer = get_tokenizer()
|
669 |
+
for sample in data:
|
670 |
+
assert 'text' in sample
|
671 |
+
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
672 |
+
if mode == 'inference':
|
673 |
+
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
|
674 |
+
yield sample
|
675 |
+
|
676 |
+
|
677 |
+
def shuffle(data, shuffle_size=10000, mode='train'):
|
678 |
+
""" Local shuffle the data
|
679 |
+
|
680 |
+
Args:
|
681 |
+
data: Iterable[{key, feat, label}]
|
682 |
+
shuffle_size: buffer size for shuffle
|
683 |
+
|
684 |
+
Returns:
|
685 |
+
Iterable[{key, feat, label}]
|
686 |
+
"""
|
687 |
+
buf = []
|
688 |
+
for sample in data:
|
689 |
+
buf.append(sample)
|
690 |
+
if len(buf) >= shuffle_size:
|
691 |
+
random.shuffle(buf)
|
692 |
+
for x in buf:
|
693 |
+
yield x
|
694 |
+
buf = []
|
695 |
+
# The sample left over
|
696 |
+
random.shuffle(buf)
|
697 |
+
for x in buf:
|
698 |
+
yield x
|
699 |
+
|
700 |
+
|
701 |
+
def sort(data, sort_size=500, mode='train'):
|
702 |
+
""" Sort the data by feature length.
|
703 |
+
Sort is used after shuffle and before batch, so we can group
|
704 |
+
utts with similar lengths into a batch, and `sort_size` should
|
705 |
+
be less than `shuffle_size`
|
706 |
+
|
707 |
+
Args:
|
708 |
+
data: Iterable[{key, feat, label}]
|
709 |
+
sort_size: buffer size for sort
|
710 |
+
|
711 |
+
Returns:
|
712 |
+
Iterable[{key, feat, label}]
|
713 |
+
"""
|
714 |
+
|
715 |
+
buf = []
|
716 |
+
for sample in data:
|
717 |
+
buf.append(sample)
|
718 |
+
if len(buf) >= sort_size:
|
719 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
720 |
+
for x in buf:
|
721 |
+
yield x
|
722 |
+
buf = []
|
723 |
+
# The sample left over
|
724 |
+
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
725 |
+
for x in buf:
|
726 |
+
yield x
|
727 |
+
|
728 |
+
|
729 |
+
def static_batch(data, batch_size=16):
|
730 |
+
""" Static batch the data by `batch_size`
|
731 |
+
|
732 |
+
Args:
|
733 |
+
data: Iterable[{key, feat, label}]
|
734 |
+
batch_size: batch size
|
735 |
+
|
736 |
+
Returns:
|
737 |
+
Iterable[List[{key, feat, label}]]
|
738 |
+
"""
|
739 |
+
buf = []
|
740 |
+
for sample in data:
|
741 |
+
buf.append(sample)
|
742 |
+
if len(buf) >= batch_size:
|
743 |
+
yield buf
|
744 |
+
buf = []
|
745 |
+
if len(buf) > 0:
|
746 |
+
yield buf
|
747 |
+
|
748 |
+
|
749 |
+
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
750 |
+
""" Dynamic batch the data until the total frames in batch
|
751 |
+
reach `max_frames_in_batch`
|
752 |
+
|
753 |
+
Args:
|
754 |
+
data: Iterable[{key, feat, label}]
|
755 |
+
max_frames_in_batch: max_frames in one batch
|
756 |
+
|
757 |
+
Returns:
|
758 |
+
Iterable[List[{key, feat, label}]]
|
759 |
+
"""
|
760 |
+
buf = []
|
761 |
+
longest_frames = 0
|
762 |
+
for sample in data:
|
763 |
+
assert 'speech_feat' in sample
|
764 |
+
assert isinstance(sample['speech_feat'], torch.Tensor)
|
765 |
+
new_sample_frames = sample['speech_feat'].size(0)
|
766 |
+
longest_frames = max(longest_frames, new_sample_frames)
|
767 |
+
frames_after_padding = longest_frames * (len(buf) + 1)
|
768 |
+
if frames_after_padding > max_frames_in_batch:
|
769 |
+
yield buf
|
770 |
+
buf = [sample]
|
771 |
+
longest_frames = new_sample_frames
|
772 |
+
else:
|
773 |
+
buf.append(sample)
|
774 |
+
if len(buf) > 0:
|
775 |
+
yield buf
|
776 |
+
|
777 |
+
|
778 |
+
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
779 |
+
""" Wrapper for static/dynamic batch
|
780 |
+
"""
|
781 |
+
if mode == 'inference':
|
782 |
+
return static_batch(data, 1)
|
783 |
+
else:
|
784 |
+
if batch_type == 'static':
|
785 |
+
return static_batch(data, batch_size)
|
786 |
+
elif batch_type == 'dynamic':
|
787 |
+
return dynamic_batch(data, max_frames_in_batch)
|
788 |
+
else:
|
789 |
+
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
790 |
+
|
791 |
+
|
792 |
+
def padding(data, use_spk_embedding, mode='train'):
|
793 |
+
""" Padding the data into training data
|
794 |
+
|
795 |
+
Args:
|
796 |
+
data: Iterable[List[{key, feat, label}]]
|
797 |
+
|
798 |
+
Returns:
|
799 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
800 |
+
"""
|
801 |
+
for sample in data:
|
802 |
+
assert isinstance(sample, list)
|
803 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
804 |
+
dtype=torch.int32)
|
805 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
806 |
+
|
807 |
+
utts = [sample[i]['utt'] for i in order]
|
808 |
+
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
809 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
810 |
+
speech_token = pad_sequence(speech_token,
|
811 |
+
batch_first=True,
|
812 |
+
padding_value=0)
|
813 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
814 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
815 |
+
speech_feat = pad_sequence(speech_feat,
|
816 |
+
batch_first=True,
|
817 |
+
padding_value=0)
|
818 |
+
text = [sample[i]['text'] for i in order]
|
819 |
+
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
820 |
+
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
821 |
+
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
|
822 |
+
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
823 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
824 |
+
batch = {
|
825 |
+
"utts": utts,
|
826 |
+
"speech_token": speech_token,
|
827 |
+
"speech_token_len": speech_token_len,
|
828 |
+
"speech_feat": speech_feat,
|
829 |
+
"speech_feat_len": speech_feat_len,
|
830 |
+
"text": text,
|
831 |
+
"text_token": text_token,
|
832 |
+
"text_token_len": text_token_len,
|
833 |
+
"utt_embedding": utt_embedding,
|
834 |
+
"spk_embedding": spk_embedding,
|
835 |
+
}
|
836 |
+
if mode == 'inference':
|
837 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
838 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
839 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
840 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
841 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
842 |
+
batch.update({'tts_text': tts_text,
|
843 |
+
'tts_index': tts_index,
|
844 |
+
'tts_text_token': tts_text_token,
|
845 |
+
'tts_text_token_len': tts_text_token_len})
|
846 |
+
if use_spk_embedding is True:
|
847 |
+
batch["embedding"] = batch["spk_embedding"]
|
848 |
+
else:
|
849 |
+
batch["embedding"] = batch["utt_embedding"]
|
850 |
+
yield batch
|
851 |
+
|
852 |
+
|
853 |
+
|
854 |
+
def padding_speech_token(data, use_spk_embedding, mode='train'):
|
855 |
+
""" Padding the data into training data
|
856 |
+
|
857 |
+
Args:
|
858 |
+
data: Iterable[List[{key, feat, label}]]
|
859 |
+
|
860 |
+
Returns:
|
861 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
862 |
+
"""
|
863 |
+
for sample in data:
|
864 |
+
assert isinstance(sample, list)
|
865 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
866 |
+
dtype=torch.int32)
|
867 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
868 |
+
|
869 |
+
# utts = [sample[i]['utt'] for i in order]
|
870 |
+
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
871 |
+
try:
|
872 |
+
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
|
873 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
874 |
+
speech_token = pad_sequence(speech_token,
|
875 |
+
batch_first=True,
|
876 |
+
padding_value=0)
|
877 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
878 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
879 |
+
speech_feat = pad_sequence(speech_feat,
|
880 |
+
batch_first=True,
|
881 |
+
padding_value=0)
|
882 |
+
batch = {
|
883 |
+
"speech_token": speech_token,
|
884 |
+
"speech_token_len": speech_token_len,
|
885 |
+
"speech_feat": speech_feat,
|
886 |
+
"speech_feat_len": speech_feat_len,
|
887 |
+
}
|
888 |
+
if mode == 'inference':
|
889 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
890 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
891 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
892 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
893 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
894 |
+
batch.update({'tts_text': tts_text,
|
895 |
+
'tts_index': tts_index,
|
896 |
+
'tts_text_token': tts_text_token,
|
897 |
+
'tts_text_token_len': tts_text_token_len})
|
898 |
+
# if use_spk_embedding is True:
|
899 |
+
# batch["embedding"] = batch["spk_embedding"]
|
900 |
+
# else:
|
901 |
+
# batch["embedding"] = batch["utt_embedding"]
|
902 |
+
batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
|
903 |
+
yield batch
|
904 |
+
except Exception as ex:
|
905 |
+
logging.warning(' ex info {}'.format(ex))
|
906 |
+
# assert False
|
907 |
+
|
908 |
+
|
909 |
+
|
910 |
+
def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
|
911 |
+
""" Padding the data into training data
|
912 |
+
|
913 |
+
Args:
|
914 |
+
data: Iterable[List[{key, feat, label}]]
|
915 |
+
|
916 |
+
Returns:
|
917 |
+
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
918 |
+
"""
|
919 |
+
for sample in data:
|
920 |
+
assert isinstance(sample, list)
|
921 |
+
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
|
922 |
+
dtype=torch.int32)
|
923 |
+
order = torch.argsort(speech_feat_len, descending=True)
|
924 |
+
|
925 |
+
# utts = [sample[i]['utt'] for i in order]
|
926 |
+
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
927 |
+
try:
|
928 |
+
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
|
929 |
+
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
930 |
+
speech_token = pad_sequence(speech_token,
|
931 |
+
batch_first=True,
|
932 |
+
padding_value=0)
|
933 |
+
speech_feat = [sample[i]['speech_feat'] for i in order]
|
934 |
+
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
935 |
+
speech_feat = pad_sequence(speech_feat,
|
936 |
+
batch_first=True,
|
937 |
+
padding_value=0)
|
938 |
+
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
939 |
+
batch = {
|
940 |
+
"speech_token": speech_token,
|
941 |
+
"speech_token_len": speech_token_len,
|
942 |
+
"speech_feat": speech_feat,
|
943 |
+
"speech_feat_len": speech_feat_len,
|
944 |
+
"spk_embedding": spk_embedding,
|
945 |
+
}
|
946 |
+
if mode == 'inference':
|
947 |
+
tts_text = [sample[i]['tts_text'] for i in order]
|
948 |
+
tts_index = [sample[i]['tts_index'] for i in order]
|
949 |
+
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
|
950 |
+
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
|
951 |
+
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
|
952 |
+
batch.update({'tts_text': tts_text,
|
953 |
+
'tts_index': tts_index,
|
954 |
+
'tts_text_token': tts_text_token,
|
955 |
+
'tts_text_token_len': tts_text_token_len})
|
956 |
+
# if use_spk_embedding is True:
|
957 |
+
# batch["embedding"] = batch["spk_embedding"]
|
958 |
+
# else:
|
959 |
+
# batch["embedding"] = batch["utt_embedding"]
|
960 |
+
# batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
|
961 |
+
batch["embedding"] = batch["spk_embedding"]
|
962 |
+
yield batch
|
963 |
+
except Exception as ex:
|
964 |
+
logging.warning(' ex info {}'.format(ex))
|
965 |
+
# assert False
|
third_party/GLM-4-Voice/cosyvoice/flow/decoder.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from einops import pack, rearrange, repeat
|
17 |
+
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
18 |
+
from matcha.models.components.transformer import BasicTransformerBlock
|
19 |
+
|
20 |
+
|
21 |
+
class ConditionalDecoder(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
in_channels,
|
25 |
+
out_channels,
|
26 |
+
channels=(256, 256),
|
27 |
+
dropout=0.05,
|
28 |
+
attention_head_dim=64,
|
29 |
+
n_blocks=1,
|
30 |
+
num_mid_blocks=2,
|
31 |
+
num_heads=4,
|
32 |
+
act_fn="snake",
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
36 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
channels = tuple(channels)
|
40 |
+
self.in_channels = in_channels
|
41 |
+
self.out_channels = out_channels
|
42 |
+
|
43 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
44 |
+
time_embed_dim = channels[0] * 4
|
45 |
+
self.time_mlp = TimestepEmbedding(
|
46 |
+
in_channels=in_channels,
|
47 |
+
time_embed_dim=time_embed_dim,
|
48 |
+
act_fn="silu",
|
49 |
+
)
|
50 |
+
self.down_blocks = nn.ModuleList([])
|
51 |
+
self.mid_blocks = nn.ModuleList([])
|
52 |
+
self.up_blocks = nn.ModuleList([])
|
53 |
+
|
54 |
+
output_channel = in_channels
|
55 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
56 |
+
input_channel = output_channel
|
57 |
+
output_channel = channels[i]
|
58 |
+
is_last = i == len(channels) - 1
|
59 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
60 |
+
transformer_blocks = nn.ModuleList(
|
61 |
+
[
|
62 |
+
BasicTransformerBlock(
|
63 |
+
dim=output_channel,
|
64 |
+
num_attention_heads=num_heads,
|
65 |
+
attention_head_dim=attention_head_dim,
|
66 |
+
dropout=dropout,
|
67 |
+
activation_fn=act_fn,
|
68 |
+
)
|
69 |
+
for _ in range(n_blocks)
|
70 |
+
]
|
71 |
+
)
|
72 |
+
downsample = (
|
73 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
74 |
+
)
|
75 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
76 |
+
|
77 |
+
for i in range(num_mid_blocks):
|
78 |
+
input_channel = channels[-1]
|
79 |
+
out_channels = channels[-1]
|
80 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
81 |
+
|
82 |
+
transformer_blocks = nn.ModuleList(
|
83 |
+
[
|
84 |
+
BasicTransformerBlock(
|
85 |
+
dim=output_channel,
|
86 |
+
num_attention_heads=num_heads,
|
87 |
+
attention_head_dim=attention_head_dim,
|
88 |
+
dropout=dropout,
|
89 |
+
activation_fn=act_fn,
|
90 |
+
)
|
91 |
+
for _ in range(n_blocks)
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
96 |
+
|
97 |
+
channels = channels[::-1] + (channels[0],)
|
98 |
+
for i in range(len(channels) - 1):
|
99 |
+
input_channel = channels[i] * 2
|
100 |
+
output_channel = channels[i + 1]
|
101 |
+
is_last = i == len(channels) - 2
|
102 |
+
resnet = ResnetBlock1D(
|
103 |
+
dim=input_channel,
|
104 |
+
dim_out=output_channel,
|
105 |
+
time_emb_dim=time_embed_dim,
|
106 |
+
)
|
107 |
+
transformer_blocks = nn.ModuleList(
|
108 |
+
[
|
109 |
+
BasicTransformerBlock(
|
110 |
+
dim=output_channel,
|
111 |
+
num_attention_heads=num_heads,
|
112 |
+
attention_head_dim=attention_head_dim,
|
113 |
+
dropout=dropout,
|
114 |
+
activation_fn=act_fn,
|
115 |
+
)
|
116 |
+
for _ in range(n_blocks)
|
117 |
+
]
|
118 |
+
)
|
119 |
+
upsample = (
|
120 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
121 |
+
if not is_last
|
122 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
123 |
+
)
|
124 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
125 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
126 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
127 |
+
self.initialize_weights()
|
128 |
+
|
129 |
+
|
130 |
+
def initialize_weights(self):
|
131 |
+
for m in self.modules():
|
132 |
+
if isinstance(m, nn.Conv1d):
|
133 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
134 |
+
if m.bias is not None:
|
135 |
+
nn.init.constant_(m.bias, 0)
|
136 |
+
elif isinstance(m, nn.GroupNorm):
|
137 |
+
nn.init.constant_(m.weight, 1)
|
138 |
+
nn.init.constant_(m.bias, 0)
|
139 |
+
elif isinstance(m, nn.Linear):
|
140 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
141 |
+
if m.bias is not None:
|
142 |
+
nn.init.constant_(m.bias, 0)
|
143 |
+
|
144 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
145 |
+
"""Forward pass of the UNet1DConditional model.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
149 |
+
mask (_type_): shape (batch_size, 1, time)
|
150 |
+
t (_type_): shape (batch_size)
|
151 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
152 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
153 |
+
|
154 |
+
Raises:
|
155 |
+
ValueError: _description_
|
156 |
+
ValueError: _description_
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
_type_: _description_
|
160 |
+
"""
|
161 |
+
|
162 |
+
t = self.time_embeddings(t)
|
163 |
+
t = self.time_mlp(t)
|
164 |
+
|
165 |
+
x = pack([x, mu], "b * t")[0]
|
166 |
+
|
167 |
+
if spks is not None:
|
168 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
169 |
+
x = pack([x, spks], "b * t")[0]
|
170 |
+
if cond is not None:
|
171 |
+
x = pack([x, cond], "b * t")[0]
|
172 |
+
|
173 |
+
hiddens = []
|
174 |
+
masks = [mask]
|
175 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
176 |
+
mask_down = masks[-1]
|
177 |
+
x = resnet(x, mask_down, t)
|
178 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
179 |
+
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
|
180 |
+
for transformer_block in transformer_blocks:
|
181 |
+
x = transformer_block(
|
182 |
+
hidden_states=x,
|
183 |
+
attention_mask=attn_mask,
|
184 |
+
timestep=t,
|
185 |
+
)
|
186 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
187 |
+
hiddens.append(x) # Save hidden states for skip connections
|
188 |
+
x = downsample(x * mask_down)
|
189 |
+
masks.append(mask_down[:, :, ::2])
|
190 |
+
masks = masks[:-1]
|
191 |
+
mask_mid = masks[-1]
|
192 |
+
|
193 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
194 |
+
x = resnet(x, mask_mid, t)
|
195 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
196 |
+
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
|
197 |
+
for transformer_block in transformer_blocks:
|
198 |
+
x = transformer_block(
|
199 |
+
hidden_states=x,
|
200 |
+
attention_mask=attn_mask,
|
201 |
+
timestep=t,
|
202 |
+
)
|
203 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
204 |
+
|
205 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
206 |
+
mask_up = masks.pop()
|
207 |
+
skip = hiddens.pop()
|
208 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
209 |
+
x = resnet(x, mask_up, t)
|
210 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
211 |
+
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
|
212 |
+
for transformer_block in transformer_blocks:
|
213 |
+
x = transformer_block(
|
214 |
+
hidden_states=x,
|
215 |
+
attention_mask=attn_mask,
|
216 |
+
timestep=t,
|
217 |
+
)
|
218 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
219 |
+
x = upsample(x * mask_up)
|
220 |
+
x = self.final_block(x, mask_up)
|
221 |
+
output = self.final_proj(x * mask_up)
|
222 |
+
return output * mask
|