shenyunhang commited on
Commit
52e4f53
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +164 -0
  2. LICENSE +180 -0
  3. README.md +278 -0
  4. app.py +378 -0
  5. configs/sts_finetune_stage1.yaml +273 -0
  6. configs/sts_finetune_stage2.yaml +273 -0
  7. evaluation/compute-acc-of-contain.py +85 -0
  8. evaluation/compute-cer.py +559 -0
  9. evaluation/compute-wer.py +553 -0
  10. evaluation/evaluate_asr.py +379 -0
  11. evaluation/evaluate_libritts.py +384 -0
  12. evaluation/evaluate_seedtts.py +394 -0
  13. evaluation/evaluate_sqa.py +451 -0
  14. evaluation/get_chat_template.py +59 -0
  15. requirements.txt +1 -0
  16. requirements_ds_gpu.txt +44 -0
  17. scripts/deepspeed/ds_config_zero1.json +61 -0
  18. scripts/deepspeed/ds_config_zero2.json +61 -0
  19. scripts/deepspeed/ds_config_zero2_no_optimizer.json +52 -0
  20. scripts/deepspeed/ds_config_zero2_offload.json +61 -0
  21. scripts/deepspeed/ds_config_zero3.json +63 -0
  22. scripts/deepspeed/ds_config_zero3_offload.json +75 -0
  23. scripts/deepspeed/evaluate_sts.sh +348 -0
  24. scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh +137 -0
  25. scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh +137 -0
  26. scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh +137 -0
  27. scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh +136 -0
  28. scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh +138 -0
  29. scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage2.sh +138 -0
  30. scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh +138 -0
  31. scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh +137 -0
  32. scripts/set_env_ds_gpu.sh +53 -0
  33. setup.py +12 -0
  34. third_party/GLM-4-Voice/.gitignore +4 -0
  35. third_party/GLM-4-Voice/.gitmodules +3 -0
  36. third_party/GLM-4-Voice/LICENSE +201 -0
  37. third_party/GLM-4-Voice/README.md +159 -0
  38. third_party/GLM-4-Voice/README_en.md +148 -0
  39. third_party/GLM-4-Voice/audio_process.py +93 -0
  40. third_party/GLM-4-Voice/cosyvoice/__init__.py +0 -0
  41. third_party/GLM-4-Voice/cosyvoice/bin/inference.py +114 -0
  42. third_party/GLM-4-Voice/cosyvoice/bin/train.py +140 -0
  43. third_party/GLM-4-Voice/cosyvoice/cli/__init__.py +0 -0
  44. third_party/GLM-4-Voice/cosyvoice/cli/cosyvoice.py +83 -0
  45. third_party/GLM-4-Voice/cosyvoice/cli/frontend.py +168 -0
  46. third_party/GLM-4-Voice/cosyvoice/cli/model.py +95 -0
  47. third_party/GLM-4-Voice/cosyvoice/dataset/__init__.py +0 -0
  48. third_party/GLM-4-Voice/cosyvoice/dataset/dataset.py +160 -0
  49. third_party/GLM-4-Voice/cosyvoice/dataset/processor.py +965 -0
  50. third_party/GLM-4-Voice/cosyvoice/flow/decoder.py +222 -0
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ #
164
+ *.sw*
LICENSE ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
2
+
3
+ License Terms of the VITA1.5:
4
+ --------------------------------------------------------------------
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7
+
8
+ - You agree to use the VITA1.5 only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
9
+
10
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
11
+
12
+ For avoidance of doubts, "Software" means the VITA1.5 model inference-enabling code, and weights made available under this license.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15
+
16
+
17
+ Other dependencies and licenses:
18
+
19
+
20
+ Open Source Model Licensed under the Apache License Version 2.0:
21
+ The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"), as model weights provided for the VITA1.5 Project hereunder is fine-tuned with the assistance of below model.
22
+
23
+ All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
24
+ --------------------------------------------------------------------
25
+ 1. Qwen2-7B-Instruct
26
+ Copyright 2024 Alibaba Cloud
27
+
28
+ Terms of the Apache License Version 2.0:
29
+ --------------------------------------------------------------------
30
+ Apache License
31
+
32
+ Version 2.0, January 2004
33
+
34
+ http://www.apache.org/licenses/
35
+
36
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
37
+ 1. Definitions.
38
+
39
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
40
+
41
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
42
+
43
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
44
+
45
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
46
+
47
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
48
+
49
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
50
+
51
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
52
+
53
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
54
+
55
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
56
+
57
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
58
+
59
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
60
+
61
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
62
+
63
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
64
+
65
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
66
+
67
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
68
+
69
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
70
+
71
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
72
+
73
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
74
+
75
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
76
+
77
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
78
+
79
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
80
+
81
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
82
+
83
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
84
+
85
+ END OF TERMS AND CONDITIONS
86
+
87
+
88
+ Open Source Model/Software Licensed under the Apache License Version 2.0:
89
+ The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
90
+ --------------------------------------------------------------------
91
+ 1. ModelLink
92
+ Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
93
+
94
+ A copy of the Apache License Version 2.0 is included in this file.
95
+
96
+
97
+ Open Source Model/Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
98
+ --------------------------------------------------------------------
99
+ 1. opencv
100
+ Copyright (C) 2000-2022, Intel Corporation, all rights reserved.
101
+ Copyright (C) 2009-2011, Willow Garage Inc., all rights reserved.
102
+ Copyright (C) 2009-2016, NVIDIA Corporation, all rights reserved.
103
+ Copyright (C) 2010-2013, Advanced Micro Devices, Inc., all rights reserved.
104
+ Copyright (C) 2015-2023, OpenCV Foundation, all rights reserved.
105
+ Copyright (C) 2008-2016, Itseez Inc., all rights reserved.
106
+ Copyright (C) 2019-2023, Xperience AI, all rights reserved.
107
+ Copyright (C) 2019-2022, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
108
+ Copyright (C) 2022-2023, Southern University of Science And Technology, all rights reserved.
109
+
110
+ A copy of the Apache 2.0 is included in this file.
111
+
112
+ For the license of other third party components, please refer to the following URL:
113
+ https://github.com/opencv/opencv/tree/4.10.0/3rdparty
114
+
115
+
116
+ Open Source Model/Software Licensed under the BSD 3-Clause License:
117
+ --------------------------------------------------------------------
118
+ 1. flask
119
+ Copyright 2010 Pallets
120
+
121
+ 2. flask-restful
122
+ Copyright (c) 2013, Twilio, Inc.
123
+ All rights reserved.
124
+
125
+
126
+ Terms of the BSD 3-Clause License:
127
+ --------------------------------------------------------------------
128
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
129
+
130
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
131
+
132
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
133
+
134
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
135
+
136
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
137
+
138
+
139
+
140
+ Open Source Model/Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
141
+ --------------------------------------------------------------------
142
+ 1. Megatron-LM
143
+ Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
144
+
145
+
146
+ A copy of the BSD 3-Clause is included in this file.
147
+
148
+ For the license of other third party components, please refer to the following URL:
149
+ https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE
150
+
151
+
152
+ Open Source Model/Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
153
+ --------------------------------------------------------------------
154
+ 1. MindSpeed
155
+ Copyright (c) 2024, Bytedance Inc.
156
+ Copyright (c) 2023, Huawei Technologies Co., Ltd
157
+ Copyright (c) 2022, NVIDIA CORPORATION.
158
+ All rights reserved.
159
+
160
+
161
+ A copy of the BSD 3-Clause is included in this file.
162
+
163
+ For the license of other third party components, please refer to the following URL:
164
+ https://gitee.com/ascend/MindSpeed/blob/master/LICENSE
165
+
166
+
167
+ Open Source Model/Software Licensed under the MIT License:
168
+ --------------------------------------------------------------------
169
+ 1. natsort
170
+ Copyright (c) 2012-2023 Seth M. Morton
171
+
172
+
173
+ Terms of the MIT License:
174
+ --------------------------------------------------------------------
175
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
176
+
177
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
178
+
179
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
180
+ VertiTab
README.md ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VITA-Audio: Fast Interleaved Audio-Text Token Generation for Efficient Large Speech-Language Model
2
+
3
+ <p align="center">
4
+ <img src="asset/VITA_audio_logos.png" width="50%" height="50%">
5
+ </p>
6
+
7
+ <p align="center">
8
+ <a href="https://arxiv.org/abs/2502.05177" target="_blank"><img src="https://img.shields.io/badge/VITA%20Audio-Report-b5212f.svg?logo=arxiv" /></a>
9
+ <a href="https://huggingface.co/collections/VITA-MLLM/vita-audio-680f036c174441e7cdf02575" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-ffc107?color=ffc107&logoColor=white" /></a>
10
+ </p>
11
+
12
+
13
+ ## :fire: News
14
+
15
+
16
+
17
+ * **`2025.05.06`** 🌟 We are proud to launch VITA-Audio, an end-to-end large speech model with fast audio-text token generation.
18
+
19
+
20
+ ## 📄 Contents <!-- omit in toc -->
21
+
22
+
23
+ - [Highlights](#-highlights)
24
+ - [Exhibition](#-exhibition)
25
+ - [Models](#-models)
26
+ - [Experimental Results](#-experimental-results)
27
+ - [Training](#-training)
28
+ - [Inference](#-inference)
29
+ - [Evaluation](#-evaluation)
30
+
31
+
32
+ ## ✨ Highlights
33
+
34
+ - **Low Latency**. VITA-Audio is the first end-to-end speech model capable of generating audio during the initial forward pass. By utilizing a set of 32 prefill tokens, VITA-Audio reduces the time required to generate the first audio token chunk from 217 ms to just 47 ms.
35
+ - **Fast Inference**. VITA-Audio achieves an inference speedup of 3-5x at the 7B parameter scale.
36
+ - **Open Source**. VITA-Audio is trained on **open-source data** only, consisting of 200k hours of publicly available audio.
37
+ - **Strong Performance**. VITA-Audio achieves competitive results on ASR,TTS and SQA benchmarks among cutting-edge models under 7B parameters.
38
+
39
+
40
+
41
+ ## 📌 Exhibition
42
+
43
+ ### Inference Acceleration
44
+ Model inference speed under different inference modes.
45
+
46
+ <p align="center">
47
+ <img src="./asset/qa_speed.gif" alt="demogif" width="48%" style="display: inline-block; margin-right: 2%;">
48
+ <img src="./asset/tts_speed.gif" alt="second_gif" width="48%" style="display: inline-block;">
49
+ </p>
50
+
51
+ ### Time to Generate the First Audio Segment In Streaming Inference
52
+ <div align="center">
53
+ <img width="400" alt="first audio generate time" src="https://github.com/user-attachments/assets/165f943e-ac53-443f-abba-e5eb1e0c0f40" />
54
+ </div>
55
+
56
+
57
+ ### Generated Audio Case
58
+
59
+
60
+
61
+ > 打南边来了个哑巴,腰里别了个喇叭;打北边来了个喇嘛,手里提了个獭犸。
62
+ > 提着獭犸的喇嘛要拿獭犸换别着喇叭的哑巴的喇叭;别着喇叭的哑巴不愿拿喇叭换提着獭玛的喇嘛的獭犸。
63
+ > 不知是别着喇叭的哑巴打了提着獭玛的喇嘛一喇叭;还是提着獭玛的喇嘛打了别着喇叭的哑巴一獭玛。
64
+ > 喇嘛回家炖獭犸;哑巴嘀嘀哒哒吹喇叭。
65
+
66
+ https://github.com/user-attachments/assets/38da791f-5d72-4d9c-a9b2-cec97c2f2b2b
67
+
68
+
69
+ ---
70
+
71
+ > To be or not to be--to live intensely and richly,
72
+ > merely to exist, that depends on ourselves. Let widen and intensify our relations.
73
+ > While we live, let live!
74
+
75
+ https://github.com/user-attachments/assets/fd478065-4041-4eb8-b331-0c03b304d853
76
+
77
+
78
+ ---
79
+
80
+ > The hair has been so little, don't think about it, go to bed early, for your hair. Good night!
81
+
82
+ https://github.com/user-attachments/assets/4cfe4742-e237-42bd-9f17-7935b2285799
83
+
84
+
85
+ ---
86
+ > 两个黄鹂鸣翠柳,
87
+ > 一行白鹭上青天。
88
+ > 窗含西岭千秋雪,
89
+ > 门泊东吴万里船。
90
+
91
+ https://github.com/user-attachments/assets/382620ee-bb2a-488e-9e00-71afd2342b56
92
+
93
+
94
+ ---
95
+ ## 🔔 Models
96
+
97
+ | Model | LLM Size | Huggingface Weights |
98
+ |-------------------------|----------|---------------------------------------------------------------|
99
+ | VITA-Audio-Boost | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Boost |
100
+ | VITA-Audio-Balance | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Balance |
101
+ | VITA-Audio-Plus-Vanilla | 7B | https://huggingface.co/VITA-MLLM/VITA-Audio-Plus-Vanilla |
102
+
103
+
104
+
105
+ ## 📈 Experimental Results
106
+ - **Comparison of Spoken Question Answering**.
107
+
108
+ ![Clipboard_Screenshot_1746531780](https://github.com/user-attachments/assets/3adcad15-0333-4b92-bfdf-b753b330a3e2)
109
+
110
+
111
+ - **Comparison of Text to Speech**.
112
+
113
+ ![image](https://github.com/user-attachments/assets/09cf8fd3-d7a5-4b77-be49-5a0ace308f3f)
114
+
115
+
116
+ - **Comparison of Automatic Speech Recognition**.
117
+
118
+ ![Clipboard_Screenshot_1746532039](https://github.com/user-attachments/assets/d950cae0-c065-4da9-b37a-a471d28158a0)
119
+
120
+ ![Clipboard_Screenshot_1746532022](https://github.com/user-attachments/assets/929f45cd-693a-4ff6-af73-ceec6e875706)
121
+
122
+
123
+
124
+ - **Effectiveness of Inference Acceleration**.
125
+
126
+
127
+ ![Clipboard_Screenshot_1746532167](https://github.com/user-attachments/assets/ad8b9e90-cd3c-4968-8653-998811a50006)
128
+
129
+ ![Image](https://github.com/user-attachments/assets/4aa5db8c-362d-4152-8090-92292b9a84c0)
130
+
131
+
132
+
133
+ ## 📔 Requirements and Installation
134
+
135
+ ### Prepare Environment
136
+ ```
137
+ docker pull shenyunhang/pytorch:24.11-py3_2024-1224
138
+ ```
139
+
140
+ ### Get the Code
141
+ ```
142
+ git clone https://github.com/VITA-MLLM/VITA-Audio.git
143
+ cd VITA-Audio
144
+ pip install -r requirements_ds_gpu.txt
145
+ pip install -e .
146
+ ```
147
+
148
+ ### Prepare Pre-trained Weight
149
+
150
+ #### LLM
151
+
152
+ - Download the LLM from https://huggingface.co/Qwen/Qwen2.5-7B-Instruct.
153
+ - Put it into '../models/Qwen/Qwen2.5-7B-Instruct/'
154
+
155
+ #### Audio Encoder and Audio Decoder
156
+
157
+ - Download the Audio Encoder from https://huggingface.co/THUDM/glm-4-voice-tokenizer.
158
+ - Put it into '../models/THUDM/glm-4-voice-tokenizer'
159
+
160
+ - Download the Audio Decoder from https://huggingface.co/THUDM/glm-4-voice-decoder.
161
+ - Put it into '../models/THUDM/glm-4-voice-decoder'
162
+
163
+
164
+ ### Data Format
165
+ #### **Speech QA Interleaved Data Format**
166
+
167
+ > This format shows how text and audio sequences are interleaved in a structured JSON conversation between a user and an assistant.
168
+
169
+ ```jsonc
170
+ {
171
+ "messages": [
172
+ {
173
+ "role": "user",
174
+ "content": "<|begin_of_audio|> audio_sequence <|end_of_audio|>"
175
+ },
176
+ {
177
+ "role": "assistant",
178
+ "content": "text_sequence_1 <|begin_of_audio|> audio_sequence_1 <|end_of_audio|> text_sequence_2 <|begin_of_audio|> audio_sequence_2 <|end_of_audio|>"
179
+ }
180
+ ]
181
+ }
182
+ ```
183
+
184
+ ## 🎲 Training
185
+
186
+
187
+ The following tutorial will take `VITA-Audio-Boost` as an example.
188
+
189
+ - To train `VITA-Audio-Balance` and other variants, you should modify the `text-audio-interval-ratio`.
190
+
191
+ VITA-Audio-Boost:
192
+ ```
193
+ --text-audio-interval-ratio 1 10 4 10 \
194
+ ```
195
+
196
+ VITA-Audio-Balance:
197
+ ```
198
+ --text-audio-interval-ratio 1 4 3 8 4 10 \
199
+ ```
200
+
201
+ - To train `VITA-Audio-Plus-*`, you should use the script like `scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice...`
202
+
203
+ ### Stage-1 (Audio-Text Alignment)
204
+
205
+ ```
206
+ bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
207
+ ```
208
+
209
+ The above script may need some adjustments.
210
+
211
+ - Set `ROOT_PATH` to your code root folder.
212
+ - Set `LOCAL_ROOT_PATH` to a temporary code root folder.
213
+ - Modify other variables as needed for your environment.
214
+
215
+ ### Stage-2 (Single MCTP Module Training)
216
+
217
+ ```
218
+ bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
219
+ ```
220
+
221
+ The above script may need some adjustments.
222
+
223
+ - Set `ROOT_PATH` to your code root folder.
224
+ - Set `LOCAL_ROOT_PATH` to a temporary code root folder.
225
+ - Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 1.
226
+ - Modify other variables as needed for your environment.
227
+
228
+ ### Stage-3 (Multiple MCTP Modules Training)
229
+
230
+ ```
231
+ bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh 8192 `date +'%Y%m%d_%H%M%S'`
232
+ ```
233
+
234
+ The above script may need some adjustments.
235
+
236
+ - Set `ROOT_PATH` to your code root folder.
237
+ - Set `LOCAL_ROOT_PATH` to a temporary code root folder.
238
+ - Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 2.
239
+ - Modify other variables as needed for your environment.
240
+
241
+ ### Stage-4 (Supervised Fine-tuning)
242
+
243
+ ```
244
+ bash scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh 2048 `date +'%Y%m%d_%H%M%S'`
245
+ ```
246
+
247
+ The above script may need some adjustments.
248
+
249
+ - Set `ROOT_PATH` to your code root folder.
250
+ - Set `LOCAL_ROOT_PATH` to a temporary code root folder.
251
+ - Set `MODEL_NAME_OR_PATH` to the path of the model trained in Stage 3.
252
+ - Modify other variables as needed for your environment.
253
+
254
+
255
+
256
+ ## 📐 Inference
257
+
258
+ Here we implement a simple script for inference.
259
+
260
+ It includes examples of speech-to-speech, ASR, and TTS tasks, as well as inference speed testing.
261
+
262
+ ```
263
+ python tools/inference_sts.py
264
+ ```
265
+
266
+ - Set `model_name_or_path` to VITA-Audio weights.
267
+ - Set `audio_tokenizer_path` to the path of the audio encoder.
268
+ - Set `flow_path` to the path of the audio decoder.
269
+
270
+
271
+ ## 🔎 Evaluation
272
+
273
+ Evaluate SQA, ASR, and TTS benchmarks
274
+ ```
275
+ bash scripts/deepspeed/evaluate_sts.sh
276
+ ```
277
+
278
+
app.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import copy
5
+ import gradio as gr
6
+ import sys
7
+ from vita_audio.tokenizer import get_audio_tokenizer
8
+ from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
9
+
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig
12
+ from transformers.generation import GenerationConfig
13
+
14
+
15
+
16
+ PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
17
+
18
+
19
+ import math
20
+ from numba import jit
21
+
22
+ @jit
23
+ def float_to_int16(audio: np.ndarray) -> np.ndarray:
24
+ am = int(math.ceil(float(np.abs(audio).max())) * 32768)
25
+ am = 32767 * 32768 // am
26
+ return np.multiply(audio, am).astype(np.int16)
27
+
28
+
29
+ def is_wav(file_path):
30
+ wav_extensions = {'.wav'}
31
+ _, ext = os.path.splitext(file_path)
32
+ return ext.lower() in wav_extensions
33
+
34
+
35
+
36
+ def _parse_text(text):
37
+ lines = text.split("\n")
38
+ lines = [line for line in lines if line != ""]
39
+ count = 0
40
+
41
+ for i, line in enumerate(lines):
42
+ if "```" in line:
43
+ count += 1
44
+ items = line.split("`")
45
+ if count % 2 == 1:
46
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
47
+ else:
48
+ lines[i] = "<br></code></pre>"
49
+ else:
50
+ if i > 0 and count % 2 == 1:
51
+ line = line.replace("`", r"\`")
52
+ line = line.replace("<", "&lt;")
53
+ line = line.replace(">", "&gt;")
54
+ line = line.replace(" ", "&nbsp;")
55
+ line = line.replace("*", "&ast;")
56
+ line = line.replace("_", "&lowbar;")
57
+ line = line.replace("-", "&#45;")
58
+ line = line.replace(".", "&#46;")
59
+ line = line.replace("!", "&#33;")
60
+ line = line.replace("(", "&#40;")
61
+ line = line.replace(")", "&#41;")
62
+ line = line.replace("$", "&#36;")
63
+ lines[i] = "<br>" + line
64
+
65
+ return "".join(lines)
66
+
67
+
68
+
69
+ def _launch_demo(model, tokenizer, audio_tokenizer):
70
+ def predict(_chatbot, task_history,task):
71
+ chat_query = task_history[-1][0]
72
+ print(task_history)
73
+
74
+ messages = []
75
+
76
+ audio_path_list =[]
77
+ if task == 'Spoken QA':
78
+ messages = [
79
+ {
80
+ "role": "system",
81
+ #"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.",
82
+ # "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.",
83
+ "content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.",
84
+ },
85
+ ]
86
+ for i, (q, a) in enumerate(task_history):
87
+
88
+ if isinstance(q, (tuple, list)) and is_wav(q[0]):
89
+ audio_path_list.append(q[0])
90
+ messages = messages + [
91
+ {
92
+ "role": "user",
93
+ "content": f"\n<|audio|>",
94
+ },
95
+ ]
96
+ else:
97
+ messages = messages + [
98
+ {
99
+ "role": "user",
100
+ "content": q ,
101
+ },
102
+ ]
103
+ if a != None:
104
+ messages = messages + [
105
+ {
106
+ "role": "assistant",
107
+ "content": a ,
108
+ },
109
+ ]
110
+ model.generation_config.do_sample = False
111
+
112
+ elif task == 'TTS':
113
+ for i, (q, a) in enumerate(task_history):
114
+
115
+ if isinstance(q, (tuple, list)) and is_wav(q[0]):
116
+ audio_path_list.append(q[0])
117
+ messages = messages + [
118
+ {
119
+ "role": "user",
120
+ "content": f"\n<|audio|>",
121
+ },
122
+ ]
123
+ else:
124
+ messages = messages + [
125
+ {
126
+ "role": "user",
127
+ "content": f'Convert the text to speech.\n{q}' ,
128
+ },
129
+ ]
130
+ if a != None:
131
+ messages = messages + [
132
+ {
133
+ "role": "assistant",
134
+ "content": a ,
135
+ },
136
+ ]
137
+ model.generation_config.do_sample = True
138
+ elif task == 'ASR':
139
+ for i, (q, a) in enumerate(task_history):
140
+
141
+ if isinstance(q, (tuple, list)) and is_wav(q[0]):
142
+ audio_path_list.append(q[0])
143
+ messages = messages + [
144
+ {
145
+ "role": "user",
146
+ "content": f"Convert the speech to text.\n<|audio|>",
147
+ },
148
+ ]
149
+ else:
150
+ messages = messages + [
151
+ {
152
+ "role": "user",
153
+ "content": f"{q}" ,
154
+ },
155
+ ]
156
+ if a != None:
157
+ messages = messages + [
158
+ {
159
+ "role": "assistant",
160
+ "content": a ,
161
+ },
162
+ ]
163
+ model.generation_config.do_sample = False
164
+
165
+
166
+
167
+ add_generation_prompt =True
168
+ input_ids = tokenizer.apply_chat_template(
169
+ messages,
170
+ tokenize=True,
171
+ add_generation_prompt=add_generation_prompt,
172
+ # return_tensors="pt",
173
+ )
174
+
175
+
176
+ input_ids, audios, audio_indices = add_audio_input_contiguous(
177
+ input_ids, audio_path_list, tokenizer, audio_tokenizer
178
+ )
179
+
180
+
181
+ input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda")
182
+
183
+ # print("input", tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True)
184
+
185
+
186
+ if audio_path_list == []:
187
+ audios = None
188
+ audio_indices = None
189
+
190
+ outputs = model.generate(
191
+ input_ids,
192
+ audios=audios,
193
+ audio_indices=audio_indices,
194
+ )
195
+
196
+ output = tokenizer.decode(outputs[0], skip_special_tokens=False)
197
+ # print(f"{output=}", flush=True)
198
+
199
+ audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
200
+ begin_of_audio = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>")
201
+ end_of_audio = tokenizer.convert_tokens_to_ids("<|end_of_audio|>")
202
+ im_end = tokenizer.convert_tokens_to_ids("<|im_end|>")
203
+ response = outputs[0][len(input_ids[0]):]
204
+
205
+ audio_tokens = []
206
+ text_tokens = []
207
+ for token_id in response:
208
+ if token_id >= audio_offset:
209
+ audio_tokens.append(token_id - audio_offset)
210
+ elif (token_id.item() != begin_of_audio) and (token_id.item() != end_of_audio) and (token_id.item() != im_end):
211
+ text_tokens.append(token_id)
212
+
213
+ if len(audio_tokens) > 0:
214
+ tts_speech = audio_tokenizer.decode(audio_tokens)
215
+ audio_np = float_to_int16(tts_speech.cpu().numpy())
216
+ tts_speech = (22050,audio_np)
217
+ else:
218
+ tts_speech = None
219
+
220
+ # import pdb;pdb.set_trace()
221
+ history_response = tokenizer.decode(text_tokens)
222
+ task_history[-1] = (chat_query, history_response)
223
+
224
+ _chatbot[-1] = (chat_query, history_response)
225
+ # print("query",chat_query)
226
+ # print("task_history",task_history)
227
+ # print(_chatbot)
228
+ # print("answer: ",outputs)
229
+ return _chatbot, tts_speech
230
+
231
+
232
+
233
+ def add_text(history, task_history, text):
234
+ task_text = text
235
+ # import pdb;pdb.set_trace()
236
+ if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
237
+ task_text = text[:-1]
238
+ history = history + [(_parse_text(text), None)]
239
+ task_history = task_history + [(task_text, None)]
240
+ return history, task_history, ""
241
+
242
+
243
+ def add_audio(history, task_history, file):
244
+ print(file)
245
+ if file is None:
246
+ return history, task_history
247
+ history = history + [((file,), None)]
248
+ task_history = task_history + [((file,), None)]
249
+ return history, task_history
250
+
251
+
252
+
253
+
254
+ def reset_user_input():
255
+ # import pdb;pdb.set_trace()
256
+ return gr.update(value="")
257
+
258
+ def reset_state(task_history):
259
+ task_history.clear()
260
+ return []
261
+
262
+
263
+
264
+ with gr.Blocks(title="VITA-Audio-Plus-Vanilla") as demo:
265
+ gr.Markdown("""<center><font size=8>VITA-Audio-Plus-Vanilla</center>""")
266
+ gr.Markdown("""<center><font size=4>The deployment of the VITA-Audio-Plus-Vanilla model employs a non-streaming deployment approach. The currently deployed model is VITA-Audio-Plus-Vanilla. For the ASR and TTS tasks, only single-turn dialogues are supported. In the Spoken QA task, generated text is used as dialogue history to reduce the context length.</center>""")
267
+ chatbot = gr.Chatbot(label='VITA-Audio-Plus-Vanilla', elem_classes="control-height", height=500)
268
+ query = gr.Textbox(lines=2, label='Text Input')
269
+ task_history = gr.State([])
270
+ with gr.Row():
271
+ add_text_button = gr.Button("Submit Text (提交文本)")
272
+ add_audio_button = gr.Button("Submit Audio (提交音频)")
273
+ empty_bin = gr.Button("🧹 Clear History (清除历史)")
274
+ task = gr.Radio(
275
+ choices = ["ASR", "TTS", "Spoken QA"], label="TASK",value = 'Spoken QA'
276
+ )
277
+
278
+ with gr.Row(scale=1):
279
+
280
+ record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
281
+ audio_output = gr.Audio(label="Play", streaming=True,
282
+ autoplay=True, show_download_button=True)
283
+
284
+
285
+
286
+ add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then(
287
+ reset_user_input, [], [query]
288
+ ).then(
289
+ predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True
290
+ )
291
+
292
+
293
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
294
+
295
+
296
+ add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then(
297
+ predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True
298
+ )
299
+
300
+
301
+ server_port = 18806
302
+ demo.launch(
303
+ share=False,
304
+ debug=True,
305
+ server_name="0.0.0.0",
306
+ server_port=server_port,
307
+ show_api=False,
308
+ show_error=False,
309
+
310
+ )
311
+
312
+ def main():
313
+
314
+ model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla"
315
+
316
+ device_map = "cuda:0"
317
+
318
+ sys.path.append("third_party/GLM-4-Voice/")
319
+ sys.path.append("third_party/GLM-4-Voice/cosyvoice/")
320
+ sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/")
321
+
322
+ from huggingface_hub import snapshot_download
323
+ audio_tokenizer_path = snapshot_download(repo_id="THUDM/glm-4-voice-tokenizer")
324
+ flow_path = snapshot_download(repo_id="THUDM/glm-4-voice-decoder")
325
+
326
+ audio_tokenizer_rank = 0
327
+ audio_tokenizer_type = "sensevoice_glm4voice"
328
+
329
+ torch_dtype = torch.bfloat16
330
+ audio_tokenizer = get_audio_tokenizer(
331
+ audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path, rank=audio_tokenizer_rank
332
+ )
333
+ from evaluation.get_chat_template import qwen2_chat_template as chat_template
334
+
335
+ tokenizer = AutoTokenizer.from_pretrained(
336
+ model_name_or_path,
337
+ trust_remote_code=True,
338
+ chat_template=chat_template,
339
+ )
340
+ # print(f"{tokenizer=}")
341
+ # print(f"{tokenizer.get_chat_template()=}")
342
+
343
+
344
+ model = AutoModelForCausalLM.from_pretrained(
345
+ model_name_or_path,
346
+ trust_remote_code=True,
347
+ device_map=device_map,
348
+ torch_dtype=torch_dtype,
349
+ attn_implementation="flash_attention_2",
350
+ ).eval()
351
+
352
+ # print(f"{model.config.model_type=}")
353
+
354
+ model.generation_config = GenerationConfig.from_pretrained(
355
+ model_name_or_path, trust_remote_code=True
356
+ )
357
+
358
+ model.generation_config.max_new_tokens = 4096
359
+ model.generation_config.chat_format = "chatml"
360
+ model.generation_config.max_window_size = 8192
361
+ model.generation_config.use_cache = True
362
+ model.generation_config.do_sample = True
363
+ model.generation_config.temperature = 1.0
364
+ model.generation_config.top_k = 50
365
+ model.generation_config.top_p = 1.0
366
+ model.generation_config.num_beams = 1
367
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
368
+ model.generation_config.mtp_inference_mode = [8192,10]
369
+
370
+
371
+ _launch_demo(model, tokenizer, audio_tokenizer)
372
+
373
+
374
+
375
+
376
+ if __name__ == '__main__':
377
+
378
+ main()
configs/sts_finetune_stage1.yaml ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ xlsx_sample_num: 5
3
+
4
+ dataset:
5
+
6
+ wenet-e2e/wenetspeech:
7
+ ratio: 1.0
8
+ data_paths:
9
+ - datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
10
+ - datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
11
+
12
+ Wenetspeech4TTS/Wenetspeech4TTS:
13
+ ratio: 1.0
14
+ data_paths:
15
+ - datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
16
+
17
+ fixie-ai/librispeech_asr:
18
+ ratio: 1.0
19
+ data_paths:
20
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
21
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
22
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
23
+
24
+ mythicinfinity/libritts:
25
+ ratio: 1.0
26
+ data_paths:
27
+ - datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
28
+ - datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
29
+ - datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
30
+ - datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
31
+ - datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
32
+ - datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
33
+
34
+ parler-tts/mls_eng:
35
+ ratio: 1.0
36
+ data_paths:
37
+ #- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
38
+ - datasets/jsonl/parler-tts/mls_eng/train.jsonl
39
+
40
+ mozilla-foundation/common_voice_17_0:
41
+ ratio: 1.0
42
+ data_paths:
43
+ - datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
44
+ - datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
45
+
46
+ MushanW/GLOBE_V2:
47
+ ratio: 1.0
48
+ data_paths:
49
+ - datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
50
+
51
+ amphion/Emilia-Dataset:
52
+ ratio: 0.5
53
+ data_paths:
54
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
55
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
56
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
57
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
58
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
59
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
60
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
61
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
62
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
63
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
64
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
65
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
66
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
67
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
68
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
69
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
70
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
71
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
72
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
73
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
74
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
75
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
76
+
77
+ amphion/Emilia-Dataset/speaker_prompt:
78
+ ratio: 0.5
79
+ data_paths:
80
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
81
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
82
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
83
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
84
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
85
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
86
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
87
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
88
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
89
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
90
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
91
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
92
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
93
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
94
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
95
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
96
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
97
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
98
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
99
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
100
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
101
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
102
+
103
+ openslr:
104
+ ratio: 1.0
105
+ data_paths:
106
+ - datasets/jsonl/openslr/SLR68/train.jsonl
107
+ - datasets/jsonl/openslr/SLR68/dev.jsonl
108
+
109
+ speechcolab/gigaspeech:
110
+ ratio: 1.0
111
+ data_paths:
112
+ - datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
113
+ - datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
114
+
115
+ MLCommons/peoples_speech:
116
+ ratio: 1.0
117
+ data_paths:
118
+ - datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
119
+ - datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
120
+ - datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
121
+ - datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
122
+ - datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
123
+
124
+ facebook/voxpopuli:
125
+ ratio: 1.0
126
+ data_paths:
127
+ - datasets/jsonl/facebook/voxpopuli/en_train.jsonl
128
+ - datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
129
+
130
+ shenyunhang:
131
+ ratio: 1.0
132
+ data_paths:
133
+ - datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
134
+ - datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
135
+ - datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
136
+ - datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
137
+ - datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
138
+
139
+ gpt-omni/VoiceAssistant-400K:
140
+ ratio: 0.0
141
+ data_paths:
142
+ - datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
143
+
144
+ VITA-MLLM/AudioQA-1M:
145
+ ratio: 0.0
146
+ data_paths:
147
+ - datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
148
+
149
+ BAAI/Infinity-Instruct:
150
+ ratio: 1.0
151
+ data_paths:
152
+ #- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
153
+ #- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
154
+ #- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
155
+ - datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
156
+ #- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
157
+
158
+ OpenHermes:
159
+ ratio: 1.0
160
+ data_paths:
161
+ - datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
162
+
163
+ lima:
164
+ ratio: 1.0
165
+ data_paths:
166
+ - datasets/jsonl/GAIR/lima/train.jsonl
167
+
168
+ databricks-dolly-15k:
169
+ ratio: 1.0
170
+ data_paths:
171
+ - datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
172
+
173
+ MetaMathQA:
174
+ ratio: 1.0
175
+ data_paths:
176
+ - datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
177
+
178
+ MathInstruct:
179
+ ratio: 1.0
180
+ data_paths:
181
+ - datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
182
+
183
+ orca-math-word-problems-200k:
184
+ ratio: 1.0
185
+ data_paths:
186
+ - datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
187
+
188
+ atlas-math-sets:
189
+ ratio: 1.0
190
+ num: 100000
191
+ data_paths:
192
+ - datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
193
+
194
+ goat:
195
+ ratio: 1.0
196
+ num: 30000
197
+ data_paths:
198
+ - datasets/jsonl/tiedong/goat/dataset.jsonl
199
+
200
+ camel-ai:
201
+ ratio: 1.0
202
+ data_paths:
203
+ - datasets/jsonl/camel-ai/math/math.jsonl
204
+
205
+ Long-Instruction-with-Paraphrasing:
206
+ ratio: 1.0
207
+ data_paths:
208
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
209
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
210
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
211
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
212
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
213
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
214
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
215
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
216
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
217
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
218
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
219
+
220
+ Long:
221
+ ratio: 1.0
222
+ data_paths:
223
+ - datasets/jsonl/akoksal/LongForm/data.jsonl
224
+ - datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
225
+ - datasets/jsonl/THUDM/LongCite-45k/long.jsonl
226
+ - datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
227
+ - datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
228
+ - datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
229
+ - datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
230
+ - datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
231
+ - datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
232
+
233
+ open-thoughts/OpenThoughts2-1M:
234
+ ratio: 0.0
235
+ num: 200000
236
+ data_paths:
237
+ - datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
238
+
239
+ nvidia/Llama-Nemotron-Post-Training-Dataset:
240
+ ratio: 0.0
241
+ num: 200000
242
+ data_paths:
243
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
244
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
245
+ #- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
246
+ #- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
247
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
248
+
249
+ glaiveai/reasoning-v1-20m:
250
+ ratio: 0.0
251
+ num: 200000
252
+ data_paths:
253
+ - datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
254
+
255
+ nvidia/OpenCodeReasoning:
256
+ ratio: 0.0
257
+ num: 200000
258
+ data_paths:
259
+ - datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
260
+
261
+ Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
262
+ ratio: 0.0
263
+ num: 200000
264
+ data_paths:
265
+ - datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
266
+
267
+ open-r1/OpenR1-Math-220k:
268
+ ratio: 0.0
269
+ num: 200000
270
+ data_paths:
271
+ #- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
272
+ - datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
273
+ #- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
configs/sts_finetune_stage2.yaml ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ xlsx_sample_num: 5
3
+
4
+ dataset:
5
+
6
+ wenet-e2e/wenetspeech:
7
+ ratio: 0.05
8
+ data_paths:
9
+ - datasets/jsonl/wenet-e2e/wenetspeech/L_fixed.jsonl
10
+ - datasets/jsonl/wenet-e2e/wenetspeech/DEV_fixed.jsonl
11
+
12
+ Wenetspeech4TTS/Wenetspeech4TTS:
13
+ ratio: 0.05
14
+ data_paths:
15
+ - datasets/jsonl/Wenetspeech4TTS/WenetSpeech4TTS/Basic.jsonl
16
+
17
+ fixie-ai/librispeech_asr:
18
+ ratio: 0.05
19
+ data_paths:
20
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.100.clean.jsonl
21
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.360.clean.jsonl
22
+ - datasets/jsonl/fixie-ai/librispeech_asr/train.500.other.jsonl
23
+
24
+ mythicinfinity/libritts:
25
+ ratio: 0.05
26
+ data_paths:
27
+ - datasets/jsonl/mythicinfinity/libritts/train.clean.100.jsonl
28
+ - datasets/jsonl/mythicinfinity/libritts/train.clean.360.jsonl
29
+ - datasets/jsonl/mythicinfinity/libritts/train.other.500.jsonl
30
+ - datasets/jsonl/mythicinfinity/libritts_r/train.clean.100.jsonl
31
+ - datasets/jsonl/mythicinfinity/libritts_r/train.clean.360.jsonl
32
+ - datasets/jsonl/mythicinfinity/libritts_r/train.other.500.jsonl
33
+
34
+ parler-tts/mls_eng:
35
+ ratio: 0.05
36
+ data_paths:
37
+ #- datasets/jsonl/parler-tts/mls_eng_10k/train.jsonl
38
+ - datasets/jsonl/parler-tts/mls_eng/train.jsonl
39
+
40
+ mozilla-foundation/common_voice_17_0:
41
+ ratio: 0.05
42
+ data_paths:
43
+ - datasets/jsonl/mozilla-foundation/common_voice_17_0/en/train.jsonl
44
+ - datasets/jsonl/mozilla-foundation/common_voice_17_0/zh-CN/train.jsonl
45
+
46
+ MushanW/GLOBE_V2:
47
+ ratio: 0.05
48
+ data_paths:
49
+ - datasets/jsonl/MushanW/GLOBE_V2/train.jsonl
50
+
51
+ amphion/Emilia-Dataset:
52
+ ratio: 0.025
53
+ data_paths:
54
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100.jsonl
55
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200.jsonl
56
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300.jsonl
57
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400.jsonl
58
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500.jsonl
59
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600.jsonl
60
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700.jsonl
61
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800.jsonl
62
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900.jsonl
63
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000.jsonl
64
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100.jsonl
65
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200.jsonl
66
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300.jsonl
67
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400.jsonl
68
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500.jsonl
69
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600.jsonl
70
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700.jsonl
71
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800.jsonl
72
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900.jsonl
73
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000.jsonl
74
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100.jsonl
75
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200.jsonl
76
+
77
+ amphion/Emilia-Dataset/speaker_prompt:
78
+ ratio: 0.025
79
+ data_paths:
80
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000000_B000100_speak_prompt.jsonl
81
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000100_B000200_speak_prompt.jsonl
82
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000200_B000300_speak_prompt.jsonl
83
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000300_B000400_speak_prompt.jsonl
84
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000400_B000500_speak_prompt.jsonl
85
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000500_B000600_speak_prompt.jsonl
86
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000600_B000700_speak_prompt.jsonl
87
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000700_B000800_speak_prompt.jsonl
88
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000800_B000900_speak_prompt.jsonl
89
+ - datasets/jsonl/amphion/Emilia-Dataset/ZH_B000900_B001000_speak_prompt.jsonl
90
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000000_B000100_speak_prompt.jsonl
91
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000100_B000200_speak_prompt.jsonl
92
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000200_B000300_speak_prompt.jsonl
93
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000300_B000400_speak_prompt.jsonl
94
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000400_B000500_speak_prompt.jsonl
95
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000500_B000600_speak_prompt.jsonl
96
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000600_B000700_speak_prompt.jsonl
97
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000700_B000800_speak_prompt.jsonl
98
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000800_B000900_speak_prompt.jsonl
99
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B000900_B001000_speak_prompt.jsonl
100
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001000_B001100_speak_prompt.jsonl
101
+ - datasets/jsonl/amphion/Emilia-Dataset/EN_B001100_B001200_speak_prompt.jsonl
102
+
103
+ openslr:
104
+ ratio: 0.05
105
+ data_paths:
106
+ - datasets/jsonl/openslr/SLR68/train.jsonl
107
+ - datasets/jsonl/openslr/SLR68/dev.jsonl
108
+
109
+ speechcolab/gigaspeech:
110
+ ratio: 0.05
111
+ data_paths:
112
+ - datasets/jsonl/speechcolab/gigaspeech/xl.jsonl
113
+ - datasets/jsonl/speechcolab/gigaspeech/dev.jsonl
114
+
115
+ MLCommons/peoples_speech:
116
+ ratio: 0.05
117
+ data_paths:
118
+ - datasets/jsonl/MLCommons/peoples_speech/clean.jsonl
119
+ - datasets/jsonl/MLCommons/peoples_speech/clean_sa.jsonl
120
+ - datasets/jsonl/MLCommons/peoples_speech/dirty.jsonl
121
+ - datasets/jsonl/MLCommons/peoples_speech/dirty_sa.jsonl
122
+ - datasets/jsonl/MLCommons/peoples_speech/validation.jsonl
123
+
124
+ facebook/voxpopuli:
125
+ ratio: 0.05
126
+ data_paths:
127
+ - datasets/jsonl/facebook/voxpopuli/en_train.jsonl
128
+ - datasets/jsonl/facebook/voxpopuli/en_accented_test.jsonl
129
+
130
+ shenyunhang:
131
+ ratio: 0.05
132
+ data_paths:
133
+ - datasets/jsonl/shenyunhang/AISHELL-1/train.jsonl
134
+ - datasets/jsonl/shenyunhang/AISHELL-1/dev.jsonl
135
+ - datasets/jsonl/shenyunhang/AISHELL-2/data.jsonl
136
+ - datasets/jsonl/shenyunhang/AISHELL-3/data.jsonl
137
+ - datasets/jsonl/shenyunhang/AISHELL-4/data.jsonl
138
+
139
+ gpt-omni/VoiceAssistant-400K:
140
+ ratio: 2.0
141
+ data_paths:
142
+ - datasets/jsonl/gpt-omni/VoiceAssistant-400K/data.jsonl
143
+
144
+ VITA-MLLM/AudioQA-1M:
145
+ ratio: 2.0
146
+ data_paths:
147
+ - datasets/jsonl/VITA-MLLM/AudioQA-1M/data.jsonl
148
+
149
+ BAAI/Infinity-Instruct:
150
+ ratio: 0.05
151
+ data_paths:
152
+ #- datasets/jsonl/BAAI/Infinity-Instruct/3M.jsonl
153
+ #- datasets/jsonl/BAAI/Infinity-Instruct/7M.jsonl
154
+ #- datasets/jsonl/BAAI/Infinity-Instruct/7M_domains.jsonl
155
+ - datasets/jsonl/BAAI/Infinity-Instruct/0625.jsonl
156
+ #- datasets/jsonl/BAAI/Infinity-Instruct/Gen.jsonl
157
+
158
+ OpenHermes:
159
+ ratio: 0.05
160
+ data_paths:
161
+ - datasets/jsonl/teknium/OpenHermes-2.5/openhermes2_5.jsonl
162
+
163
+ lima:
164
+ ratio: 0.05
165
+ data_paths:
166
+ - datasets/jsonl/GAIR/lima/train.jsonl
167
+
168
+ databricks-dolly-15k:
169
+ ratio: 0.05
170
+ data_paths:
171
+ - datasets/jsonl/databricks/databricks-dolly-15k/databricks-dolly-15k.jsonl
172
+
173
+ MetaMathQA:
174
+ ratio: 0.05
175
+ data_paths:
176
+ - datasets/jsonl/meta-math/MetaMathQA/MetaMathQA-395K.jsonl
177
+
178
+ MathInstruct:
179
+ ratio: 0.05
180
+ data_paths:
181
+ - datasets/jsonl/TIGER-Lab/MathInstruct/MathInstruct.jsonl
182
+
183
+ orca-math-word-problems-200k:
184
+ ratio: 0.05
185
+ data_paths:
186
+ - datasets/jsonl/microsoft/orca-math-word-problems-200k/data.jsonl
187
+
188
+ atlas-math-sets:
189
+ ratio: 0.05
190
+ num: 100000
191
+ data_paths:
192
+ - datasets/jsonl/AtlasUnified/atlas-math-sets/train.jsonl
193
+
194
+ goat:
195
+ ratio: 0.05
196
+ num: 30000
197
+ data_paths:
198
+ - datasets/jsonl/tiedong/goat/dataset.jsonl
199
+
200
+ camel-ai:
201
+ ratio: 0.05
202
+ data_paths:
203
+ - datasets/jsonl/camel-ai/math/math.jsonl
204
+
205
+ Long-Instruction-with-Paraphrasing:
206
+ ratio: 0.05
207
+ data_paths:
208
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_en.jsonl
209
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_en_paraphrasing.jsonl
210
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_en.jsonl
211
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_alpaca_en.jsonl
212
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_en_paraphrasing.jsonl
213
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/translation_en2zh.jsonl
214
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/booksum_zh.jsonl
215
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/multi_doc_qa_zh_paraphrasing.jsonl
216
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/sharegpt_zh.jsonl
217
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/short_instruction_from_llama_chinese.jsonl
218
+ - datasets/jsonl/yuyijiong/Long-Instruction-with-Paraphrasing/single_doc_qa_zh_paraphrasing.jsonl
219
+
220
+ Long:
221
+ ratio: 0.05
222
+ data_paths:
223
+ - datasets/jsonl/akoksal/LongForm/data.jsonl
224
+ - datasets/jsonl/THUDM/LongAlign-10k/long.jsonl
225
+ - datasets/jsonl/THUDM/LongCite-45k/long.jsonl
226
+ - datasets/jsonl/THUDM/LongWriter-6k/long.jsonl
227
+ - datasets/jsonl/YeungNLP/LongQLoRA-Dataset/LongQLoRA-SFT-Data-39k.jsonl
228
+ - datasets/jsonl/Yukang/LongAlpaca-12k/LongAlpaca-12k.jsonl
229
+ - datasets/jsonl/togethercomputer/Long-Data-Collections/natural_questions_10_200_docs.jsonl
230
+ - datasets/jsonl/togethercomputer/Long-Data-Collections/booksum.jsonl
231
+ - datasets/jsonl/KnutJaegersberg/longinstruct/longinstruct.jsonl
232
+
233
+ open-thoughts/OpenThoughts2-1M:
234
+ ratio: 0.0
235
+ num: 10000
236
+ data_paths:
237
+ - datasets/jsonl/open-thoughts/OpenThoughts2-1M/data.jsonl
238
+
239
+ nvidia/Llama-Nemotron-Post-Training-Dataset:
240
+ ratio: 0.0
241
+ num: 10000
242
+ data_paths:
243
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_chat.jsonl
244
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_code.jsonl
245
+ #- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_math.jsonl
246
+ #- datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_safety.jsonl
247
+ - datasets/jsonl/nvidia/Llama-Nemotron-Post-Training-Dataset/SFT_science.jsonl
248
+
249
+ glaiveai/reasoning-v1-20m:
250
+ ratio: 0.0
251
+ num: 10000
252
+ data_paths:
253
+ - datasets/jsonl/glaiveai/reasoning-v1-20m/data.jsonl
254
+
255
+ nvidia/OpenCodeReasoning:
256
+ ratio: 0.0
257
+ num: 10000
258
+ data_paths:
259
+ - datasets/jsonl/nvidia/OpenCodeReasoning/split_0.jsonl
260
+
261
+ Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT:
262
+ ratio: 0.0
263
+ num: 10000
264
+ data_paths:
265
+ - datasets/jsonl/Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT/data.jsonl
266
+
267
+ open-r1/OpenR1-Math-220k:
268
+ ratio: 0.0
269
+ num: 10000
270
+ data_paths:
271
+ #- datasets/jsonl/open-r1/OpenR1-Math-220k/default.jsonl
272
+ - datasets/jsonl/open-r1/OpenR1-Math-220k/all.jsonl
273
+ #- datasets/jsonl/open-r1/OpenR1-Math-220k/extended.jsonl
evaluation/compute-acc-of-contain.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import json
5
+ import re
6
+ import string
7
+ import sys
8
+ import unicodedata
9
+
10
+ from word2number import w2n
11
+
12
+ # def is_list_in_string(text, candidate):
13
+ # return any([all(xx in text for xx in x.split(" ")) if isinstance(x, str) else all([xx in text for xx in x]) for x in candidate])
14
+
15
+
16
+ def is_string_in_string(text, candidate):
17
+ return all(x in text for x in candidate.split(" "))
18
+
19
+
20
+ def is_list_in_string(text, candidate):
21
+ return any(
22
+ [
23
+ is_string_in_string(text, x) if isinstance(x, str) else is_list_in_string(text, x)
24
+ for x in candidate
25
+ ]
26
+ )
27
+
28
+
29
+ def clean_punctuation(value):
30
+ punctuation = string.punctuation
31
+ punctuation = punctuation.replace("'", "")
32
+ value = re.sub(f"[{punctuation}]", " ", value)
33
+ return value
34
+
35
+
36
+ if __name__ == "__main__":
37
+
38
+ pred_gt_json_file = sys.argv[1]
39
+
40
+ with open(pred_gt_json_file, "r") as f:
41
+ pred_gt = json.load(f)
42
+
43
+ acc = 0
44
+ for line in pred_gt:
45
+
46
+ pred = line[0]
47
+ gt = line[1]
48
+
49
+ # pred = clean_punctuation(pred)
50
+ pred = pred.lower()
51
+
52
+ if isinstance(gt, list):
53
+ pass
54
+ else:
55
+ gt = [
56
+ gt,
57
+ ]
58
+ gt = [clean_punctuation(x) for x in gt]
59
+ gt = [x.lower().strip() for x in gt]
60
+
61
+ try:
62
+ gt_number = [str(w2n.word_to_num(x.lower())) for x in gt]
63
+ except:
64
+ gt_number = gt
65
+ pass
66
+
67
+ if is_list_in_string(pred, gt):
68
+ acc += 1
69
+ elif is_list_in_string(pred, gt_number):
70
+ acc += 1
71
+ else:
72
+ print("======================================================")
73
+ print(f"{line[0]=}")
74
+ print(f"{line[1]=}")
75
+
76
+ print("======================================================")
77
+ print(f"{acc=}")
78
+ print(f"{len(pred_gt)=}")
79
+ print("======================================================")
80
+
81
+ acc = acc / len(pred_gt) * 100
82
+
83
+ print("======================================================")
84
+ print(f"{acc=}")
85
+ print("======================================================")
evaluation/compute-cer.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import sys
5
+ import unicodedata
6
+ import codecs
7
+
8
+ remove_tag = True
9
+ spacelist = [' ', '\t', '\r', '\n']
10
+ puncts = [
11
+ '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
12
+ '《', '》'
13
+ ]
14
+
15
+
16
+ def characterize(string):
17
+ res = []
18
+ i = 0
19
+ while i < len(string):
20
+ char = string[i]
21
+ if char in puncts:
22
+ i += 1
23
+ continue
24
+ cat1 = unicodedata.category(char)
25
+ # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
26
+ if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
27
+ i += 1
28
+ continue
29
+ if cat1 == 'Lo': # letter-other
30
+ res.append(char)
31
+ i += 1
32
+ else:
33
+ # some input looks like: <unk><noise>, we want to separate it to two words.
34
+ sep = ' '
35
+ if char == '<':
36
+ sep = '>'
37
+ j = i + 1
38
+ while j < len(string):
39
+ c = string[j]
40
+ if ord(c) >= 128 or (c in spacelist) or (c == sep):
41
+ break
42
+ j += 1
43
+ if j < len(string) and string[j] == '>':
44
+ j += 1
45
+ res.append(string[i:j])
46
+ i = j
47
+ return res
48
+
49
+
50
+ def stripoff_tags(x):
51
+ if not x:
52
+ return ''
53
+ chars = []
54
+ i = 0
55
+ T = len(x)
56
+ while i < T:
57
+ if x[i] == '<':
58
+ while i < T and x[i] != '>':
59
+ i += 1
60
+ i += 1
61
+ else:
62
+ chars.append(x[i])
63
+ i += 1
64
+ return ''.join(chars)
65
+
66
+
67
+ def normalize(sentence, ignore_words, cs, split=None):
68
+ """ sentence, ignore_words are both in unicode
69
+ """
70
+ new_sentence = []
71
+ for token in sentence:
72
+ x = token
73
+ if not cs:
74
+ x = x.upper()
75
+ if x in ignore_words:
76
+ continue
77
+ if remove_tag:
78
+ x = stripoff_tags(x)
79
+ if not x:
80
+ continue
81
+ if split and x in split:
82
+ new_sentence += split[x]
83
+ if x.isalnum():
84
+ for k in x:
85
+ new_sentence.append(k)
86
+ else:
87
+ new_sentence.append(x)
88
+ return new_sentence
89
+
90
+
91
+ class Calculator:
92
+
93
+ def __init__(self):
94
+ self.data = {}
95
+ self.space = []
96
+ self.cost = {}
97
+ self.cost['cor'] = 0
98
+ self.cost['sub'] = 1
99
+ self.cost['del'] = 1
100
+ self.cost['ins'] = 1
101
+
102
+ def calculate(self, lab, rec):
103
+ # Initialization
104
+ lab.insert(0, '')
105
+ rec.insert(0, '')
106
+ while len(self.space) < len(lab):
107
+ self.space.append([])
108
+ for row in self.space:
109
+ for element in row:
110
+ element['dist'] = 0
111
+ element['error'] = 'non'
112
+ while len(row) < len(rec):
113
+ row.append({'dist': 0, 'error': 'non'})
114
+ for i in range(len(lab)):
115
+ self.space[i][0]['dist'] = i
116
+ self.space[i][0]['error'] = 'del'
117
+ for j in range(len(rec)):
118
+ self.space[0][j]['dist'] = j
119
+ self.space[0][j]['error'] = 'ins'
120
+ self.space[0][0]['error'] = 'non'
121
+ for token in lab:
122
+ if token not in self.data and len(token) > 0:
123
+ self.data[token] = {
124
+ 'all': 0,
125
+ 'cor': 0,
126
+ 'sub': 0,
127
+ 'ins': 0,
128
+ 'del': 0
129
+ }
130
+ for token in rec:
131
+ if token not in self.data and len(token) > 0:
132
+ self.data[token] = {
133
+ 'all': 0,
134
+ 'cor': 0,
135
+ 'sub': 0,
136
+ 'ins': 0,
137
+ 'del': 0
138
+ }
139
+ # Computing edit distance
140
+ for i, lab_token in enumerate(lab):
141
+ for j, rec_token in enumerate(rec):
142
+ if i == 0 or j == 0:
143
+ continue
144
+ min_dist = sys.maxsize
145
+ min_error = 'none'
146
+ dist = self.space[i - 1][j]['dist'] + self.cost['del']
147
+ error = 'del'
148
+ if dist < min_dist:
149
+ min_dist = dist
150
+ min_error = error
151
+ dist = self.space[i][j - 1]['dist'] + self.cost['ins']
152
+ error = 'ins'
153
+ if dist < min_dist:
154
+ min_dist = dist
155
+ min_error = error
156
+ if lab_token == rec_token:
157
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
158
+ error = 'cor'
159
+ else:
160
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
161
+ error = 'sub'
162
+ if dist < min_dist:
163
+ min_dist = dist
164
+ min_error = error
165
+ self.space[i][j]['dist'] = min_dist
166
+ self.space[i][j]['error'] = min_error
167
+ # Tracing back
168
+ result = {
169
+ 'lab': [],
170
+ 'rec': [],
171
+ 'all': 0,
172
+ 'cor': 0,
173
+ 'sub': 0,
174
+ 'ins': 0,
175
+ 'del': 0
176
+ }
177
+ i = len(lab) - 1
178
+ j = len(rec) - 1
179
+ while True:
180
+ if self.space[i][j]['error'] == 'cor': # correct
181
+ if len(lab[i]) > 0:
182
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
183
+ self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
184
+ result['all'] = result['all'] + 1
185
+ result['cor'] = result['cor'] + 1
186
+ result['lab'].insert(0, lab[i])
187
+ result['rec'].insert(0, rec[j])
188
+ i = i - 1
189
+ j = j - 1
190
+ elif self.space[i][j]['error'] == 'sub': # substitution
191
+ if len(lab[i]) > 0:
192
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
193
+ self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
194
+ result['all'] = result['all'] + 1
195
+ result['sub'] = result['sub'] + 1
196
+ result['lab'].insert(0, lab[i])
197
+ result['rec'].insert(0, rec[j])
198
+ i = i - 1
199
+ j = j - 1
200
+ elif self.space[i][j]['error'] == 'del': # deletion
201
+ if len(lab[i]) > 0:
202
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
203
+ self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
204
+ result['all'] = result['all'] + 1
205
+ result['del'] = result['del'] + 1
206
+ result['lab'].insert(0, lab[i])
207
+ result['rec'].insert(0, "")
208
+ i = i - 1
209
+ elif self.space[i][j]['error'] == 'ins': # insertion
210
+ if len(rec[j]) > 0:
211
+ self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
212
+ result['ins'] = result['ins'] + 1
213
+ result['lab'].insert(0, "")
214
+ result['rec'].insert(0, rec[j])
215
+ j = j - 1
216
+ elif self.space[i][j]['error'] == 'non': # starting point
217
+ break
218
+ else: # shouldn't reach here
219
+ print('this should not happen , i={i} , j={j} , \
220
+ error={error}'.format(i=i,
221
+ j=j,
222
+ error=self.space[i][j]['error']))
223
+ return result
224
+
225
+ def overall(self):
226
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
227
+ for token in self.data:
228
+ result['all'] = result['all'] + self.data[token]['all']
229
+ result['cor'] = result['cor'] + self.data[token]['cor']
230
+ result['sub'] = result['sub'] + self.data[token]['sub']
231
+ result['ins'] = result['ins'] + self.data[token]['ins']
232
+ result['del'] = result['del'] + self.data[token]['del']
233
+ return result
234
+
235
+ def cluster(self, data):
236
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
237
+ for token in data:
238
+ if token in self.data:
239
+ result['all'] = result['all'] + self.data[token]['all']
240
+ result['cor'] = result['cor'] + self.data[token]['cor']
241
+ result['sub'] = result['sub'] + self.data[token]['sub']
242
+ result['ins'] = result['ins'] + self.data[token]['ins']
243
+ result['del'] = result['del'] + self.data[token]['del']
244
+ return result
245
+
246
+ def keys(self):
247
+ return list(self.data.keys())
248
+
249
+
250
+ def width(string):
251
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
252
+
253
+
254
+ def default_cluster(word):
255
+ unicode_names = [unicodedata.name(char) for char in word]
256
+ for i in reversed(range(len(unicode_names))):
257
+ if unicode_names[i].startswith('DIGIT'): # 1
258
+ unicode_names[i] = 'Number' # 'DIGIT'
259
+ elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
260
+ or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
261
+ # 明 / 郎
262
+ unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
263
+ elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
264
+ or unicode_names[i].startswith('LATIN SMALL LETTER')):
265
+ # A / a
266
+ unicode_names[i] = 'English' # 'LATIN LETTER'
267
+ elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
268
+ unicode_names[i] = 'Japanese' # 'GANA LETTER'
269
+ elif (unicode_names[i].startswith('AMPERSAND')
270
+ or unicode_names[i].startswith('APOSTROPHE')
271
+ or unicode_names[i].startswith('COMMERCIAL AT')
272
+ or unicode_names[i].startswith('DEGREE CELSIUS')
273
+ or unicode_names[i].startswith('EQUALS SIGN')
274
+ or unicode_names[i].startswith('FULL STOP')
275
+ or unicode_names[i].startswith('HYPHEN-MINUS')
276
+ or unicode_names[i].startswith('LOW LINE')
277
+ or unicode_names[i].startswith('NUMBER SIGN')
278
+ or unicode_names[i].startswith('PLUS SIGN')
279
+ or unicode_names[i].startswith('SEMICOLON')):
280
+ # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
281
+ del unicode_names[i]
282
+ else:
283
+ return 'Other'
284
+ if len(unicode_names) == 0:
285
+ return 'Other'
286
+ if len(unicode_names) == 1:
287
+ return unicode_names[0]
288
+ for i in range(len(unicode_names) - 1):
289
+ if unicode_names[i] != unicode_names[i + 1]:
290
+ return 'Other'
291
+ return unicode_names[0]
292
+
293
+
294
+ def usage():
295
+ print("compute-wer.py : compute word error rate (WER) \
296
+ and align recognition results and references.")
297
+ print(" usage : python compute-wer.py [--cs={0,1}] \
298
+ [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] \
299
+ [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
300
+
301
+
302
+ if __name__ == '__main__':
303
+ if len(sys.argv) == 1:
304
+ usage()
305
+ sys.exit(0)
306
+ calculator = Calculator()
307
+ cluster_file = ''
308
+ ignore_words = set()
309
+ tochar = False
310
+ verbose = 1
311
+ padding_symbol = ' '
312
+ case_sensitive = False
313
+ max_words_per_line = sys.maxsize
314
+ split = None
315
+ while len(sys.argv) > 3:
316
+ a = '--maxw='
317
+ if sys.argv[1].startswith(a):
318
+ b = sys.argv[1][len(a):]
319
+ del sys.argv[1]
320
+ max_words_per_line = int(b)
321
+ continue
322
+ a = '--rt='
323
+ if sys.argv[1].startswith(a):
324
+ b = sys.argv[1][len(a):].lower()
325
+ del sys.argv[1]
326
+ remove_tag = (b == 'true') or (b != '0')
327
+ continue
328
+ a = '--cs='
329
+ if sys.argv[1].startswith(a):
330
+ b = sys.argv[1][len(a):].lower()
331
+ del sys.argv[1]
332
+ case_sensitive = (b == 'true') or (b != '0')
333
+ continue
334
+ a = '--cluster='
335
+ if sys.argv[1].startswith(a):
336
+ cluster_file = sys.argv[1][len(a):]
337
+ del sys.argv[1]
338
+ continue
339
+ a = '--splitfile='
340
+ if sys.argv[1].startswith(a):
341
+ split_file = sys.argv[1][len(a):]
342
+ del sys.argv[1]
343
+ split = dict()
344
+ with codecs.open(split_file, 'r', 'utf-8') as fh:
345
+ for line in fh: # line in unicode
346
+ words = line.strip().split()
347
+ if len(words) >= 2:
348
+ split[words[0]] = words[1:]
349
+ continue
350
+ a = '--ig='
351
+ if sys.argv[1].startswith(a):
352
+ ignore_file = sys.argv[1][len(a):]
353
+ del sys.argv[1]
354
+ with codecs.open(ignore_file, 'r', 'utf-8') as fh:
355
+ for line in fh: # line in unicode
356
+ line = line.strip()
357
+ if len(line) > 0:
358
+ ignore_words.add(line)
359
+ continue
360
+ a = '--char='
361
+ if sys.argv[1].startswith(a):
362
+ b = sys.argv[1][len(a):].lower()
363
+ del sys.argv[1]
364
+ tochar = (b == 'true') or (b != '0')
365
+ continue
366
+ a = '--v='
367
+ if sys.argv[1].startswith(a):
368
+ b = sys.argv[1][len(a):].lower()
369
+ del sys.argv[1]
370
+ verbose = 0
371
+ try:
372
+ verbose = int(b)
373
+ except Exception:
374
+ if b == 'true' or b != '0':
375
+ verbose = 1
376
+ continue
377
+ a = '--padding-symbol='
378
+ if sys.argv[1].startswith(a):
379
+ b = sys.argv[1][len(a):].lower()
380
+ del sys.argv[1]
381
+ if b == 'space':
382
+ padding_symbol = ' '
383
+ elif b == 'underline':
384
+ padding_symbol = '_'
385
+ continue
386
+ if True or sys.argv[1].startswith('-'):
387
+ # ignore invalid switch
388
+ del sys.argv[1]
389
+ continue
390
+
391
+ if not case_sensitive:
392
+ ig = set([w.upper() for w in ignore_words])
393
+ ignore_words = ig
394
+
395
+ default_clusters = {}
396
+ default_words = {}
397
+
398
+ ref_file = sys.argv[1]
399
+ hyp_file = sys.argv[2]
400
+ rec_set = {}
401
+ if split and not case_sensitive:
402
+ newsplit = dict()
403
+ for w in split:
404
+ words = split[w]
405
+ for i in range(len(words)):
406
+ words[i] = words[i].upper()
407
+ newsplit[w.upper()] = words
408
+ split = newsplit
409
+
410
+ with codecs.open(hyp_file, 'r', 'utf-8') as fh:
411
+ for line in fh:
412
+ if tochar:
413
+ array = characterize(line)
414
+ else:
415
+ array = line.strip().split()
416
+ if len(array) == 0:
417
+ continue
418
+ fid = array[0]
419
+ rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
420
+ split)
421
+
422
+ # compute error rate on the interaction of reference file and hyp file
423
+ for line in open(ref_file, 'r', encoding='utf-8'):
424
+ if tochar:
425
+ array = characterize(line)
426
+ else:
427
+ array = line.rstrip('\n').split()
428
+ if len(array) == 0:
429
+ continue
430
+ fid = array[0]
431
+ if fid not in rec_set:
432
+ continue
433
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
434
+ rec = rec_set[fid]
435
+ if verbose:
436
+ print('\nutt: %s' % fid)
437
+
438
+ for word in rec + lab:
439
+ if word not in default_words:
440
+ default_cluster_name = default_cluster(word)
441
+ if default_cluster_name not in default_clusters:
442
+ default_clusters[default_cluster_name] = {}
443
+ if word not in default_clusters[default_cluster_name]:
444
+ default_clusters[default_cluster_name][word] = 1
445
+ default_words[word] = default_cluster_name
446
+
447
+ result = calculator.calculate(lab, rec)
448
+ if verbose:
449
+ if result['all'] != 0:
450
+ wer = float(result['ins'] + result['sub'] +
451
+ result['del']) * 100.0 / result['all']
452
+ else:
453
+ wer = 0.0
454
+ print('WER: %4.2f %%' % wer, end=' ')
455
+ print('N=%d C=%d S=%d D=%d I=%d' %
456
+ (result['all'], result['cor'], result['sub'], result['del'],
457
+ result['ins']))
458
+ space = {}
459
+ space['lab'] = []
460
+ space['rec'] = []
461
+ for idx in range(len(result['lab'])):
462
+ len_lab = width(result['lab'][idx])
463
+ len_rec = width(result['rec'][idx])
464
+ length = max(len_lab, len_rec)
465
+ space['lab'].append(length - len_lab)
466
+ space['rec'].append(length - len_rec)
467
+ upper_lab = len(result['lab'])
468
+ upper_rec = len(result['rec'])
469
+ lab1, rec1 = 0, 0
470
+ while lab1 < upper_lab or rec1 < upper_rec:
471
+ if verbose > 1:
472
+ print('lab(%s):' % fid.encode('utf-8'), end=' ')
473
+ else:
474
+ print('lab:', end=' ')
475
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
476
+ for idx in range(lab1, lab2):
477
+ token = result['lab'][idx]
478
+ print('{token}'.format(token=token), end='')
479
+ for n in range(space['lab'][idx]):
480
+ print(padding_symbol, end='')
481
+ print(' ', end='')
482
+ print()
483
+ if verbose > 1:
484
+ print('rec(%s):' % fid.encode('utf-8'), end=' ')
485
+ else:
486
+ print('rec:', end=' ')
487
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
488
+ for idx in range(rec1, rec2):
489
+ token = result['rec'][idx]
490
+ print('{token}'.format(token=token), end='')
491
+ for n in range(space['rec'][idx]):
492
+ print(padding_symbol, end='')
493
+ print(' ', end='')
494
+ print('\n', end='\n')
495
+ lab1 = lab2
496
+ rec1 = rec2
497
+
498
+ if verbose:
499
+ print('==================================================='
500
+ '========================')
501
+ print()
502
+
503
+ result = calculator.overall()
504
+ if result['all'] != 0:
505
+ wer = float(result['ins'] + result['sub'] +
506
+ result['del']) * 100.0 / result['all']
507
+ else:
508
+ wer = 0.0
509
+ print('Overall -> %4.2f %%' % wer, end=' ')
510
+ print('N=%d C=%d S=%d D=%d I=%d' %
511
+ (result['all'], result['cor'], result['sub'], result['del'],
512
+ result['ins']))
513
+ if not verbose:
514
+ print()
515
+
516
+ if verbose:
517
+ for cluster_id in default_clusters:
518
+ result = calculator.cluster(k
519
+ for k in default_clusters[cluster_id])
520
+ if result['all'] != 0:
521
+ wer = float(result['ins'] + result['sub'] +
522
+ result['del']) * 100.0 / result['all']
523
+ else:
524
+ wer = 0.0
525
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
526
+ print('N=%d C=%d S=%d D=%d I=%d' %
527
+ (result['all'], result['cor'], result['sub'], result['del'],
528
+ result['ins']))
529
+ if len(cluster_file) > 0: # compute separated WERs for word clusters
530
+ cluster_id = ''
531
+ cluster = []
532
+ for line in open(cluster_file, 'r', encoding='utf-8'):
533
+ for token in line.decode('utf-8').rstrip('\n').split():
534
+ # end of cluster reached, like </Keyword>
535
+ if token[0:2] == '</' and token[len(token) - 1] == '>' and \
536
+ token.lstrip('</').rstrip('>') == cluster_id :
537
+ result = calculator.cluster(cluster)
538
+ if result['all'] != 0:
539
+ wer = float(result['ins'] + result['sub'] +
540
+ result['del']) * 100.0 / result['all']
541
+ else:
542
+ wer = 0.0
543
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
544
+ print('N=%d C=%d S=%d D=%d I=%d' %
545
+ (result['all'], result['cor'], result['sub'],
546
+ result['del'], result['ins']))
547
+ cluster_id = ''
548
+ cluster = []
549
+ # begin of cluster reached, like <Keyword>
550
+ elif (token[0] == '<' and token[len(token) - 1] == '>'
551
+ and cluster_id == ''):
552
+ cluster_id = token.lstrip('<').rstrip('>')
553
+ cluster = []
554
+ # general terms, like WEATHER / CAR / ...
555
+ else:
556
+ cluster.append(token)
557
+ print()
558
+ print('======================================='
559
+ '====================================')
evaluation/compute-wer.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import re, sys, unicodedata
5
+ import codecs
6
+
7
+ remove_tag = True
8
+ spacelist = [' ', '\t', '\r', '\n']
9
+ puncts = [
10
+ '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
11
+ '《', '》'
12
+ ]
13
+
14
+
15
+ def characterize(string):
16
+ res = []
17
+ i = 0
18
+ while i < len(string):
19
+ char = string[i]
20
+ if char in puncts:
21
+ i += 1
22
+ continue
23
+ cat1 = unicodedata.category(char)
24
+ #https://unicodebook.readthedocs.io/unicode.html#unicode-categories
25
+ if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
26
+ i += 1
27
+ continue
28
+ if cat1 == 'Lo': # letter-other
29
+ res.append(char)
30
+ i += 1
31
+ else:
32
+ # some input looks like: <unk><noise>, we want to separate it to two words.
33
+ sep = ' '
34
+ if char == '<': sep = '>'
35
+ j = i + 1
36
+ while j < len(string):
37
+ c = string[j]
38
+ if ord(c) >= 128 or (c in spacelist) or (c == sep):
39
+ break
40
+ j += 1
41
+ if j < len(string) and string[j] == '>':
42
+ j += 1
43
+ res.append(string[i:j])
44
+ i = j
45
+ return res
46
+
47
+
48
+ def stripoff_tags(x):
49
+ if not x: return ''
50
+ chars = []
51
+ i = 0
52
+ T = len(x)
53
+ while i < T:
54
+ if x[i] == '<':
55
+ while i < T and x[i] != '>':
56
+ i += 1
57
+ i += 1
58
+ else:
59
+ chars.append(x[i])
60
+ i += 1
61
+ return ''.join(chars)
62
+
63
+
64
+ def normalize(sentence, ignore_words, cs, split=None):
65
+ """ sentence, ignore_words are both in unicode
66
+ """
67
+ new_sentence = []
68
+ for token in sentence:
69
+ x = token
70
+ if not cs:
71
+ x = x.upper()
72
+ if x in ignore_words:
73
+ continue
74
+ if remove_tag:
75
+ x = stripoff_tags(x)
76
+ if not x:
77
+ continue
78
+ if split and x in split:
79
+ new_sentence += split[x]
80
+ else:
81
+ new_sentence.append(x)
82
+ return new_sentence
83
+
84
+
85
+ class Calculator:
86
+
87
+ def __init__(self):
88
+ self.data = {}
89
+ self.space = []
90
+ self.cost = {}
91
+ self.cost['cor'] = 0
92
+ self.cost['sub'] = 1
93
+ self.cost['del'] = 1
94
+ self.cost['ins'] = 1
95
+
96
+ def calculate(self, lab, rec):
97
+ # Initialization
98
+ lab.insert(0, '')
99
+ rec.insert(0, '')
100
+ while len(self.space) < len(lab):
101
+ self.space.append([])
102
+ for row in self.space:
103
+ for element in row:
104
+ element['dist'] = 0
105
+ element['error'] = 'non'
106
+ while len(row) < len(rec):
107
+ row.append({'dist': 0, 'error': 'non'})
108
+ for i in range(len(lab)):
109
+ self.space[i][0]['dist'] = i
110
+ self.space[i][0]['error'] = 'del'
111
+ for j in range(len(rec)):
112
+ self.space[0][j]['dist'] = j
113
+ self.space[0][j]['error'] = 'ins'
114
+ self.space[0][0]['error'] = 'non'
115
+ for token in lab:
116
+ if token not in self.data and len(token) > 0:
117
+ self.data[token] = {
118
+ 'all': 0,
119
+ 'cor': 0,
120
+ 'sub': 0,
121
+ 'ins': 0,
122
+ 'del': 0
123
+ }
124
+ for token in rec:
125
+ if token not in self.data and len(token) > 0:
126
+ self.data[token] = {
127
+ 'all': 0,
128
+ 'cor': 0,
129
+ 'sub': 0,
130
+ 'ins': 0,
131
+ 'del': 0
132
+ }
133
+ # Computing edit distance
134
+ for i, lab_token in enumerate(lab):
135
+ for j, rec_token in enumerate(rec):
136
+ if i == 0 or j == 0:
137
+ continue
138
+ min_dist = sys.maxsize
139
+ min_error = 'none'
140
+ dist = self.space[i - 1][j]['dist'] + self.cost['del']
141
+ error = 'del'
142
+ if dist < min_dist:
143
+ min_dist = dist
144
+ min_error = error
145
+ dist = self.space[i][j - 1]['dist'] + self.cost['ins']
146
+ error = 'ins'
147
+ if dist < min_dist:
148
+ min_dist = dist
149
+ min_error = error
150
+ if lab_token == rec_token:
151
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
152
+ error = 'cor'
153
+ else:
154
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
155
+ error = 'sub'
156
+ if dist < min_dist:
157
+ min_dist = dist
158
+ min_error = error
159
+ self.space[i][j]['dist'] = min_dist
160
+ self.space[i][j]['error'] = min_error
161
+ # Tracing back
162
+ result = {
163
+ 'lab': [],
164
+ 'rec': [],
165
+ 'all': 0,
166
+ 'cor': 0,
167
+ 'sub': 0,
168
+ 'ins': 0,
169
+ 'del': 0
170
+ }
171
+ i = len(lab) - 1
172
+ j = len(rec) - 1
173
+ while True:
174
+ if self.space[i][j]['error'] == 'cor': # correct
175
+ if len(lab[i]) > 0:
176
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
177
+ self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
178
+ result['all'] = result['all'] + 1
179
+ result['cor'] = result['cor'] + 1
180
+ result['lab'].insert(0, lab[i])
181
+ result['rec'].insert(0, rec[j])
182
+ i = i - 1
183
+ j = j - 1
184
+ elif self.space[i][j]['error'] == 'sub': # substitution
185
+ if len(lab[i]) > 0:
186
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
187
+ self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
188
+ result['all'] = result['all'] + 1
189
+ result['sub'] = result['sub'] + 1
190
+ result['lab'].insert(0, lab[i])
191
+ result['rec'].insert(0, rec[j])
192
+ i = i - 1
193
+ j = j - 1
194
+ elif self.space[i][j]['error'] == 'del': # deletion
195
+ if len(lab[i]) > 0:
196
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
197
+ self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
198
+ result['all'] = result['all'] + 1
199
+ result['del'] = result['del'] + 1
200
+ result['lab'].insert(0, lab[i])
201
+ result['rec'].insert(0, "")
202
+ i = i - 1
203
+ elif self.space[i][j]['error'] == 'ins': # insertion
204
+ if len(rec[j]) > 0:
205
+ self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
206
+ result['ins'] = result['ins'] + 1
207
+ result['lab'].insert(0, "")
208
+ result['rec'].insert(0, rec[j])
209
+ j = j - 1
210
+ elif self.space[i][j]['error'] == 'non': # starting point
211
+ break
212
+ else: # shouldn't reach here
213
+ print(
214
+ 'this should not happen , i = {i} , j = {j} , error = {error}'
215
+ .format(i=i, j=j, error=self.space[i][j]['error']))
216
+ return result
217
+
218
+ def overall(self):
219
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
220
+ for token in self.data:
221
+ result['all'] = result['all'] + self.data[token]['all']
222
+ result['cor'] = result['cor'] + self.data[token]['cor']
223
+ result['sub'] = result['sub'] + self.data[token]['sub']
224
+ result['ins'] = result['ins'] + self.data[token]['ins']
225
+ result['del'] = result['del'] + self.data[token]['del']
226
+ return result
227
+
228
+ def cluster(self, data):
229
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
230
+ for token in data:
231
+ if token in self.data:
232
+ result['all'] = result['all'] + self.data[token]['all']
233
+ result['cor'] = result['cor'] + self.data[token]['cor']
234
+ result['sub'] = result['sub'] + self.data[token]['sub']
235
+ result['ins'] = result['ins'] + self.data[token]['ins']
236
+ result['del'] = result['del'] + self.data[token]['del']
237
+ return result
238
+
239
+ def keys(self):
240
+ return list(self.data.keys())
241
+
242
+
243
+ def width(string):
244
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
245
+
246
+
247
+ def default_cluster(word):
248
+ unicode_names = [unicodedata.name(char) for char in word]
249
+ for i in reversed(range(len(unicode_names))):
250
+ if unicode_names[i].startswith('DIGIT'): # 1
251
+ unicode_names[i] = 'Number' # 'DIGIT'
252
+ elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
253
+ or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
254
+ # 明 / 郎
255
+ unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
256
+ elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
257
+ or unicode_names[i].startswith('LATIN SMALL LETTER')):
258
+ # A / a
259
+ unicode_names[i] = 'English' # 'LATIN LETTER'
260
+ elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
261
+ unicode_names[i] = 'Japanese' # 'GANA LETTER'
262
+ elif (unicode_names[i].startswith('AMPERSAND')
263
+ or unicode_names[i].startswith('APOSTROPHE')
264
+ or unicode_names[i].startswith('COMMERCIAL AT')
265
+ or unicode_names[i].startswith('DEGREE CELSIUS')
266
+ or unicode_names[i].startswith('EQUALS SIGN')
267
+ or unicode_names[i].startswith('FULL STOP')
268
+ or unicode_names[i].startswith('HYPHEN-MINUS')
269
+ or unicode_names[i].startswith('LOW LINE')
270
+ or unicode_names[i].startswith('NUMBER SIGN')
271
+ or unicode_names[i].startswith('PLUS SIGN')
272
+ or unicode_names[i].startswith('SEMICOLON')):
273
+ # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
274
+ del unicode_names[i]
275
+ else:
276
+ return 'Other'
277
+ if len(unicode_names) == 0:
278
+ return 'Other'
279
+ if len(unicode_names) == 1:
280
+ return unicode_names[0]
281
+ for i in range(len(unicode_names) - 1):
282
+ if unicode_names[i] != unicode_names[i + 1]:
283
+ return 'Other'
284
+ return unicode_names[0]
285
+
286
+
287
+ def usage():
288
+ print(
289
+ "compute-wer.py : compute word error rate (WER) and align recognition results and references."
290
+ )
291
+ print(
292
+ " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
293
+ )
294
+
295
+
296
+ if __name__ == '__main__':
297
+ if len(sys.argv) == 1:
298
+ usage()
299
+ sys.exit(0)
300
+ calculator = Calculator()
301
+ cluster_file = ''
302
+ ignore_words = set()
303
+ tochar = False
304
+ verbose = 1
305
+ padding_symbol = ' '
306
+ case_sensitive = False
307
+ max_words_per_line = sys.maxsize
308
+ split = None
309
+ while len(sys.argv) > 3:
310
+ a = '--maxw='
311
+ if sys.argv[1].startswith(a):
312
+ b = sys.argv[1][len(a):]
313
+ del sys.argv[1]
314
+ max_words_per_line = int(b)
315
+ continue
316
+ a = '--rt='
317
+ if sys.argv[1].startswith(a):
318
+ b = sys.argv[1][len(a):].lower()
319
+ del sys.argv[1]
320
+ remove_tag = (b == 'true') or (b != '0')
321
+ continue
322
+ a = '--cs='
323
+ if sys.argv[1].startswith(a):
324
+ b = sys.argv[1][len(a):].lower()
325
+ del sys.argv[1]
326
+ case_sensitive = (b == 'true') or (b != '0')
327
+ continue
328
+ a = '--cluster='
329
+ if sys.argv[1].startswith(a):
330
+ cluster_file = sys.argv[1][len(a):]
331
+ del sys.argv[1]
332
+ continue
333
+ a = '--splitfile='
334
+ if sys.argv[1].startswith(a):
335
+ split_file = sys.argv[1][len(a):]
336
+ del sys.argv[1]
337
+ split = dict()
338
+ with codecs.open(split_file, 'r', 'utf-8') as fh:
339
+ for line in fh: # line in unicode
340
+ words = line.strip().split()
341
+ if len(words) >= 2:
342
+ split[words[0]] = words[1:]
343
+ continue
344
+ a = '--ig='
345
+ if sys.argv[1].startswith(a):
346
+ ignore_file = sys.argv[1][len(a):]
347
+ del sys.argv[1]
348
+ with codecs.open(ignore_file, 'r', 'utf-8') as fh:
349
+ for line in fh: # line in unicode
350
+ line = line.strip()
351
+ if len(line) > 0:
352
+ ignore_words.add(line)
353
+ continue
354
+ a = '--char='
355
+ if sys.argv[1].startswith(a):
356
+ b = sys.argv[1][len(a):].lower()
357
+ del sys.argv[1]
358
+ tochar = (b == 'true') or (b != '0')
359
+ continue
360
+ a = '--v='
361
+ if sys.argv[1].startswith(a):
362
+ b = sys.argv[1][len(a):].lower()
363
+ del sys.argv[1]
364
+ verbose = 0
365
+ try:
366
+ verbose = int(b)
367
+ except:
368
+ if b == 'true' or b != '0':
369
+ verbose = 1
370
+ continue
371
+ a = '--padding-symbol='
372
+ if sys.argv[1].startswith(a):
373
+ b = sys.argv[1][len(a):].lower()
374
+ del sys.argv[1]
375
+ if b == 'space':
376
+ padding_symbol = ' '
377
+ elif b == 'underline':
378
+ padding_symbol = '_'
379
+ continue
380
+ if True or sys.argv[1].startswith('-'):
381
+ #ignore invalid switch
382
+ del sys.argv[1]
383
+ continue
384
+
385
+ if not case_sensitive:
386
+ ig = set([w.upper() for w in ignore_words])
387
+ ignore_words = ig
388
+
389
+ default_clusters = {}
390
+ default_words = {}
391
+
392
+ ref_file = sys.argv[1]
393
+ hyp_file = sys.argv[2]
394
+ rec_set = {}
395
+ if split and not case_sensitive:
396
+ newsplit = dict()
397
+ for w in split:
398
+ words = split[w]
399
+ for i in range(len(words)):
400
+ words[i] = words[i].upper()
401
+ newsplit[w.upper()] = words
402
+ split = newsplit
403
+
404
+ with codecs.open(hyp_file, 'r', 'utf-8') as fh:
405
+ for line in fh:
406
+ if tochar:
407
+ array = characterize(line)
408
+ else:
409
+ array = line.strip().split()
410
+ if len(array) == 0: continue
411
+ fid = array[0]
412
+ rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
413
+ split)
414
+
415
+ # compute error rate on the interaction of reference file and hyp file
416
+ for line in open(ref_file, 'r', encoding='utf-8'):
417
+ if tochar:
418
+ array = characterize(line)
419
+ else:
420
+ array = line.rstrip('\n').split()
421
+ if len(array) == 0: continue
422
+ fid = array[0]
423
+ if fid not in rec_set:
424
+ continue
425
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
426
+ rec = rec_set[fid]
427
+ if verbose:
428
+ print('\nutt: %s' % fid)
429
+
430
+ for word in rec + lab:
431
+ if word not in default_words:
432
+ default_cluster_name = default_cluster(word)
433
+ if default_cluster_name not in default_clusters:
434
+ default_clusters[default_cluster_name] = {}
435
+ if word not in default_clusters[default_cluster_name]:
436
+ default_clusters[default_cluster_name][word] = 1
437
+ default_words[word] = default_cluster_name
438
+
439
+ result = calculator.calculate(lab, rec)
440
+ if verbose:
441
+ if result['all'] != 0:
442
+ wer = float(result['ins'] + result['sub'] +
443
+ result['del']) * 100.0 / result['all']
444
+ else:
445
+ wer = 0.0
446
+ print('WER: %4.2f %%' % wer, end=' ')
447
+ print('N=%d C=%d S=%d D=%d I=%d' %
448
+ (result['all'], result['cor'], result['sub'], result['del'],
449
+ result['ins']))
450
+ space = {}
451
+ space['lab'] = []
452
+ space['rec'] = []
453
+ for idx in range(len(result['lab'])):
454
+ len_lab = width(result['lab'][idx])
455
+ len_rec = width(result['rec'][idx])
456
+ length = max(len_lab, len_rec)
457
+ space['lab'].append(length - len_lab)
458
+ space['rec'].append(length - len_rec)
459
+ upper_lab = len(result['lab'])
460
+ upper_rec = len(result['rec'])
461
+ lab1, rec1 = 0, 0
462
+ while lab1 < upper_lab or rec1 < upper_rec:
463
+ if verbose > 1:
464
+ print('lab(%s):' % fid.encode('utf-8'), end=' ')
465
+ else:
466
+ print('lab:', end=' ')
467
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
468
+ for idx in range(lab1, lab2):
469
+ token = result['lab'][idx]
470
+ print('{token}'.format(token=token), end='')
471
+ for n in range(space['lab'][idx]):
472
+ print(padding_symbol, end='')
473
+ print(' ', end='')
474
+ print()
475
+ if verbose > 1:
476
+ print('rec(%s):' % fid.encode('utf-8'), end=' ')
477
+ else:
478
+ print('rec:', end=' ')
479
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
480
+ for idx in range(rec1, rec2):
481
+ token = result['rec'][idx]
482
+ print('{token}'.format(token=token), end='')
483
+ for n in range(space['rec'][idx]):
484
+ print(padding_symbol, end='')
485
+ print(' ', end='')
486
+ print('\n', end='\n')
487
+ lab1 = lab2
488
+ rec1 = rec2
489
+
490
+ if verbose:
491
+ print(
492
+ '==========================================================================='
493
+ )
494
+ print()
495
+
496
+ result = calculator.overall()
497
+ if result['all'] != 0:
498
+ wer = float(result['ins'] + result['sub'] +
499
+ result['del']) * 100.0 / result['all']
500
+ else:
501
+ wer = 0.0
502
+ print('Overall -> %4.2f %%' % wer, end=' ')
503
+ print('N=%d C=%d S=%d D=%d I=%d' %
504
+ (result['all'], result['cor'], result['sub'], result['del'],
505
+ result['ins']))
506
+ if not verbose:
507
+ print()
508
+
509
+ if verbose:
510
+ for cluster_id in default_clusters:
511
+ result = calculator.cluster(
512
+ [k for k in default_clusters[cluster_id]])
513
+ if result['all'] != 0:
514
+ wer = float(result['ins'] + result['sub'] +
515
+ result['del']) * 100.0 / result['all']
516
+ else:
517
+ wer = 0.0
518
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
519
+ print('N=%d C=%d S=%d D=%d I=%d' %
520
+ (result['all'], result['cor'], result['sub'], result['del'],
521
+ result['ins']))
522
+ if len(cluster_file) > 0: # compute separated WERs for word clusters
523
+ cluster_id = ''
524
+ cluster = []
525
+ for line in open(cluster_file, 'r', encoding='utf-8'):
526
+ for token in line.decode('utf-8').rstrip('\n').split():
527
+ # end of cluster reached, like </Keyword>
528
+ if token[0:2] == '</' and token[len(token)-1] == '>' and \
529
+ token.lstrip('</').rstrip('>') == cluster_id :
530
+ result = calculator.cluster(cluster)
531
+ if result['all'] != 0:
532
+ wer = float(result['ins'] + result['sub'] +
533
+ result['del']) * 100.0 / result['all']
534
+ else:
535
+ wer = 0.0
536
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
537
+ print('N=%d C=%d S=%d D=%d I=%d' %
538
+ (result['all'], result['cor'], result['sub'],
539
+ result['del'], result['ins']))
540
+ cluster_id = ''
541
+ cluster = []
542
+ # begin of cluster reached, like <Keyword>
543
+ elif token[0] == '<' and token[len(token)-1] == '>' and \
544
+ cluster_id == '' :
545
+ cluster_id = token.lstrip('<').rstrip('>')
546
+ cluster = []
547
+ # general terms, like WEATHER / CAR / ...
548
+ else:
549
+ cluster.append(token)
550
+ print()
551
+ print(
552
+ '==========================================================================='
553
+ )
evaluation/evaluate_asr.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import sys
7
+ import uuid
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import tqdm
14
+ from datasets import load_dataset
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+ from transformers.generation import GenerationConfig
17
+
18
+ import torchaudio
19
+ from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
20
+ from vita_audio.tokenizer import get_audio_tokenizer
21
+
22
+
23
+ def collate_fn(batches):
24
+ input_ids = [sample["input_ids"] for sample in batches]
25
+ audios = [sample["audios"] for sample in batches]
26
+ audio_indices = [sample["audio_indices"] for sample in batches]
27
+
28
+ refs = [sample["ref"] for sample in batches]
29
+
30
+ return input_ids, audios, audio_indices, refs
31
+
32
+
33
+ class ASRDataset(torch.utils.data.Dataset):
34
+ def __init__(
35
+ self,
36
+ json_path,
37
+ tokenizer,
38
+ audio_tokenizer,
39
+ default_system_message=None,
40
+ add_generation_prompt=True,
41
+ ):
42
+ data = load_dataset("json", data_files=json_path, keep_in_memory=False)
43
+ self.data = data["train"]
44
+
45
+ self.tokenizer = tokenizer
46
+ self.add_generation_prompt = add_generation_prompt
47
+
48
+ self.audio_tokenizer = audio_tokenizer
49
+ self.default_system_message = default_system_message
50
+
51
+ def __len__(self):
52
+ return len(self.data)
53
+
54
+ def __getitem__(self, idx):
55
+ sample = self.data[idx]
56
+ # print(f"sample {sample}")
57
+
58
+ audio_path = sample["audios"][0]
59
+
60
+ if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
61
+ # discrete codec
62
+ audio_tokens = self.audio_tokenizer.encode(audio_path)
63
+ audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
64
+ else:
65
+ audio_tokens = None
66
+
67
+ messages = []
68
+ if len(sample["messages"]) == 2:
69
+ assert len(sample["messages"]) == 2
70
+ assert sample["messages"][0]["role"] == "user"
71
+ assert sample["messages"][1]["role"] == "assistant"
72
+
73
+ if self.default_system_message is not None:
74
+ messages = self.default_system_message + messages
75
+
76
+ elif len(sample["messages"]) == 3:
77
+ assert len(sample["messages"]) == 3
78
+ assert sample["messages"][0]["role"] == "system"
79
+ assert sample["messages"][1]["role"] == "user"
80
+ assert sample["messages"][2]["role"] == "assistant"
81
+
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ # print(sample)
86
+ for conv in sample["messages"][:-1]:
87
+ new_conv = {}
88
+ new_conv["role"] = conv["role"]
89
+
90
+ content = conv["content"]
91
+
92
+ if audio_tokens is not None:
93
+ content = content.replace(
94
+ "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
95
+ )
96
+
97
+ new_conv["content"] = content
98
+ messages.append(new_conv)
99
+
100
+ input_ids = self.tokenizer.apply_chat_template(
101
+ messages,
102
+ tokenize=True,
103
+ add_generation_prompt=self.add_generation_prompt,
104
+ # return_tensors="pt",
105
+ )
106
+
107
+ ref = sample["messages"][-1]["content"]
108
+
109
+ if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
110
+ # contiguous codec
111
+ input_ids, audios, audio_indices = add_audio_input_contiguous(
112
+ input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
113
+ )
114
+ else:
115
+ audios = None
116
+ audio_indices = None
117
+
118
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
119
+ return {
120
+ "input_ids": input_ids,
121
+ "audios": audios,
122
+ "audio_indices": audio_indices,
123
+ "ref": ref,
124
+ }
125
+
126
+
127
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
128
+ def __init__(self, size):
129
+ self._size = int(size)
130
+ assert size > 0
131
+ self._rank = torch.distributed.get_rank()
132
+ self._world_size = torch.distributed.get_world_size()
133
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
134
+
135
+ @staticmethod
136
+ def _get_local_indices(total_size, world_size, rank):
137
+ shard_size = total_size // world_size
138
+ left = total_size % world_size
139
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
140
+
141
+ begin = sum(shard_sizes[:rank])
142
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
143
+ return range(begin, end)
144
+
145
+ def __iter__(self):
146
+ yield from self._local_indices
147
+
148
+ def __len__(self):
149
+ return len(self._local_indices)
150
+
151
+
152
+ def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
153
+
154
+ audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
155
+
156
+ outputs = []
157
+
158
+ for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref) in enumerate(
159
+ tqdm.tqdm(dataloader)
160
+ ):
161
+
162
+ for input_ids, audios, audio_indices, ref in zip(
163
+ batched_input_ids, batched_audios, batched_audio_indices, batched_ref
164
+ ):
165
+ kwargs = {
166
+ # "temperature": 0.2,
167
+ # "top_p": 0.8,
168
+ # "do_sample": False,
169
+ # "temperature": 1.0,
170
+ "max_new_tokens": max([len(x) for x in batched_ref]) + 10,
171
+ "min_new_tokens": 1,
172
+ }
173
+ if audios is not None:
174
+ kwargs["audios"] = audios
175
+ kwargs["audio_indices"] = audio_indices
176
+
177
+ responses = model.generate(
178
+ input_ids=input_ids.cuda(),
179
+ **kwargs,
180
+ )
181
+
182
+ response = responses[0][len(input_ids[0]) :]
183
+
184
+ text_tokens = []
185
+ audio_tokens = []
186
+ for token_id in response:
187
+ if token_id >= audio_offset:
188
+ audio_tokens.append(token_id - audio_offset)
189
+ else:
190
+ text_tokens.append(token_id)
191
+
192
+ hyp = tokenizer.decode(text_tokens, skip_special_tokens=True)
193
+
194
+ outputs.append((hyp, ref))
195
+
196
+ print("")
197
+ print("=" * 100)
198
+ print(f"{hyp=}")
199
+ print(f"{ref=}")
200
+
201
+ return outputs
202
+
203
+
204
+ if __name__ == "__main__":
205
+ parser = argparse.ArgumentParser(
206
+ description="",
207
+ formatter_class=argparse.RawDescriptionHelpFormatter,
208
+ )
209
+
210
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
211
+ parser.add_argument(
212
+ "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
213
+ )
214
+ parser.add_argument(
215
+ "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
216
+ )
217
+ parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
218
+
219
+ parser.add_argument("--json_path", type=str, required=True, help="json_path")
220
+ parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
221
+
222
+ parser.add_argument("--batch_size", type=int, default=1)
223
+ parser.add_argument("--num_workers", type=int, default=0)
224
+
225
+ args = parser.parse_args()
226
+
227
+ print(f"{args=}")
228
+
229
+ torch.distributed.init_process_group(
230
+ backend="nccl",
231
+ world_size=int(os.getenv("WORLD_SIZE", "1")),
232
+ rank=int(os.getenv("RANK", "0")),
233
+ timeout=timedelta(seconds=7200),
234
+ )
235
+
236
+ torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
237
+
238
+ random.seed(42)
239
+ torch.manual_seed(42)
240
+
241
+ config = AutoConfig.from_pretrained(
242
+ args.model_name_or_path,
243
+ trust_remote_code=True,
244
+ )
245
+
246
+ # ================================================================
247
+ if "glm" in config.model_type.lower():
248
+ from get_chat_template import glm4_chat_template as chat_template
249
+
250
+ add_generation_prompt = True
251
+
252
+ default_system_message = [
253
+ {
254
+ "role": "system",
255
+ "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
256
+ }
257
+ ]
258
+
259
+ if "qwen2" in config.model_type.lower():
260
+ from get_chat_template import qwen2_chat_template as chat_template
261
+
262
+ add_generation_prompt = True
263
+
264
+ default_system_message = []
265
+
266
+ if "hunyuan" in config.model_type.lower():
267
+ from get_chat_template import hunyuan_chat_template as chat_template
268
+
269
+ add_generation_prompt = False
270
+
271
+ default_system_message = [
272
+ {
273
+ "role": "system",
274
+ "content": "You are a helpful AI assistant.",
275
+ }
276
+ ]
277
+
278
+ # ================================================================
279
+ print("Loading model")
280
+ # device_map = "auto"
281
+ device_map = "cuda"
282
+ # torch_dtype=torch.float16
283
+ torch_dtype = torch.bfloat16
284
+
285
+ rank = torch.distributed.get_rank()
286
+
287
+ audio_tokenizer = get_audio_tokenizer(
288
+ args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
289
+ )
290
+
291
+ tokenizer = AutoTokenizer.from_pretrained(
292
+ args.model_name_or_path,
293
+ trust_remote_code=True,
294
+ chat_template=chat_template,
295
+ )
296
+ # print("tokenizer", tokenizer)
297
+
298
+ model = AutoModelForCausalLM.from_pretrained(
299
+ args.model_name_or_path,
300
+ trust_remote_code=True,
301
+ device_map=device_map,
302
+ torch_dtype=torch_dtype,
303
+ attn_implementation="flash_attention_2",
304
+ ).eval()
305
+ # print("model", model)
306
+
307
+ model.generation_config = GenerationConfig.from_pretrained(
308
+ args.model_name_or_path, trust_remote_code=True
309
+ )
310
+
311
+ model.generation_config.max_new_tokens = 4096
312
+ model.generation_config.chat_format = "chatml"
313
+ model.generation_config.max_window_size = 8192
314
+ model.generation_config.use_cache = True
315
+ model.generation_config.do_sample = False
316
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
317
+ if model.config.model_type == "hunyuan":
318
+ model.generation_config.eos_token_id = tokenizer.eos_id
319
+
320
+ # ================================================================
321
+ print("Loading data")
322
+ dataset = ASRDataset(
323
+ json_path=args.json_path,
324
+ tokenizer=tokenizer,
325
+ audio_tokenizer=audio_tokenizer,
326
+ default_system_message=default_system_message,
327
+ add_generation_prompt=add_generation_prompt,
328
+ )
329
+
330
+ dataloader = torch.utils.data.DataLoader(
331
+ dataset=dataset,
332
+ sampler=InferenceSampler(len(dataset)),
333
+ batch_size=args.batch_size,
334
+ num_workers=args.num_workers,
335
+ pin_memory=True,
336
+ drop_last=False,
337
+ collate_fn=partial(
338
+ collate_fn,
339
+ ),
340
+ )
341
+
342
+ # ================================================================
343
+ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
344
+
345
+ torch.distributed.barrier()
346
+
347
+ world_size = torch.distributed.get_world_size()
348
+ merged_outputs = [None for _ in range(world_size)]
349
+ torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
350
+
351
+ merged_outputs = [json.loads(_) for _ in merged_outputs]
352
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
353
+
354
+ if torch.distributed.get_rank() == 0:
355
+ # json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
356
+ json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
357
+ hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
358
+ ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
359
+
360
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
361
+ os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
362
+
363
+ hyp_file = open(hyp_path, "w")
364
+ ref_file = open(ref_path, "w")
365
+
366
+ for sample_idx, (hyp, ref) in enumerate(merged_outputs):
367
+ hyp_file.write(f"{sample_idx} {hyp}" + "\n")
368
+ ref_file.write(f"{sample_idx} {ref}" + "\n")
369
+
370
+ hyp_file.close()
371
+ ref_file.close()
372
+
373
+ hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
374
+ hyp_ref_file = open(hyp_ref_path, "w")
375
+ json.dump(merged_outputs, hyp_ref_file, indent=4)
376
+ hyp_ref_file.close()
377
+
378
+ torch.distributed.barrier()
379
+ print("Done.")
evaluation/evaluate_libritts.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import re
7
+ import sys
8
+ import uuid
9
+ from datetime import timedelta
10
+ from functools import partial
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import tqdm
15
+ from datasets import load_dataset
16
+ from tn.english.normalizer import Normalizer as EnNormalizer
17
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
18
+ from transformers.generation import GenerationConfig
19
+
20
+ import torchaudio
21
+ from vita_audio.tokenizer import get_audio_tokenizer
22
+
23
+
24
+ def collate_fn(batches):
25
+ input_ids = [sample["input_ids"] for sample in batches]
26
+
27
+ refs = [sample["ref"] for sample in batches]
28
+ filenames = [sample["filename"] for sample in batches]
29
+
30
+ return input_ids, refs, filenames
31
+
32
+
33
+ class TTSDataset(torch.utils.data.Dataset):
34
+ def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
35
+ data = load_dataset("json", data_files=json_path, keep_in_memory=False)
36
+ self.data = data["train"]
37
+
38
+ self.tokenizer = tokenizer
39
+ self.audio_tokenizer = audio_tokenizer
40
+ self.default_system_message = default_system_message
41
+ self.add_generation_prompt = add_generation_prompt
42
+
43
+ def __len__(self):
44
+ return len(self.data)
45
+
46
+ def __getitem__(self, idx):
47
+ sample = self.data[idx]
48
+
49
+ messages = []
50
+
51
+ if self.default_system_message is not None:
52
+ messages = self.default_system_message + messages
53
+
54
+ role = "user"
55
+ content = sample["messages"][0]["content"]
56
+ messages.append(
57
+ {
58
+ "role": role,
59
+ "content": content,
60
+ }
61
+ )
62
+
63
+ input_ids = self.tokenizer.apply_chat_template(
64
+ messages,
65
+ tokenize=True,
66
+ add_generation_prompt=self.add_generation_prompt,
67
+ return_tensors="pt",
68
+ )
69
+
70
+ ref = sample["messages"][0]["content"]
71
+ ref = ref.replace("Convert the text to speech.\n", "")
72
+ ref = ref.strip()
73
+
74
+ filepath = sample["audios"][0]
75
+ filename = os.path.basename(filepath)
76
+
77
+ return {
78
+ "input_ids": input_ids,
79
+ "ref": ref,
80
+ "filename": filename,
81
+ }
82
+
83
+
84
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
85
+ def __init__(self, size):
86
+ self._size = int(size)
87
+ assert size > 0
88
+ self._rank = torch.distributed.get_rank()
89
+ self._world_size = torch.distributed.get_world_size()
90
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
91
+
92
+ @staticmethod
93
+ def _get_local_indices(total_size, world_size, rank):
94
+ shard_size = total_size // world_size
95
+ left = total_size % world_size
96
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
97
+
98
+ begin = sum(shard_sizes[:rank])
99
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
100
+ return range(begin, end)
101
+
102
+ def __iter__(self):
103
+ yield from self._local_indices
104
+
105
+ def __len__(self):
106
+ return len(self._local_indices)
107
+
108
+
109
+ def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
110
+
111
+ audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
112
+ en_tn_model = EnNormalizer(overwrite_cache=True)
113
+
114
+ outputs = []
115
+
116
+ for _, (
117
+ batched_input_ids,
118
+ batched_ref,
119
+ batched_filename,
120
+ ) in enumerate(tqdm.tqdm(dataloader)):
121
+ for input_ids, ref, filename in zip(
122
+ batched_input_ids, batched_ref, batched_filename
123
+ ):
124
+
125
+ responses = model.generate(
126
+ input_ids=input_ids.cuda(),
127
+ # temperature=0.2,
128
+ # top_p=0.8,
129
+ # do_sample=False,
130
+ # temperature=1.0,
131
+ max_new_tokens=1024,
132
+ min_new_tokens=1,
133
+ )
134
+
135
+ response = responses[0][len(input_ids[0]) :]
136
+
137
+ text_tokens = []
138
+ audio_tokens = []
139
+ for token_id in response:
140
+ if token_id >= audio_offset:
141
+ audio_tokens.append(token_id - audio_offset)
142
+ else:
143
+ text_tokens.append(token_id)
144
+
145
+ if len(audio_tokens) == 0:
146
+ continue
147
+
148
+ tts_speech = audio_tokenizer.decode(audio_tokens)
149
+
150
+ wav_dir = os.path.join(output_dir, "audio")
151
+ wav_path = os.path.join(wav_dir, filename + ".wav")
152
+ os.makedirs(os.path.dirname(wav_path), exist_ok=True)
153
+ torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
154
+
155
+ hyp = asr_model(wav_path, return_timestamps=True)["text"].strip()
156
+
157
+ hyp = en_tn_model.normalize(hyp)
158
+ ref = en_tn_model.normalize(ref)
159
+
160
+ hyp = re.sub(r"\W+", " ", hyp)
161
+ ref = re.sub(r"\W+", " ", ref)
162
+
163
+ outputs.append((hyp, ref))
164
+
165
+ print("")
166
+ print("=" * 100)
167
+ # print(f"{len(input_id)=}")
168
+ # print(f"{len(response)=}")
169
+ print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
170
+ print(f"{filename=}")
171
+
172
+ return outputs
173
+
174
+
175
+ def load_asr_model():
176
+ import torch
177
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
178
+
179
+ rank = torch.distributed.get_rank()
180
+ device = f"cuda:{rank}"
181
+ torch_dtype = torch.float16
182
+
183
+ model_id = "/data/models/openai/whisper-large-v3"
184
+
185
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
186
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
187
+ )
188
+ model.to(device)
189
+
190
+ processor = AutoProcessor.from_pretrained(model_id)
191
+
192
+ pipe = pipeline(
193
+ "automatic-speech-recognition",
194
+ model=model,
195
+ tokenizer=processor.tokenizer,
196
+ feature_extractor=processor.feature_extractor,
197
+ torch_dtype=torch_dtype,
198
+ device=device,
199
+ )
200
+
201
+ return pipe
202
+
203
+
204
+ if __name__ == "__main__":
205
+ parser = argparse.ArgumentParser(
206
+ description="",
207
+ formatter_class=argparse.RawDescriptionHelpFormatter,
208
+ )
209
+
210
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
211
+ parser.add_argument(
212
+ "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
213
+ )
214
+ parser.add_argument(
215
+ "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
216
+ )
217
+ parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
218
+
219
+ parser.add_argument("--json_path", type=str, required=True, help="json_path")
220
+ parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
221
+
222
+ parser.add_argument("--batch_size", type=int, default=1)
223
+ parser.add_argument("--num_workers", type=int, default=0)
224
+
225
+ parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
226
+
227
+ args = parser.parse_args()
228
+
229
+ print(f"{args=}")
230
+
231
+ torch.distributed.init_process_group(
232
+ backend="nccl",
233
+ world_size=int(os.getenv("WORLD_SIZE", "1")),
234
+ rank=int(os.getenv("RANK", "0")),
235
+ timeout=timedelta(seconds=7200),
236
+ )
237
+
238
+ torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
239
+
240
+ random.seed(42)
241
+ torch.manual_seed(42)
242
+
243
+ config = AutoConfig.from_pretrained(
244
+ args.model_name_or_path,
245
+ trust_remote_code=True,
246
+ )
247
+
248
+ # ================================================================
249
+ if "glm" in config.model_type.lower():
250
+ from get_chat_template import glm4_chat_template as chat_template
251
+
252
+ add_generation_prompt = True
253
+
254
+ default_system_message = [
255
+ {
256
+ "role": "system",
257
+ "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
258
+ }
259
+ ]
260
+
261
+ if "qwen2" in config.model_type.lower():
262
+ from get_chat_template import qwen2_chat_template as chat_template
263
+
264
+ add_generation_prompt = True
265
+
266
+ default_system_message = []
267
+
268
+ if "hunyuan" in config.model_type.lower():
269
+ from get_chat_template import hunyuan_chat_template as chat_template
270
+
271
+ add_generation_prompt = False
272
+
273
+ default_system_message = [
274
+ {
275
+ "role": "system",
276
+ "content": "You are a helpful AI assistant.",
277
+ }
278
+ ]
279
+
280
+ # ================================================================
281
+ print("Loading model")
282
+ device = "cuda"
283
+ # device_map = "auto"
284
+ device_map = "cuda"
285
+ # torch_dtype=torch.float16
286
+ torch_dtype = torch.bfloat16
287
+
288
+ rank = torch.distributed.get_rank()
289
+
290
+ audio_tokenizer = get_audio_tokenizer(
291
+ args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
292
+ )
293
+
294
+ tokenizer = AutoTokenizer.from_pretrained(
295
+ args.model_name_or_path,
296
+ trust_remote_code=True,
297
+ chat_template=chat_template,
298
+ )
299
+ # print("tokenizer", tokenizer)
300
+
301
+ model = AutoModelForCausalLM.from_pretrained(
302
+ args.model_name_or_path,
303
+ trust_remote_code=True,
304
+ device_map=device_map,
305
+ torch_dtype=torch_dtype,
306
+ attn_implementation="flash_attention_2",
307
+ ).eval()
308
+ # print("model", model)
309
+
310
+ model.generation_config = GenerationConfig.from_pretrained(
311
+ args.model_name_or_path, trust_remote_code=True
312
+ )
313
+
314
+ model.generation_config.max_new_tokens = 4096
315
+ model.generation_config.chat_format = "chatml"
316
+ model.generation_config.max_window_size = 8192
317
+ model.generation_config.use_cache = True
318
+ model.generation_config.do_sample = True
319
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
320
+ if model.config.model_type == "hunyuan":
321
+ model.generation_config.eos_token_id = tokenizer.eos_id
322
+
323
+ asr_model = load_asr_model()
324
+
325
+ # ================================================================
326
+ print("Loading data")
327
+ dataset = TTSDataset(
328
+ json_path=args.json_path,
329
+ tokenizer=tokenizer,
330
+ audio_tokenizer=audio_tokenizer,
331
+ default_system_message=default_system_message,
332
+ add_generation_prompt=add_generation_prompt,
333
+ )
334
+
335
+ dataloader = torch.utils.data.DataLoader(
336
+ dataset=dataset,
337
+ sampler=InferenceSampler(len(dataset)),
338
+ batch_size=args.batch_size,
339
+ num_workers=args.num_workers,
340
+ pin_memory=True,
341
+ drop_last=False,
342
+ collate_fn=partial(
343
+ collate_fn,
344
+ ),
345
+ )
346
+
347
+ # ================================================================
348
+ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
349
+
350
+ torch.distributed.barrier()
351
+
352
+ world_size = torch.distributed.get_world_size()
353
+ merged_outputs = [None for _ in range(world_size)]
354
+ torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
355
+
356
+ merged_outputs = [json.loads(_) for _ in merged_outputs]
357
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
358
+
359
+ if torch.distributed.get_rank() == 0:
360
+ # json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
361
+ json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
362
+ hyp_path = os.path.join(args.output_dir, f"{json_name}_hyp.txt")
363
+ ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
364
+
365
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
366
+ os.makedirs(os.path.dirname(hyp_path), exist_ok=True)
367
+
368
+ hyp_file = open(hyp_path, "w")
369
+ ref_file = open(ref_path, "w")
370
+
371
+ for sample_idx, (hyp, ref) in enumerate(merged_outputs):
372
+ hyp_file.write(f"{sample_idx} {hyp}" + "\n")
373
+ ref_file.write(f"{sample_idx} {ref}" + "\n")
374
+
375
+ hyp_file.close()
376
+ ref_file.close()
377
+
378
+ hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref.json")
379
+ hyp_ref_file = open(hyp_ref_path, "w")
380
+ json.dump(merged_outputs, hyp_ref_file, indent=4)
381
+ hyp_ref_file.close()
382
+
383
+ torch.distributed.barrier()
384
+ print("Done.")
evaluation/evaluate_seedtts.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import sys
7
+ import uuid
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import tqdm
14
+ from datasets import load_dataset
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+ from transformers.generation import GenerationConfig
17
+
18
+ import torchaudio
19
+ from vita_audio.tokenizer import get_audio_tokenizer
20
+
21
+
22
+ def collate_fn(batches):
23
+ input_ids = [sample["input_ids"] for sample in batches]
24
+
25
+ refs = [sample["ref"] for sample in batches]
26
+ filenames = [sample["filename"] for sample in batches]
27
+ prompt_audio_path = [sample["prompt_audio_path"] for sample in batches]
28
+
29
+ return input_ids, refs, filenames, prompt_audio_path
30
+
31
+
32
+ class SeedTTSDataset(torch.utils.data.Dataset):
33
+ def __init__(
34
+ self,
35
+ data_path,
36
+ tokenizer,
37
+ audio_tokenizer,
38
+ default_system_message=None,
39
+ speaker_prompt=False,
40
+ add_generation_prompt=True,
41
+ ):
42
+ self.data = []
43
+
44
+ meta_path = os.path.join(data_path, f"seedtts_testset/zh/meta.lst")
45
+ with open(meta_path, "r") as f:
46
+ lines = f.readlines()
47
+
48
+ for line in lines:
49
+ line = line.strip().split("|")
50
+ filename = line[0]
51
+ prompt_text = line[1]
52
+ prompt_audio = line[2]
53
+ text = line[3]
54
+ self.data.append(["zh", filename, prompt_text, prompt_audio, text])
55
+
56
+ meta_path = os.path.join(data_path, f"seedtts_testset/zh/hardcase.lst")
57
+ with open(meta_path, "r") as f:
58
+ lines = f.readlines()
59
+
60
+ for line in lines:
61
+ line = line.strip().split("|")
62
+ filename = line[0]
63
+ prompt_text = line[1]
64
+ prompt_audio = line[2]
65
+ text = line[3]
66
+ self.data.append(["hardcase", filename, prompt_text, prompt_audio, text])
67
+
68
+ meta_path = os.path.join(data_path, f"seedtts_testset/en/meta.lst")
69
+ with open(meta_path, "r") as f:
70
+ lines = f.readlines()
71
+
72
+ for line in lines:
73
+ line = line.strip().split("|")
74
+ filename = line[0]
75
+ prompt_text = line[1]
76
+ prompt_audio = line[2]
77
+ text = line[3]
78
+ self.data.append(["en", filename, prompt_text, prompt_audio, text])
79
+
80
+ self.tokenizer = tokenizer
81
+ self.audio_tokenizer = audio_tokenizer
82
+ self.default_system_message = default_system_message
83
+ self.add_generation_prompt = add_generation_prompt
84
+
85
+ self.data_path = data_path
86
+ self.speaker_prompt = speaker_prompt
87
+
88
+ def __len__(self):
89
+ return len(self.data)
90
+
91
+ def __getitem__(self, idx):
92
+ sample = self.data[idx]
93
+
94
+ split, filename, prompt_text, prompt_audio, text = sample
95
+
96
+ messages = []
97
+
98
+ if self.default_system_message is not None:
99
+ messages = self.default_system_message + messages
100
+
101
+ if self.speaker_prompt:
102
+ if split == "hardcase":
103
+ prompt_audio_path = os.path.join(
104
+ self.data_path, "seedtts_testset", "zh", prompt_audio
105
+ )
106
+ else:
107
+ prompt_audio_path = os.path.join(
108
+ self.data_path, "seedtts_testset", split, prompt_audio
109
+ )
110
+
111
+ if self.audio_tokenizer.apply_to_role("system", is_discrete=True):
112
+ # discrete codec
113
+ prompt_audio_tokens = self.audio_tokenizer.encode(prompt_audio_path)
114
+ prompt_audio_tokens = "".join(f"<|audio_{i}|>" for i in prompt_audio_tokens)
115
+
116
+ prompt_text = f"Speaker Metadata:\nAudio: <|begin_of_audio|>{prompt_audio_tokens}<|end_of_audio|>\n"
117
+
118
+ if len(messages) > 0 and messages[0]["role"] == "system":
119
+ messages[0]["content"] += prompt_text
120
+
121
+ else:
122
+ messages.append(
123
+ {
124
+ "role": "system",
125
+ "content": prompt_text,
126
+ }
127
+ )
128
+ else:
129
+ prompt_audio_path = None
130
+
131
+ role = "user"
132
+ content = "Convert the text to speech.\n" + text
133
+ messages.append(
134
+ {
135
+ "role": role,
136
+ "content": content,
137
+ }
138
+ )
139
+
140
+ input_ids = self.tokenizer.apply_chat_template(
141
+ messages,
142
+ tokenize=True,
143
+ add_generation_prompt=self.add_generation_prompt,
144
+ return_tensors="pt",
145
+ )
146
+
147
+ ref = text
148
+
149
+ return {
150
+ "input_ids": input_ids,
151
+ "ref": ref,
152
+ "filename": split + "/" + filename,
153
+ "prompt_audio_path": prompt_audio_path,
154
+ }
155
+
156
+
157
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
158
+ def __init__(self, size):
159
+ self._size = int(size)
160
+ assert size > 0
161
+ self._rank = torch.distributed.get_rank()
162
+ self._world_size = torch.distributed.get_world_size()
163
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
164
+
165
+ @staticmethod
166
+ def _get_local_indices(total_size, world_size, rank):
167
+ shard_size = total_size // world_size
168
+ left = total_size % world_size
169
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
170
+
171
+ begin = sum(shard_sizes[:rank])
172
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
173
+ return range(begin, end)
174
+
175
+ def __iter__(self):
176
+ yield from self._local_indices
177
+
178
+ def __len__(self):
179
+ return len(self._local_indices)
180
+
181
+
182
+ def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir):
183
+
184
+ audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
185
+
186
+ outputs = []
187
+
188
+ for _, (
189
+ batched_input_ids,
190
+ batched_ref,
191
+ batched_filename,
192
+ batched_prompt_audio_path,
193
+ ) in enumerate(tqdm.tqdm(dataloader)):
194
+
195
+ for input_ids, ref, filename, prompt_audio_path in zip(
196
+ batched_input_ids, batched_ref, batched_filename, batched_prompt_audio_path
197
+ ):
198
+ responses = model.generate(
199
+ input_ids=input_ids.cuda(),
200
+ # temperature=0.2,
201
+ # top_p=0.8,
202
+ # do_sample=False,
203
+ # temperature=1.0,
204
+ max_new_tokens=1024,
205
+ min_new_tokens=1,
206
+ )
207
+
208
+ response = responses[0][len(input_ids[0]) :]
209
+
210
+ text_tokens = []
211
+ audio_tokens = []
212
+ for token_id in response:
213
+ if token_id >= audio_offset:
214
+ audio_tokens.append(token_id - audio_offset)
215
+ else:
216
+ text_tokens.append(token_id)
217
+
218
+ if len(audio_tokens) == 0:
219
+ continue
220
+
221
+ tts_speech = audio_tokenizer.decode(audio_tokens, source_speech_16k=prompt_audio_path)
222
+
223
+ wav_path = os.path.join(output_dir, filename + ".wav")
224
+ os.makedirs(os.path.dirname(wav_path), exist_ok=True)
225
+ torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
226
+
227
+ outputs.append((wav_path, filename))
228
+
229
+ print("")
230
+ print("=" * 100)
231
+ # print(f"{len(input_id)=}")
232
+ # print(f"{len(response)=}")
233
+ print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
234
+ print(f"{filename=}")
235
+
236
+ return outputs
237
+
238
+
239
+ if __name__ == "__main__":
240
+ parser = argparse.ArgumentParser(
241
+ description="",
242
+ formatter_class=argparse.RawDescriptionHelpFormatter,
243
+ )
244
+
245
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
246
+ parser.add_argument(
247
+ "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
248
+ )
249
+ parser.add_argument(
250
+ "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
251
+ )
252
+ parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
253
+
254
+ parser.add_argument("--data_path", type=str, required=True, help="data_path")
255
+ parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
256
+
257
+ parser.add_argument("--batch_size", type=int, default=1)
258
+ parser.add_argument("--num_workers", type=int, default=0)
259
+
260
+ parser.add_argument("--speaker_prompt", action=argparse.BooleanOptionalAction, default=False)
261
+
262
+ args = parser.parse_args()
263
+
264
+ print(f"{args=}")
265
+
266
+ torch.distributed.init_process_group(
267
+ backend="nccl",
268
+ world_size=int(os.getenv("WORLD_SIZE", "1")),
269
+ rank=int(os.getenv("RANK", "0")),
270
+ timeout=timedelta(seconds=7200),
271
+ )
272
+
273
+ torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
274
+
275
+ random.seed(42)
276
+ torch.manual_seed(42)
277
+
278
+ config = AutoConfig.from_pretrained(
279
+ args.model_name_or_path,
280
+ trust_remote_code=True,
281
+ )
282
+
283
+ # ================================================================
284
+ if "glm" in config.model_type.lower():
285
+ from get_chat_template import glm4_chat_template as chat_template
286
+
287
+ add_generation_prompt = True
288
+
289
+ default_system_message = [
290
+ {
291
+ "role": "system",
292
+ "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
293
+ }
294
+ ]
295
+
296
+ if "qwen2" in config.model_type.lower():
297
+ from get_chat_template import qwen2_chat_template as chat_template
298
+
299
+ add_generation_prompt = True
300
+
301
+ default_system_message = []
302
+
303
+ if "hunyuan" in config.model_type.lower():
304
+ from get_chat_template import hunyuan_chat_template as chat_template
305
+
306
+ add_generation_prompt = False
307
+
308
+ default_system_message = [
309
+ {
310
+ "role": "system",
311
+ "content": "You are a helpful AI assistant.",
312
+ }
313
+ ]
314
+
315
+ # ================================================================
316
+ print("Loading model")
317
+ device = "cuda"
318
+ # device_map = "auto"
319
+ device_map = "cuda"
320
+ # torch_dtype=torch.float16
321
+ torch_dtype = torch.bfloat16
322
+
323
+ rank = torch.distributed.get_rank()
324
+
325
+ audio_tokenizer = get_audio_tokenizer(
326
+ args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
327
+ )
328
+
329
+ tokenizer = AutoTokenizer.from_pretrained(
330
+ args.model_name_or_path,
331
+ trust_remote_code=True,
332
+ chat_template=chat_template,
333
+ )
334
+ # print("tokenizer", tokenizer)
335
+
336
+ model = AutoModelForCausalLM.from_pretrained(
337
+ args.model_name_or_path,
338
+ trust_remote_code=True,
339
+ device_map=device_map,
340
+ torch_dtype=torch_dtype,
341
+ attn_implementation="flash_attention_2",
342
+ ).eval()
343
+ # print("model", model)
344
+
345
+ model.generation_config = GenerationConfig.from_pretrained(
346
+ args.model_name_or_path, trust_remote_code=True
347
+ )
348
+
349
+ model.generation_config.max_new_tokens = 4096
350
+ model.generation_config.chat_format = "chatml"
351
+ model.generation_config.max_window_size = 8192
352
+ model.generation_config.use_cache = True
353
+ model.generation_config.do_sample = True
354
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
355
+ if model.config.model_type == "hunyuan":
356
+ model.generation_config.eos_token_id = tokenizer.eos_id
357
+
358
+ # ================================================================
359
+ print("Loading data")
360
+ dataset = SeedTTSDataset(
361
+ data_path=args.data_path,
362
+ tokenizer=tokenizer,
363
+ audio_tokenizer=audio_tokenizer,
364
+ default_system_message=default_system_message,
365
+ speaker_prompt=args.speaker_prompt,
366
+ add_generation_prompt=add_generation_prompt,
367
+ )
368
+
369
+ dataloader = torch.utils.data.DataLoader(
370
+ dataset=dataset,
371
+ sampler=InferenceSampler(len(dataset)),
372
+ batch_size=args.batch_size,
373
+ num_workers=args.num_workers,
374
+ pin_memory=True,
375
+ drop_last=False,
376
+ collate_fn=partial(
377
+ collate_fn,
378
+ ),
379
+ )
380
+
381
+ # ================================================================
382
+ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir)
383
+
384
+ torch.distributed.barrier()
385
+
386
+ world_size = torch.distributed.get_world_size()
387
+ merged_outputs = [None for _ in range(world_size)]
388
+ torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
389
+
390
+ merged_outputs = [json.loads(_) for _ in merged_outputs]
391
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
392
+
393
+ torch.distributed.barrier()
394
+ print("Done.")
evaluation/evaluate_sqa.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import json
4
+ import os
5
+ import random
6
+ import sys
7
+ import uuid
8
+ from datetime import timedelta
9
+ from functools import partial
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import tqdm
14
+ from datasets import load_dataset
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+ from transformers.generation import GenerationConfig
17
+
18
+ import torchaudio
19
+ from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
20
+ from vita_audio.tokenizer import get_audio_tokenizer
21
+
22
+
23
+ def collate_fn(batches):
24
+ input_ids = [sample["input_ids"] for sample in batches]
25
+ audios = [sample["audios"] for sample in batches]
26
+ audio_indices = [sample["audio_indices"] for sample in batches]
27
+
28
+ refs = [sample["ref"] for sample in batches]
29
+ filenames = [sample["filename"] for sample in batches]
30
+
31
+ return input_ids, audios, audio_indices, refs, filenames
32
+
33
+
34
+ class STSDataset(torch.utils.data.Dataset):
35
+ def __init__(self, json_path, tokenizer, audio_tokenizer, default_system_message=None, add_generation_prompt=True):
36
+ data = load_dataset("json", data_files=json_path, keep_in_memory=False)
37
+ self.data = data["train"]
38
+
39
+ self.tokenizer = tokenizer
40
+ self.add_generation_prompt = add_generation_prompt
41
+
42
+ self.audio_tokenizer = audio_tokenizer
43
+ self.default_system_message = default_system_message
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, idx):
49
+ sample = self.data[idx]
50
+
51
+ assert len(sample["audios"]) == 1
52
+
53
+ audio_path = sample["audios"][0]
54
+
55
+ if self.audio_tokenizer.apply_to_role("user", is_discrete=True):
56
+ # discrete codec
57
+ audio_tokens = self.audio_tokenizer.encode(audio_path)
58
+ audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
59
+ else:
60
+ audio_tokens = None
61
+
62
+ messages = []
63
+
64
+ if len(sample["messages"]) == 2:
65
+ assert len(sample["messages"]) == 2
66
+ assert sample["messages"][0]["role"] == "user"
67
+ assert sample["messages"][1]["role"] == "assistant"
68
+
69
+ if self.default_system_message is not None:
70
+ messages = self.default_system_message + messages
71
+
72
+ elif len(sample["messages"]) == 3:
73
+ assert len(sample["messages"]) == 3
74
+ assert sample["messages"][0]["role"] == "system"
75
+ assert sample["messages"][1]["role"] == "user"
76
+ assert sample["messages"][2]["role"] == "assistant"
77
+
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ for conv in sample["messages"][:-1]:
82
+ new_conv = {}
83
+ new_conv["role"] = conv["role"]
84
+
85
+ content = conv["content"]
86
+ if isinstance(content, list):
87
+ assert len(content) == 1
88
+ content = content[0]
89
+
90
+ if audio_tokens is not None:
91
+ content = content.replace(
92
+ "<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
93
+ )
94
+
95
+ new_conv["content"] = content
96
+ messages.append(new_conv)
97
+
98
+ input_ids = self.tokenizer.apply_chat_template(
99
+ messages,
100
+ tokenize=True,
101
+ add_generation_prompt=self.add_generation_prompt,
102
+ # return_tensors="pt",
103
+ )
104
+
105
+ ref = sample["messages"][-1]["content"]
106
+
107
+ if self.audio_tokenizer.apply_to_role("user", is_contiguous=True):
108
+ # contiguous codec
109
+ input_ids, audios, audio_indices = add_audio_input_contiguous(
110
+ input_ids, [audio_path], self.tokenizer, self.audio_tokenizer
111
+ )
112
+ else:
113
+ audios = None
114
+ audio_indices = None
115
+
116
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
117
+
118
+ filename = os.path.basename(audio_path)
119
+ filename = os.path.splitext(filename)[0]
120
+
121
+ return {
122
+ "input_ids": input_ids,
123
+ "audios": audios,
124
+ "audio_indices": audio_indices,
125
+ "ref": ref,
126
+ "filename": filename,
127
+ }
128
+
129
+
130
+ class InferenceSampler(torch.utils.data.sampler.Sampler):
131
+ def __init__(self, size):
132
+ self._size = int(size)
133
+ assert size > 0
134
+ self._rank = torch.distributed.get_rank()
135
+ self._world_size = torch.distributed.get_world_size()
136
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
137
+
138
+ @staticmethod
139
+ def _get_local_indices(total_size, world_size, rank):
140
+ shard_size = total_size // world_size
141
+ left = total_size % world_size
142
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
143
+
144
+ begin = sum(shard_sizes[:rank])
145
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
146
+ return range(begin, end)
147
+
148
+ def __iter__(self):
149
+ yield from self._local_indices
150
+
151
+ def __len__(self):
152
+ return len(self._local_indices)
153
+
154
+
155
+ def inference(model, tokenizer, audio_tokenizer, dataloader, output_dir, asr_model):
156
+
157
+ audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
158
+
159
+ outputs = []
160
+
161
+ for _, (batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename) in enumerate(
162
+ tqdm.tqdm(dataloader)
163
+ ):
164
+ for input_ids, audios, audio_indices, ref, filename in zip(
165
+ batched_input_ids, batched_audios, batched_audio_indices, batched_ref, batched_filename
166
+ ):
167
+
168
+ responses = model.generate(
169
+ input_ids=input_ids.cuda(),
170
+ audios=audios,
171
+ audio_indices=audio_indices,
172
+ # temperature=0.2,
173
+ # top_p=0.8,
174
+ # do_sample=False,
175
+ # temperature=1.0,
176
+ max_new_tokens=1024,
177
+ min_new_tokens=1,
178
+ )
179
+
180
+ response = responses[0][len(input_ids[0]) :]
181
+
182
+ text_tokens = []
183
+ audio_tokens = []
184
+ for token_id in response:
185
+ if token_id >= audio_offset:
186
+ audio_tokens.append(token_id - audio_offset)
187
+ else:
188
+ text_tokens.append(token_id)
189
+
190
+ hyp_text = tokenizer.decode(text_tokens, skip_special_tokens=True)
191
+
192
+ if len(audio_tokens) == 0:
193
+ continue
194
+
195
+ tts_speech = audio_tokenizer.decode(audio_tokens)
196
+
197
+ wav_dir = os.path.join(output_dir, "audio")
198
+ wav_path = os.path.join(wav_dir, filename + ".wav")
199
+ os.makedirs(os.path.dirname(wav_path), exist_ok=True)
200
+ torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
201
+
202
+ # hyp_speech = asr_model.transcribe(wav_path)["text"].strip()
203
+ hyp_speech = asr_model(wav_path, return_timestamps=True)["text"].strip()
204
+ # hyp_speech = ""
205
+
206
+ outputs.append((hyp_text, hyp_speech, ref))
207
+
208
+ print("")
209
+ print("=" * 100)
210
+ print(f"{tokenizer.decode(response, skip_special_tokens=False)}")
211
+ print(f" {hyp_text=}")
212
+ print(f"{hyp_speech=}")
213
+ print(f" {ref=}")
214
+ print(f"{filename=}")
215
+
216
+ return outputs
217
+
218
+
219
+ def load_asr_model():
220
+ import torch
221
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
222
+
223
+ rank = torch.distributed.get_rank()
224
+ device = f"cuda:{rank}"
225
+ torch_dtype = torch.float16
226
+
227
+ model_id = "/data/models/openai/whisper-large-v3"
228
+
229
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
230
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
231
+ )
232
+ model.to(device)
233
+
234
+ processor = AutoProcessor.from_pretrained(model_id)
235
+
236
+ pipe = pipeline(
237
+ "automatic-speech-recognition",
238
+ model=model,
239
+ tokenizer=processor.tokenizer,
240
+ feature_extractor=processor.feature_extractor,
241
+ torch_dtype=torch_dtype,
242
+ device=device,
243
+ )
244
+
245
+ return pipe
246
+
247
+
248
+ if __name__ == "__main__":
249
+ parser = argparse.ArgumentParser(
250
+ description="",
251
+ formatter_class=argparse.RawDescriptionHelpFormatter,
252
+ )
253
+
254
+ parser.add_argument("--model_name_or_path", type=str, required=True, help="model_name_or_path")
255
+ parser.add_argument(
256
+ "--audio_tokenizer_path", type=str, required=True, help="audio_tokenizer_path"
257
+ )
258
+ parser.add_argument(
259
+ "--audio_tokenizer_type", type=str, required=True, help="audio_tokenizer_type"
260
+ )
261
+ parser.add_argument("--flow_path", type=str, required=True, help="flow_path")
262
+
263
+ parser.add_argument("--json_path", type=str, required=True, help="json_path")
264
+ parser.add_argument("--output_dir", type=str, required=True, help="output_dir")
265
+
266
+ parser.add_argument("--batch_size", type=int, default=1)
267
+ parser.add_argument("--num_workers", type=int, default=0)
268
+
269
+ args = parser.parse_args()
270
+
271
+ print(f"{args=}")
272
+
273
+ torch.distributed.init_process_group(
274
+ backend="nccl",
275
+ world_size=int(os.getenv("WORLD_SIZE", "1")),
276
+ rank=int(os.getenv("RANK", "0")),
277
+ timeout=timedelta(seconds=7200),
278
+ )
279
+
280
+ torch.cuda.set_device(int(os.getenv("LOCAL_RANK", 0)))
281
+
282
+ random.seed(42)
283
+ torch.manual_seed(42)
284
+
285
+ config = AutoConfig.from_pretrained(
286
+ args.model_name_or_path,
287
+ trust_remote_code=True,
288
+ )
289
+
290
+ # ================================================================
291
+ if "glm" in config.model_type.lower():
292
+ from get_chat_template import glm4_chat_template as chat_template
293
+
294
+ add_generation_prompt = True
295
+
296
+ default_system_message = [
297
+ {
298
+ "role": "system",
299
+ "content": "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens.",
300
+ }
301
+ ]
302
+
303
+ if "qwen2" in config.model_type.lower():
304
+ from get_chat_template import qwen2_chat_template as chat_template
305
+
306
+ add_generation_prompt = True
307
+
308
+ default_system_message = []
309
+
310
+ if "hunyuan" in config.model_type.lower():
311
+ from get_chat_template import hunyuan_chat_template as chat_template
312
+
313
+ add_generation_prompt = False
314
+
315
+ default_system_message = [
316
+ {
317
+ "role": "system",
318
+ "content": "You are a helpful AI assistant.",
319
+ }
320
+ ]
321
+
322
+ default_system_message = [
323
+ {
324
+ "role": "system",
325
+ # "content": "Your Name: Luke\nYour Gender: male\nRespond in a text-audio interleaved manner.",
326
+ # "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.",
327
+ "content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.",
328
+ },
329
+ ]
330
+
331
+ # ================================================================
332
+ print("Loading model")
333
+ device = "cuda"
334
+ # device_map = "auto"
335
+ device_map = "cuda"
336
+ # torch_dtype=torch.float16
337
+ torch_dtype = torch.bfloat16
338
+
339
+ rank = torch.distributed.get_rank()
340
+
341
+ audio_tokenizer = get_audio_tokenizer(
342
+ args.audio_tokenizer_path, args.audio_tokenizer_type, flow_path=args.flow_path, rank=rank
343
+ )
344
+
345
+ tokenizer = AutoTokenizer.from_pretrained(
346
+ args.model_name_or_path,
347
+ trust_remote_code=True,
348
+ chat_template=chat_template,
349
+ )
350
+ # print("tokenizer", tokenizer)
351
+
352
+ model = AutoModelForCausalLM.from_pretrained(
353
+ args.model_name_or_path,
354
+ trust_remote_code=True,
355
+ device_map=device_map,
356
+ torch_dtype=torch_dtype,
357
+ attn_implementation="flash_attention_2",
358
+ ).eval()
359
+ # print("model", model)
360
+
361
+ model.generation_config = GenerationConfig.from_pretrained(
362
+ args.model_name_or_path, trust_remote_code=True
363
+ )
364
+
365
+ model.generation_config.max_new_tokens = 4096
366
+ model.generation_config.chat_format = "chatml"
367
+ model.generation_config.max_window_size = 8192
368
+ model.generation_config.use_cache = True
369
+ model.generation_config.do_sample = False
370
+ model.generation_config.temperature = None
371
+ model.generation_config.top_p = None
372
+ model.generation_config.top_k = None
373
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
374
+ if model.config.model_type == "hunyuan":
375
+ model.generation_config.eos_token_id = tokenizer.eos_id
376
+
377
+ asr_model = load_asr_model()
378
+
379
+ # ================================================================
380
+ print("Loading data")
381
+ dataset = STSDataset(
382
+ json_path=args.json_path,
383
+ tokenizer=tokenizer,
384
+ audio_tokenizer=audio_tokenizer,
385
+ default_system_message=default_system_message,
386
+ add_generation_prompt=add_generation_prompt,
387
+ )
388
+
389
+ dataloader = torch.utils.data.DataLoader(
390
+ dataset=dataset,
391
+ sampler=InferenceSampler(len(dataset)),
392
+ batch_size=args.batch_size,
393
+ num_workers=args.num_workers,
394
+ pin_memory=True,
395
+ drop_last=False,
396
+ collate_fn=partial(
397
+ collate_fn,
398
+ ),
399
+ )
400
+
401
+ # ================================================================
402
+ outputs = inference(model, tokenizer, audio_tokenizer, dataloader, args.output_dir, asr_model)
403
+
404
+ torch.distributed.barrier()
405
+
406
+ world_size = torch.distributed.get_world_size()
407
+ merged_outputs = [None for _ in range(world_size)]
408
+ torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
409
+
410
+ merged_outputs = [json.loads(_) for _ in merged_outputs]
411
+ merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
412
+
413
+ if torch.distributed.get_rank() == 0:
414
+ # json_name = Path("_".join(os.path.normpath(args.json_path).split(os.sep)[-2:])).stem
415
+ json_name = Path(os.path.normpath(args.json_path).split(os.sep)[-1]).stem
416
+ hyp_text_path = os.path.join(args.output_dir, f"{json_name}_hyp_text.txt")
417
+ hyp_speech_path = os.path.join(args.output_dir, f"{json_name}_hyp_speech.txt")
418
+ ref_path = os.path.join(args.output_dir, f"{json_name}_ref.txt")
419
+
420
+ os.makedirs(os.path.dirname(ref_path), exist_ok=True)
421
+ os.makedirs(os.path.dirname(hyp_text_path), exist_ok=True)
422
+ os.makedirs(os.path.dirname(hyp_speech_path), exist_ok=True)
423
+
424
+ hyp_text_file = open(hyp_text_path, "w")
425
+ hyp_speech_file = open(hyp_speech_path, "w")
426
+ ref_file = open(ref_path, "w")
427
+
428
+ for sample_idx, (hyp_text, hyp_speech, ref) in enumerate(merged_outputs):
429
+ hyp_text_file.write(f"{sample_idx} {hyp_text}" + "\n")
430
+ hyp_speech_file.write(f"{sample_idx} {hyp_speech}" + "\n")
431
+ ref_file.write(f"{sample_idx} {ref}" + "\n")
432
+
433
+ hyp_text_file.close()
434
+ hyp_speech_file.close()
435
+ ref_file.close()
436
+
437
+ outputs_speech = [[x[1], x[2]] for x in merged_outputs]
438
+ outputs_text = [[x[0], x[2]] for x in merged_outputs]
439
+
440
+ hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_text.json")
441
+ hyp_ref_file = open(hyp_ref_path, "w")
442
+ json.dump(outputs_text, hyp_ref_file, indent=4)
443
+ hyp_ref_file.close()
444
+
445
+ hyp_ref_path = os.path.join(args.output_dir, f"{json_name}_hyp_ref_speech.json")
446
+ hyp_ref_file = open(hyp_ref_path, "w")
447
+ json.dump(outputs_speech, hyp_ref_file, indent=4)
448
+ hyp_ref_file.close()
449
+
450
+ torch.distributed.barrier()
451
+ print("Done.")
evaluation/get_chat_template.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ qwen2_chat_template = """
2
+ {%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n
3
+ """
4
+
5
+ qwen3_chat_template = """
6
+ "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
7
+ """
8
+
9
+ hunyuan_chat_template = """
10
+ {% set context = {'has_head': true} %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = message['content'] %}{% if loop.index0 == 0 %}{% if content == '' %}{% set _ = context.update({'has_head': false}) %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_4|>' %}{% endif %}{% endif %}{% if message['role'] == 'user' %}{% if loop.index0 == 1 and not context.has_head %}{% set content = '<|startoftext|>' + content %}{% endif %}{% if loop.index0 == 1 and context.has_head %}{% set content = content + '<|extra_0|>' %}{% else %}{% set content = '<|startoftext|>' + content + '<|extra_0|>' %}{% endif %}{% elif message['role'] == 'assistant' %}{% set content = content + '<|eos|>' %}{% endif %}{{ content }}{% endfor %}
11
+ """
12
+
13
+ glm4_chat_template = """
14
+ {%- for message in messages %} {%- if (message.role == "system") %} {{- '<|system|>' + '\n' + message.content }} {%- elif (message.role == "user") %} {{- '<|user|>' + '\n' + message.content }} {%- elif message.role == "assistant" %} {{- '<|assistant|>' }} {%- if message.content %} {{- 'streaming_transcription\n' + message.content }} {%- endif %} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} {{- '<|assistant|>streaming_transcription\n' }} {%- endif %}
15
+ """
16
+
17
+ if __name__ == "__main__":
18
+ from transformers import AutoTokenizer
19
+
20
+ chat = [
21
+ {"role": "system", "content": "You are a helpful assistant."},
22
+ {"role": "user", "content": "Hello, how are you?"},
23
+ {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
24
+ {"role": "user", "content": "I'd like to show off how chat templating works!"},
25
+ ]
26
+
27
+ # print("=" * 100)
28
+ # tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Instruct-2407")
29
+ # print(tokenizer.get_chat_template())
30
+ # message = tokenizer.apply_chat_template(chat, tokenize=False)
31
+ # print(message)
32
+
33
+ print("=" * 100)
34
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B-Instruct")
35
+ print(tokenizer.get_chat_template())
36
+ message = tokenizer.apply_chat_template(chat, tokenize=False)
37
+ print(message)
38
+
39
+ print("=" * 100)
40
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-70B-Instruct")
41
+ print(tokenizer.get_chat_template())
42
+ message = tokenizer.apply_chat_template(chat, tokenize=False)
43
+ print(message)
44
+
45
+ print("=" * 100)
46
+ tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1")
47
+ print(tokenizer.get_chat_template())
48
+ message = tokenizer.apply_chat_template(chat, tokenize=False)
49
+ print(message)
50
+ message = tokenizer.apply_chat_template(chat, tokenize=True)
51
+ print(message)
52
+
53
+ print("=" * 100)
54
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
55
+ print(tokenizer.get_chat_template())
56
+ message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=True)
57
+ print(message)
58
+ message = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True, enable_thinking=False)
59
+ print(message)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ -r requirements_ds_gpu.txt
requirements_ds_gpu.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ expecttest
2
+ peft
3
+ xlsxwriter
4
+ termcolor
5
+ tabulate
6
+ tiktoken
7
+ matplotlib
8
+ datasets
9
+ einops
10
+ pybind11
11
+ tensorboardX
12
+ pyarrow
13
+ transformers==4.48.3
14
+ deepspeed
15
+ accelerate>=1.1.1
16
+ timm
17
+ flask
18
+ flask_restful
19
+ decord
20
+ natsort
21
+ # setuptools==69.5.1
22
+ setuptools
23
+
24
+ # cosyvoice2
25
+ pyworld
26
+ evaluate
27
+ hyperpyyaml
28
+ diffusers
29
+ conformer
30
+ hydra-core
31
+ lightning
32
+ gdown
33
+ wget
34
+ funasr
35
+ zhconv
36
+ jiwer
37
+ zhon
38
+ WeTextProcessing
39
+ inflect
40
+ openai-whisper
41
+ onnxruntime
42
+ modelscope
43
+ word2number
44
+
scripts/deepspeed/ds_config_zero1.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupCosineLR",
25
+ "params": {
26
+ "total_num_steps": "auto",
27
+ "warmup_min_ratio": 0,
28
+ "warmup_num_steps": "auto",
29
+ "cos_min_ratio": 0.1
30
+ }
31
+ },
32
+
33
+ "zero_optimization": {
34
+ "stage": 1,
35
+ "offload_optimizer": {
36
+ "device": "none",
37
+ "pin_memory": true
38
+ },
39
+ "offload_param": {
40
+ "device": "none",
41
+ "pin_memory": true
42
+ },
43
+ "allgather_partitions": true,
44
+ "allgather_bucket_size": 5e8,
45
+ "overlap_comm": true,
46
+ "reduce_scatter": true,
47
+ "reduce_bucket_size": 5e8,
48
+ "contiguous_gradients": true,
49
+ "round_robin_gradients": true,
50
+ "sub_group_size": 1e12
51
+ },
52
+
53
+ "gradient_accumulation_steps": "auto",
54
+ "gradient_clipping": "auto",
55
+ "steps_per_print": 100,
56
+ "train_batch_size": "auto",
57
+ "train_micro_batch_size_per_gpu": "auto",
58
+ "wall_clock_breakdown": false,
59
+ "dump_state": false
60
+
61
+ }
scripts/deepspeed/ds_config_zero2.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupCosineLR",
25
+ "params": {
26
+ "total_num_steps": "auto",
27
+ "warmup_min_ratio": 0,
28
+ "warmup_num_steps": "auto",
29
+ "cos_min_ratio": 0.1
30
+ }
31
+ },
32
+
33
+ "zero_optimization": {
34
+ "stage": 2,
35
+ "offload_optimizer": {
36
+ "device": "none",
37
+ "pin_memory": true
38
+ },
39
+ "offload_param": {
40
+ "device": "none",
41
+ "pin_memory": true
42
+ },
43
+ "allgather_partitions": true,
44
+ "allgather_bucket_size": 5e8,
45
+ "overlap_comm": true,
46
+ "reduce_scatter": true,
47
+ "reduce_bucket_size": 5e8,
48
+ "contiguous_gradients": true,
49
+ "round_robin_gradients": true,
50
+ "sub_group_size": 1e12
51
+ },
52
+
53
+ "gradient_accumulation_steps": "auto",
54
+ "gradient_clipping": "auto",
55
+ "steps_per_print": 100,
56
+ "train_batch_size": "auto",
57
+ "train_micro_batch_size_per_gpu": "auto",
58
+ "wall_clock_breakdown": false,
59
+ "dump_state": false
60
+
61
+ }
scripts/deepspeed/ds_config_zero2_no_optimizer.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+
14
+ "scheduler": {
15
+ "type": "WarmupCosineLR",
16
+ "params": {
17
+ "total_num_steps": "auto",
18
+ "warmup_min_ratio": 0,
19
+ "warmup_num_steps": "auto",
20
+ "cos_min_ratio": 0.1
21
+ }
22
+ },
23
+
24
+ "zero_optimization": {
25
+ "stage": 2,
26
+ "offload_optimizer": {
27
+ "device": "none",
28
+ "pin_memory": true
29
+ },
30
+ "offload_param": {
31
+ "device": "none",
32
+ "pin_memory": true
33
+ },
34
+ "allgather_partitions": true,
35
+ "allgather_bucket_size": 5e8,
36
+ "overlap_comm": true,
37
+ "reduce_scatter": true,
38
+ "reduce_bucket_size": 5e8,
39
+ "contiguous_gradients": true,
40
+ "round_robin_gradients": true,
41
+ "sub_group_size": 1e12
42
+ },
43
+
44
+ "gradient_accumulation_steps": "auto",
45
+ "gradient_clipping": "auto",
46
+ "steps_per_print": 100,
47
+ "train_batch_size": "auto",
48
+ "train_micro_batch_size_per_gpu": "auto",
49
+ "wall_clock_breakdown": false,
50
+ "dump_state": false
51
+
52
+ }
scripts/deepspeed/ds_config_zero2_offload.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupCosineLR",
25
+ "params": {
26
+ "total_num_steps": "auto",
27
+ "warmup_min_ratio": 0,
28
+ "warmup_num_steps": "auto",
29
+ "cos_min_ratio": 0.1
30
+ }
31
+ },
32
+
33
+ "zero_optimization": {
34
+ "stage": 2,
35
+ "offload_optimizer": {
36
+ "device": "cpu",
37
+ "pin_memory": true
38
+ },
39
+ "offload_param": {
40
+ "device": "cpu",
41
+ "pin_memory": true
42
+ },
43
+ "allgather_partitions": true,
44
+ "allgather_bucket_size": 5e8,
45
+ "overlap_comm": true,
46
+ "reduce_scatter": true,
47
+ "reduce_bucket_size": 5e8,
48
+ "contiguous_gradients": true,
49
+ "round_robin_gradients": true,
50
+ "sub_group_size": 1e12
51
+ },
52
+
53
+ "gradient_accumulation_steps": "auto",
54
+ "gradient_clipping": "auto",
55
+ "steps_per_print": 100,
56
+ "train_batch_size": "auto",
57
+ "train_micro_batch_size_per_gpu": "auto",
58
+ "wall_clock_breakdown": false,
59
+ "dump_state": false
60
+
61
+ }
scripts/deepspeed/ds_config_zero3.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupCosineLR",
25
+ "params": {
26
+ "total_num_steps": "auto",
27
+ "warmup_min_ratio": 0,
28
+ "warmup_num_steps": "auto",
29
+ "cos_min_ratio": 0.1
30
+ }
31
+ },
32
+
33
+ "zero_optimization": {
34
+ "stage": 3,
35
+ "offload_optimizer": {
36
+ "device": "none",
37
+ "pin_memory": true
38
+ },
39
+ "offload_param": {
40
+ "device": "none",
41
+ "pin_memory": true
42
+ },
43
+ "allgather_partitions": true,
44
+ "allgather_bucket_size": 2e8,
45
+ "overlap_comm": true,
46
+ "reduce_scatter": true,
47
+ "reduce_bucket_size": 1e9,
48
+ "contiguous_gradients": true,
49
+ "sub_group_size": 1e9,
50
+ "stage3_prefetch_bucket_size": 1e9,
51
+ "stage3_param_persistence_threshold": 1e9,
52
+ "stage3_max_live_parameters": 1e9,
53
+ "stage3_max_reuse_distance": 1e9,
54
+ "stage3_gather_16bit_weights_on_model_save": true
55
+ },
56
+
57
+ "gradient_accumulation_steps": "auto",
58
+ "gradient_clipping": "auto",
59
+ "steps_per_print": 100,
60
+ "train_batch_size": "auto",
61
+ "train_micro_batch_size_per_gpu": "auto",
62
+ "wall_clock_breakdown": false
63
+ }
scripts/deepspeed/ds_config_zero3_offload.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+
23
+ "scheduler": {
24
+ "type": "WarmupCosineLR",
25
+ "params": {
26
+ "total_num_steps": "auto",
27
+ "warmup_min_ratio": 0,
28
+ "warmup_num_steps": "auto",
29
+ "cos_min_ratio": 0.1
30
+ }
31
+ },
32
+
33
+ "zero_optimization": {
34
+ "stage": 3,
35
+ "offload_optimizer": {
36
+ "device": "cpu",
37
+ "pin_memory": true
38
+ },
39
+ "offload_param": {
40
+ "device": "cpu",
41
+ "pin_memory": true
42
+ },
43
+ "allgather_partitions": true,
44
+ "allgather_bucket_size": 5e8,
45
+ "overlap_comm": true,
46
+ "reduce_scatter": true,
47
+ "reduce_bucket_size": 5e8,
48
+ "contiguous_gradients": true,
49
+ "round_robin_gradients": true,
50
+ "sub_group_size": 1e12,
51
+ "stage3_prefetch_bucket_size": 5e8,
52
+ "stage3_param_persistence_threshold": 1e5,
53
+ "stage3_max_live_parameters": 1e9,
54
+ "stage3_max_reuse_distance": 1e9,
55
+ "stage3_gather_16bit_weights_on_model_save": true
56
+ },
57
+
58
+ "flops_profiler": {
59
+ "enabled": false,
60
+ "profile_step": 1,
61
+ "module_depth": -1,
62
+ "top_modules": 1,
63
+ "detailed": true,
64
+ "output_file": null
65
+ },
66
+
67
+ "gradient_accumulation_steps": "auto",
68
+ "gradient_clipping": "auto",
69
+ "steps_per_print": 100,
70
+ "train_batch_size": "auto",
71
+ "train_micro_batch_size_per_gpu": "auto",
72
+ "wall_clock_breakdown": false,
73
+ "dump_state": false
74
+
75
+ }
scripts/deepspeed/evaluate_sts.sh ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H%M%S'`
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt install -y rsync
28
+ mkdir -p ${LOCAL_CODE_PATH}
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME/"
48
+ mkdir -p ${HF_HOME}
49
+ export HF_ENDPOINT=https://hf-mirror.com
50
+
51
+ export MODELSCOPE_CACHE="${ROOT_PATH}/data/MODELSCOPE_CACHE/"
52
+ mkdir -p ${MODELSCOPE_CACHE}
53
+
54
+ export LC_ALL="en_US.utf8"
55
+
56
+ ######################################################################
57
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
58
+ exec &> >(tee -a "$LOG")
59
+ echo Logging output to "$LOG"
60
+
61
+
62
+ ######################################################################
63
+ if true
64
+ #if false
65
+ then
66
+ MODEL_NAME_OR_PATH="/data/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh/VITA-Audio-Boost/"
67
+ MODEL_NAME_OR_PATH="/data/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh/VITA-Audio-Balance/"
68
+
69
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
70
+ FLOW_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-decoder
71
+ AUDIO_TOKENIZER_TYPE="glm4voice"
72
+
73
+ export PYTHONPATH=${PYTHONPATH}:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/cosyvoice/:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
74
+
75
+ fi
76
+
77
+ ######################################################################
78
+ DISTRIBUTED_ARGS="
79
+ --nproc_per_node $NPROC_PER_NODE \
80
+ --nnodes $NNODES \
81
+ --node_rank $NODE_RANK \
82
+ --master_addr $MASTER_ADDR \
83
+ --master_port $MASTER_PORT
84
+ "
85
+
86
+ ######################################################################
87
+ if true
88
+ #if false
89
+ then
90
+ apt-get update && apt install -y ffmpeg
91
+
92
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/llama-questions/test.jsonl
93
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
94
+ --json_path ${JSON_PATH} \
95
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
96
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
97
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
98
+ --flow_path ${FLOW_PATH} \
99
+ --output_dir ${OUTPUT_DIR}/llama-questions/
100
+
101
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/llama-questions/test_hyp_ref_text.json
102
+ echo "copypaste ACC: ${JSON_PATH}"
103
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/llama-questions/test_hyp_ref_speech.json
104
+ echo "copypaste ACC: ${JSON_PATH}"
105
+
106
+
107
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/trivia_qa-audio/validation.jsonl
108
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
109
+ --json_path ${JSON_PATH} \
110
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
111
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
112
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
113
+ --flow_path ${FLOW_PATH} \
114
+ --output_dir ${OUTPUT_DIR}/trivia_qa-audio/
115
+
116
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/trivia_qa-audio/validation_hyp_ref_text.json
117
+ echo "copypaste ACC: ${JSON_PATH}"
118
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/trivia_qa-audio/validation_hyp_ref_speech.json
119
+ echo "copypaste ACC: ${JSON_PATH}"
120
+
121
+
122
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/spoken-web-questions/test.jsonl
123
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_sqa.py \
124
+ --json_path ${JSON_PATH} \
125
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
126
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
127
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
128
+ --flow_path ${FLOW_PATH} \
129
+ --output_dir ${OUTPUT_DIR}/spoken-web-questions/
130
+
131
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/spoken-web-questions/test_hyp_ref_text.json
132
+ echo "copypaste ACC: ${JSON_PATH}"
133
+ python evaluation/compute-acc-of-contain.py ${OUTPUT_DIR}/spoken-web-questions/test_hyp_ref_speech.json
134
+ echo "copypaste ACC: ${JSON_PATH}"
135
+
136
+ fi
137
+
138
+
139
+ ######################################################################
140
+ if true
141
+ #if false
142
+ then
143
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/validation.clean.jsonl
144
+
145
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
146
+ --json_path ${JSON_PATH} \
147
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
148
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
149
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
150
+ --flow_path ${FLOW_PATH} \
151
+ --output_dir ${OUTPUT_DIR}/librispeech_asr/
152
+
153
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.clean_hyp.txt
154
+ #echo "copypaste CER: ${JSON_PATH}"
155
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.clean_hyp.txt
156
+ echo "copypaste WER: ${JSON_PATH}"
157
+
158
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/validation.other.jsonl
159
+
160
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
161
+ --json_path ${JSON_PATH} \
162
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
163
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
164
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
165
+ --flow_path ${FLOW_PATH} \
166
+ --output_dir ${OUTPUT_DIR}/librispeech_asr/
167
+
168
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.other_hyp.txt
169
+ #echo "copypaste CER: ${JSON_PATH}"
170
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/validation.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/validation.other_hyp.txt
171
+ echo "copypaste WER: ${JSON_PATH}"
172
+
173
+
174
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/test.clean.jsonl
175
+
176
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
177
+ --json_path ${JSON_PATH} \
178
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
179
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
180
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
181
+ --flow_path ${FLOW_PATH} \
182
+ --output_dir ${OUTPUT_DIR}/librispeech_asr/
183
+
184
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.clean_hyp.txt
185
+ #echo "copypaste CER: ${JSON_PATH}"
186
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.clean_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.clean_hyp.txt
187
+ echo "copypaste WER: ${JSON_PATH}"
188
+
189
+ JSON_PATH=${ROOT_PATH}/data/jsonl/fixie-ai/librispeech_asr/test.other.jsonl
190
+
191
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
192
+ --json_path ${JSON_PATH} \
193
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
194
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
195
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
196
+ --flow_path ${FLOW_PATH} \
197
+ --output_dir ${OUTPUT_DIR}/librispeech_asr/
198
+
199
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.other_hyp.txt
200
+ #echo "copypaste CER: ${JSON_PATH}"
201
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/librispeech_asr/test.other_ref.txt ${OUTPUT_DIR}/librispeech_asr/test.other_hyp.txt
202
+ echo "copypaste WER: ${JSON_PATH}"
203
+
204
+ fi
205
+
206
+
207
+ ######################################################################
208
+ if true
209
+ #if false
210
+ then
211
+ JSON_PATH=${ROOT_PATH}/data/jsonl/wenet-e2e/wenetspeech/TEST_MEETING.jsonl
212
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
213
+ --json_path ${JSON_PATH} \
214
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
215
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
216
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
217
+ --flow_path ${FLOW_PATH} \
218
+ --output_dir ${OUTPUT_DIR}/wenetspeech/
219
+
220
+ python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_hyp.txt
221
+ echo "copypaste CER: ${JSON_PATH}"
222
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_MEETING_hyp.txt
223
+ echo "copypaste WER: ${JSON_PATH}"
224
+
225
+ JSON_PATH=${ROOT_PATH}/data/jsonl/wenet-e2e/wenetspeech/TEST_NET.jsonl
226
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
227
+ --json_path ${JSON_PATH} \
228
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
229
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
230
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
231
+ --flow_path ${FLOW_PATH} \
232
+ --output_dir ${OUTPUT_DIR}/wenetspeech/
233
+
234
+ python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_NET_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_NET_hyp.txt
235
+ echo "copypaste CER: ${JSON_PATH}"
236
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/wenetspeech/TEST_NET_ref.txt ${OUTPUT_DIR}/wenetspeech/TEST_NET_hyp.txt
237
+ echo "copypaste WER: ${JSON_PATH}"
238
+ fi
239
+
240
+
241
+ ######################################################################
242
+ if true
243
+ #if false
244
+ then
245
+ JSON_PATH=${ROOT_PATH}/data/jsonl/shenyunhang/AISHELL-1/test.jsonl
246
+
247
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_asr.py \
248
+ --json_path ${JSON_PATH} \
249
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
250
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
251
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
252
+ --flow_path ${FLOW_PATH} \
253
+ --output_dir ${OUTPUT_DIR}/AISHELL-1/
254
+
255
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/AISHELL-1/_test.clean_ref.txt ${OUTPUT_DIR}/AISHELL-1/test.clean_hyp.txt
256
+ #echo "copypaste CER: ${JSON_PATH}"
257
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/AISHELL-1/test_ref.txt ${OUTPUT_DIR}/AISHELL-1/test_hyp.txt
258
+ echo "copypaste WER: ${JSON_PATH}"
259
+
260
+
261
+ fi
262
+
263
+
264
+ ######################################################################
265
+ if true
266
+ #if false
267
+ then
268
+ JSON_PATH=${ROOT_PATH}/data/jsonl/mythicinfinity/libritts/test.clean.jsonl
269
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_libritts.py \
270
+ --json_path ${JSON_PATH} \
271
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
272
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
273
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
274
+ --flow_path ${FLOW_PATH} \
275
+ --output_dir ${OUTPUT_DIR}/libritts/ \
276
+
277
+ #python evaluation/compute-cer.py --char=1 --v=1 ${OUTPUT_DIR}/libritts/test.clean_ref.txt ${OUTPUT_DIR}/libritts/test.clean_hyp.txt
278
+ #echo "copypaste CER: ${JSON_PATH}"
279
+ python evaluation/compute-wer.py --char=1 --v=1 ${OUTPUT_DIR}/libritts/test.clean_ref.txt ${OUTPUT_DIR}/libritts/test.clean_hyp.txt
280
+ echo "copypaste WER: ${JSON_PATH}"
281
+ fi
282
+
283
+
284
+ ######################################################################
285
+ if true
286
+ #if false
287
+ then
288
+
289
+ DATA_PATH=${ROOT_PATH}/data/BytedanceSpeech/seed-tts-eval/
290
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_seedtts.py \
291
+ --data_path ${DATA_PATH} \
292
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
293
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
294
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
295
+ --flow_path ${FLOW_PATH} \
296
+ --output_dir ${OUTPUT_DIR}/seed-tts/ \
297
+ --speaker_prompt \
298
+
299
+ export ARNOLD_WORKER_GPU=${NPROC_PER_NODE}
300
+ cd ${LOCAL_CODE_PATH}/third_party/seed-tts-eval
301
+
302
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ zh
303
+ echo "copypaste WER: ${DATA_PATH} zh"
304
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ zh
305
+ echo "copypaste WER: ${DATA_PATH} hardcase"
306
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ en
307
+ echo "copypaste WER: ${DATA_PATH} en"
308
+
309
+ bash cal_sim.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ ${DATA_PATH}/wavlm_large_finetune.pth
310
+ echo "copypaste SIM: ${DATA_PATH} zh"
311
+ bash cal_sim.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ ${DATA_PATH}/wavlm_large_finetune.pth
312
+ echo "copypaste SIM: ${DATA_PATH} hardcase"
313
+ bash cal_sim.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ ${DATA_PATH}/wavlm_large_finetune.pth
314
+ echo "copypaste SIM: ${DATA_PATH} en"
315
+
316
+ cd ${LOCAL_CODE_PATH}
317
+
318
+ fi
319
+
320
+
321
+ ######################################################################
322
+ if false
323
+ then
324
+ DATA_PATH=${ROOT_PATH}/data/BytedanceSpeech/seed-tts-eval/
325
+ torchrun ${DISTRIBUTED_ARGS} evaluation/evaluate_seedtts.py \
326
+ --data_path ${DATA_PATH} \
327
+ --model_name_or_path ${MODEL_NAME_OR_PATH} \
328
+ --audio_tokenizer_path ${AUDIO_TOKENIZER_PATH} \
329
+ --audio_tokenizer_type ${AUDIO_TOKENIZER_TYPE} \
330
+ --flow_path ${FLOW_PATH} \
331
+ --output_dir ${OUTPUT_DIR}/seed-tts/ \
332
+
333
+ export ARNOLD_WORKER_GPU=${NPROC_PER_NODE}
334
+ cd ${LOCAL_CODE_PATH}/third_party/seed-tts-eval
335
+
336
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/meta.lst ${OUTPUT_DIR}/seed-tts/zh/ zh
337
+ echo "copypaste WER: ${DATA_PATH} zh"
338
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/zh/hardcase.lst ${OUTPUT_DIR}/seed-tts/hardcase/ zh
339
+ echo "copypaste WER: ${DATA_PATH} hardcase"
340
+ bash cal_wer.sh ${DATA_PATH}/seedtts_testset/en/meta.lst ${OUTPUT_DIR}/seed-tts/en/ en
341
+ echo "copypaste WER: ${DATA_PATH} en"
342
+
343
+ cd ${LOCAL_CODE_PATH}
344
+
345
+ fi
346
+
347
+
348
+ set +x
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage1.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh/20250313_040353/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name vita_audio/models/qwen2_mtp_v4_48_3/config_7B_mtp10.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 8000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 1.00e-3 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --language-model-freeze \
126
+ --text-audio-interval-ratio 1 10 4 10 \
127
+
128
+ #--language-model-freeze \
129
+ #--dataset_joint false \
130
+ #--variable_length true \
131
+ #--tokenizer_name_or_path Qwen2Tokenizer \
132
+
133
+ #--bf16 True \
134
+ #--fp16 True \
135
+ #--tf32 True \
136
+
137
+ set +x
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp10_stage2.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage2.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/s2s_qwen25/finetune_glm4voice_mtp10_stage1.sh/20250315_022047/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${MODEL_NAME_OR_PATH} \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 4000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 5.00e-5 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.1 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2_no_optimizer.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --mtp_model_lr_mult 1.00e1 \
126
+ --text-audio-interval-ratio 1 10 4 10 \
127
+
128
+ #--language-model-freeze \
129
+ #--dataset_joint false \
130
+ #--variable_length true \
131
+ #--tokenizer_name_or_path Qwen2Tokenizer \
132
+
133
+ #--bf16 True \
134
+ #--fp16 True \
135
+ #--tf32 True \
136
+
137
+ set +x
scripts/deepspeed/sts_qwen25/finetune_glm4voice_mtp1_stage1.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/s2s_qwen25/finetune_glm4voice_stage1.sh/20250222_043913/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name vita_audio/models/qwen2_mtp_v4_48_3/config_7B_mtp1.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 4000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 1.00e-3 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --language-model-freeze \
126
+ --text-audio-interval-ratio 1 10 4 10 \
127
+
128
+ #--language-model-freeze \
129
+ #--dataset_joint false \
130
+ #--variable_length true \
131
+ #--tokenizer_name_or_path Qwen2Tokenizer \
132
+
133
+ #--bf16 True \
134
+ #--fp16 True \
135
+ #--tf32 True \
136
+
137
+ set +x
scripts/deepspeed/sts_qwen25/finetune_glm4voice_stage1.sh ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/models/Qwen/Qwen2.5-7B-Instruct/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${MODEL_NAME_OR_PATH} \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 8000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 6.00e-5 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --text-audio-interval-ratio 1 10 4 10 \
126
+
127
+ #--language-model-freeze \
128
+ #--dataset_joint false \
129
+ #--variable_length true \
130
+ #--tokenizer_name_or_path Qwen2Tokenizer \
131
+
132
+ #--bf16 True \
133
+ #--fp16 True \
134
+ #--tf32 True \
135
+
136
+ set +x
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh/20250418_075843/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp10.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "sensevoice_glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 8000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 1.00e-3 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --audio-model-freeze \
126
+ --language-model-freeze \
127
+ --text-audio-interval-ratio 1 10 4 10 \
128
+
129
+ #--language-model-freeze \
130
+ #--dataset_joint false \
131
+ #--variable_length true \
132
+ #--tokenizer_name_or_path Qwen2Tokenizer \
133
+
134
+ #--bf16 True \
135
+ #--fp16 True \
136
+ #--tf32 True \
137
+
138
+ set +x
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage2.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage2.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp10_stage1.sh/20250421_180624/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp10.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "sensevoice_glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 4000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 5.00e-5 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.1 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2_no_optimizer.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 2 \
125
+ --mtp_model_lr_mult 1.00e1 \
126
+ --audio-model-freeze \
127
+ --text-audio-interval-ratio 1 10 4 10 \
128
+
129
+ #--language-model-freeze \
130
+ #--dataset_joint false \
131
+ #--variable_length true \
132
+ #--tokenizer_name_or_path Qwen2Tokenizer \
133
+
134
+ #--bf16 True \
135
+ #--fp16 True \
136
+ #--tf32 True \
137
+
138
+ set +x
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_mtp1_stage1.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/output/LM/scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh/20250409_161438/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp1.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "sensevoice_glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 4000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 1.00e-3 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --audio-model-freeze \
126
+ --language-model-freeze \
127
+ --text-audio-interval-ratio 1 10 4 10 \
128
+
129
+ #--language-model-freeze \
130
+ #--dataset_joint false \
131
+ #--variable_length true \
132
+ #--tokenizer_name_or_path Qwen2Tokenizer \
133
+
134
+ #--bf16 True \
135
+ #--fp16 True \
136
+ #--tf32 True \
137
+
138
+ set +x
scripts/deepspeed/sts_qwen25/finetune_sensevoice_glm4voice_stage1.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -e
4
+ set -x
5
+
6
+ SEQ_LENGTH="$1"
7
+ if [ -z "$SEQ_LENGTH" ]
8
+ then
9
+ SEQ_LENGTH=32768
10
+ fi
11
+
12
+ timestamp="$2"
13
+ if [ -z "$timestamp" ]
14
+ then
15
+ timestamp=`date +'%Y%m%d_%H'`0000
16
+ fi
17
+
18
+ ######################################################################
19
+ export ROOT_PATH=/data/
20
+ export CODE_PATH=${ROOT_PATH}/VITA-Audio/
21
+
22
+ export LOCAL_ROOT_PATH=/data_local/
23
+ export LOCAL_CODE_PATH=${LOCAL_ROOT_PATH}/VITA-Audio/
24
+ mkdir -p ${LOCAL_ROOT_PATH}
25
+ mkdir -p ${LOCAL_CODE_PATH}
26
+
27
+ apt update
28
+ apt install -y rsync
29
+ rsync -a --exclude ".git" --exclude ".gitee" ${CODE_PATH}/ ${LOCAL_CODE_PATH}/
30
+
31
+ cd ${LOCAL_CODE_PATH}
32
+ rm -fr datasets
33
+ ln -s ${ROOT_PATH}/data datasets
34
+
35
+ ######################################################################
36
+ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
37
+ source ${LOCAL_CODE_PATH}/scripts/set_env_ds_gpu.sh
38
+ pip3 install transformers==4.48.3
39
+ #pip3 install --no-index --find-links=/data/software/ transformers==4.48.3
40
+
41
+ ######################################################################
42
+ OUTPUT_DIR=${ROOT_PATH}/output/LM/"$0"/${timestamp}/
43
+
44
+ mkdir -p ${OUTPUT_DIR}
45
+ rsync -avh $0 ${OUTPUT_DIR}
46
+
47
+ export HF_HOME="${ROOT_PATH}/data/HF_HOME_node${INDEX}/"
48
+ mkdir -p ${HF_HOME}
49
+
50
+ export TRITON_CACHE_DIR=${LOCAL_CODE_PATH}
51
+
52
+ export PYTHONPATH=$PYTHONPATH:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice:${LOCAL_CODE_PATH}/third_party/GLM-4-Voice/third_party/Matcha-TTS/
53
+
54
+ ######################################################################
55
+ LOG=${OUTPUT_DIR}/log_node${INDEX}.txt
56
+ exec &> >(tee -a "$LOG")
57
+ echo Logging output to "$LOG"
58
+
59
+ echo ${@}
60
+
61
+ ######################################################################
62
+ DATA_PATH=${LOCAL_CODE_PATH}/configs/sts_finetune_stage1.yaml
63
+
64
+ MODEL_NAME_OR_PATH=${ROOT_PATH}/models/Qwen/Qwen2.5-7B-Instruct/
65
+
66
+ AUDIO_TOKENIZER_PATH=${ROOT_PATH}/models/THUDM/glm-4-voice-tokenizer
67
+
68
+ rsync -avh ${DATA_PATH} ${OUTPUT_DIR}
69
+
70
+ ######################################################################
71
+ DISTRIBUTED_ARGS="
72
+ --nproc_per_node $NPROC_PER_NODE \
73
+ --nnodes $NNODES \
74
+ --node_rank $NODE_RANK \
75
+ --master_addr $MASTER_ADDR \
76
+ --master_port $MASTER_PORT
77
+ "
78
+
79
+ torchrun $DISTRIBUTED_ARGS tools/finetune_sts_v4_48_3.py \
80
+ --log_level "info" \
81
+ --do_train \
82
+ --overwrite_output_dir \
83
+ --config_name ${LOCAL_CODE_PATH}/VITA-Audio/models/qwen2_mtp_sensevoice_v4_48_3/config_7B_mtp0.json \
84
+ --tokenizer_name $MODEL_NAME_OR_PATH \
85
+ --model_name_or_path $MODEL_NAME_OR_PATH \
86
+ --audio_tokenizer_path $AUDIO_TOKENIZER_PATH \
87
+ --audio_tokenizer_type "sensevoice_glm4voice" \
88
+ --dataset_name $DATA_PATH \
89
+ --bf16 True \
90
+ --tf32 True \
91
+ --torch_dtype bfloat16 \
92
+ --output_dir $OUTPUT_DIR \
93
+ --num_train_epochs 1 \
94
+ --max_steps 8000 \
95
+ --per_device_train_batch_size 1 \
96
+ --per_device_eval_batch_size 1 \
97
+ --gradient_accumulation_steps 16 \
98
+ --save_strategy "steps" \
99
+ --save_steps 0.1 \
100
+ --save_total_limit 2 \
101
+ --learning_rate 6.00e-5 \
102
+ --max_grad_norm 1.0 \
103
+ --weight_decay 0.0 \
104
+ --adam_beta1 0.9 \
105
+ --adam_beta2 0.95 \
106
+ --adam_epsilon 1e-8 \
107
+ --warmup_ratio 0.03 \
108
+ --lr_scheduler_type "cosine" \
109
+ --logging_steps 1 \
110
+ --report_to "tensorboard" \
111
+ --model_max_length ${SEQ_LENGTH} \
112
+ --gradient_checkpointing True \
113
+ --deepspeed ${LOCAL_CODE_PATH}/scripts/deepspeed/ds_config_zero2.json \
114
+ --trust_remote_code False \
115
+ --ddp_timeout 7200 \
116
+ --ddp_backend ${DISTRIBUTED_BACKEND} \
117
+ --attn_implementation flash_attention_2 \
118
+ --seed 42 \
119
+ --data_seed 42 \
120
+ --reset_attention_mask \
121
+ --reset_position_ids \
122
+ --create_attention_mask false \
123
+ --create_attention_mask_2d false \
124
+ --dataloader_num_workers 8 \
125
+ --audio-model-freeze \
126
+ --text-audio-interval-ratio 1 10 4 10 \
127
+
128
+ #--language-model-freeze \
129
+ #--dataset_joint false \
130
+ #--variable_length true \
131
+ #--tokenizer_name_or_path Qwen2Tokenizer \
132
+
133
+ #--bf16 True \
134
+ #--fp16 True \
135
+ #--tf32 True \
136
+
137
+ set +x
scripts/set_env_ds_gpu.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #set -e
2
+ #set -x
3
+
4
+ ######################################################################
5
+ #export NCCL_NET=IB
6
+
7
+ #export NCCL_SOCKET_IFNAME="bond1"
8
+ #export GLOO_SOCKET_IFNAME="bond1"
9
+ #export NCCL_DEBUG=INFO
10
+ #export NCCL_IB_QPS_PER_CONNECTION=2
11
+
12
+ #export GLOO_SOCKET_IFNAME=eth0
13
+ #export NCCL_DEBUG=INFO
14
+ #export NCCL_IB_QPS_PER_CONNECTION=2
15
+
16
+ #export NCCL_IB_DISABLE=1
17
+
18
+ export DISTRIBUTED_BACKEND="nccl"
19
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
20
+
21
+ ######################################################################
22
+ pip3 install -r requirements_ds_gpu.txt
23
+ #pip3 install --no-index --find-links=/data/software/ -r requirements_ds_gpu.txt
24
+
25
+ pip3 install deepspeed==0.15.4
26
+ #pip3 install --no-index --find-links=/data/software/ deepspeed==0.15.4
27
+ #pip3 install deepspeed==0.16.1
28
+ #pip3 install deepspeed==0.14.2
29
+
30
+ pip3 install -e `pwd`
31
+
32
+ ######################################################################
33
+ #export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
34
+
35
+ #apt update
36
+ #apt install -y openssh-server
37
+ #apt install -y rsync
38
+
39
+ ######################################################################
40
+
41
+ export NNODES=${WORLD_SIZE}
42
+ export NODE_RANK=${RANK}
43
+ export MASTER_PORT=34567
44
+
45
+ if [ -z "$NPROC_PER_NODE" ]
46
+ then
47
+ export NPROC_PER_NODE=8
48
+ export NNODES=1
49
+ export NODE_RANK=0
50
+ export MASTER_ADDR=127.0.0.1
51
+ fi
52
+
53
+ ######################################################################
setup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name='vita_audio',
5
+ version='0.0.1',
6
+ packages=[
7
+ "vita_audio",
8
+ ],
9
+ install_requires=[
10
+ ],
11
+ )
12
+
third_party/GLM-4-Voice/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *venv
2
+ *.DS_Store
3
+ *.idea/
4
+ test*
third_party/GLM-4-Voice/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/Matcha-TTS"]
2
+ path = third_party/Matcha-TTS
3
+ url = https://github.com/shivammehta25/Matcha-TTS
third_party/GLM-4-Voice/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 GLM-4-Voice Model Team @ Zhipu AI
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
third_party/GLM-4-Voice/README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLM-4-Voice
2
+ <p align="center">
3
+ 📄<a href="https://arxiv.org/abs/2412.02612" target="_blank"> Report </a> • 🤗 <a href="https://huggingface.co/THUDM/glm-4-voice-9b" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/studios/ZhipuAI/GLM-4-Voice-Demo" target="_blank">Demo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
4
+ </p>
5
+
6
+ Read this in [English](./README_en.md)
7
+
8
+ GLM-4-Voice 是智谱 AI 推出的端到端语音模型。GLM-4-Voice 能够直接理解和生成中英文语音,进行实时语音对话,并且能够遵循用户的指令要求改变语音的情感、语调、语速、方言等属性。
9
+
10
+ ## Model Architecture
11
+ ![Model Architecture](./resources/architecture.jpeg)
12
+
13
+ GLM-4-Voice 由三个部分组成:
14
+ * GLM-4-Voice-Tokenizer: 通过在 [Whisper](https://github.com/openai/whisper) 的 Encoder 部分增加 Vector Quantization 并在 ASR 数据上有监督训练,将连续的语音输入转化为离散的 token。每秒音频平均只需要用 12.5 个离散 token 表示。
15
+ * GLM-4-Voice-Decoder: 基于 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 的 Flow Matching 模型结构训练的支持流式推理的语音解码器,将离散化的语音 token 转化为连续的语音输出。最少只需要 10 个语音 token 即可开始生成,降低端到端对话延迟。
16
+ * GLM-4-Voice-9B: 在 [GLM-4-9B](https://github.com/THUDM/GLM-4) 的基础上进行语音模态的预训练和对齐,从而能够理解和生成离散化的语音 token。
17
+
18
+ 预训练方面,为了攻克模型在语音模态下的智商和合成表现力两个难关,我们将 Speech2Speech 任务解耦合为“根据用户音频做出文本回复”和“根据文本回复和用户语音合成回复语音”两个任务,并设计两种预训练目标,分别基于文本预训练数据和无监督音频数据合成语音-文本交错数据以适配这两种任务形式。GLM-4-Voice-9B 在 GLM-4-9B 的基座模型基础之上,经过了数百万小时音频和数千亿 token 的音频文本交错数据预训练,拥有很强的音频理解和建模能力。
19
+
20
+ 对齐方面,为了支持高质量的语音对话,我们设计了一套流式思考架构:根据用户语音,GLM-4-Voice 可以流式交替输出文本和语音两个模态的内容,其中语音模态以文本作为参照保证回复内容的高质量,并根据用户的语音指令要求做出相应的声音变化,在最大程度保留语言模型智商的情况下仍然具有端到端建模的能力,同时具备低延迟性,最低只需要输出 20 个 token 便可以合成语音。
21
+
22
+ ## Model List
23
+
24
+ | Model | Type | Download |
25
+ |:---------------------:|:----------------:|:------------------------------------------------------------------------------------------------------------------------------------------------:|
26
+ | GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-tokenizer) |
27
+ | GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-9b) |
28
+ | GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder) [🤖 ModelScope](https://modelscope.cn/models/ZhipuAI/glm-4-voice-decoder) |
29
+
30
+ ## Usage
31
+ 我们提供了可以直接启动的 Web Demo。用户可以输入语音或文本,模型会同时给出语音和文字回复。
32
+
33
+ ![](resources/web_demo.png)
34
+
35
+ ### Preparation
36
+
37
+ 首先下载仓库
38
+ ```shell
39
+ git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
40
+ cd GLM-4-Voice
41
+ ```
42
+ 然后安装依赖。也可以使用我们提供的镜像 `zhipuai/glm-4-voice:0.1` 以跳过这一步。
43
+ ```shell
44
+ pip install -r requirements.txt
45
+ ```
46
+ 由于 Decoder 模型不支持通过 `transformers` 初始化,因此 checkpoint 需要单独下载。
47
+
48
+ ```shell
49
+ # git 模型下载,请确保已安装 git-lfs
50
+ git lfs install
51
+ git clone https://huggingface.co/THUDM/glm-4-voice-decoder
52
+ ```
53
+
54
+ ### Launch Web Demo
55
+
56
+ 1. 启动模型服务
57
+
58
+ ```shell
59
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
60
+ ```
61
+
62
+ 如果你需要使用 Int4 精度启动,请运行
63
+
64
+ ```shell
65
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
66
+ ```
67
+
68
+ 此命令会自动下载 `glm-4-voice-9b`。如果网络条件不好,也手动下载之后通过 `--model-path` 指定本地的路径。
69
+
70
+ 2. 启动 web 服务
71
+
72
+ ```shell
73
+ python web_demo.py --tokenizer-path THUDM/glm-4-voice-tokenizer --model-path THUDM/glm-4-voice-9b --flow-path ./glm-4-voice-decoder
74
+ ```
75
+
76
+ 即可在 http://127.0.0.1:8888 访问 web demo。
77
+
78
+ 此命令会自动下载 `glm-4-voice-tokenizer` 和 `glm-4-voice-9b`。 请注意,`glm-4-voice-decoder` 需要手动下载。
79
+
80
+ 如果网络条件不好,可以手动下载这三个模型之后通过 `--tokenizer-path`, `--flow-path` 和 `--model-path` 指定本地的路径。
81
+
82
+ ### Known Issues
83
+
84
+ * Gradio 的流式音频播放效果不稳定。在生成完成后点击对话框中的音频质量会更高。
85
+
86
+ ## Cases
87
+
88
+ 我们提供了 GLM-4-Voice 的部分对话案例,包括控制情绪、改变语速、生成方言等。
89
+
90
+ * 用轻柔的声音引导我放松
91
+
92
+ https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
93
+
94
+ * 用激动的声音解说足球比赛
95
+
96
+ https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
97
+
98
+ * 用哀怨的声音讲一个鬼故事
99
+
100
+ https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
101
+
102
+ * 用东北话介绍一下冬天有多冷
103
+
104
+ https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
105
+
106
+ * 用重庆话念“吃葡萄不吐葡萄皮”
107
+
108
+ https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
109
+
110
+ * 用北京话念一句绕口令
111
+
112
+ https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
113
+
114
+ * 加快语速
115
+
116
+ https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
117
+
118
+ * 再快一点
119
+
120
+ https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
121
+
122
+ ## Acknowledgements
123
+
124
+ 本项目的部分代码来自:
125
+ * [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
126
+ * [transformers](https://github.com/huggingface/transformers)
127
+ * [GLM-4](https://github.com/THUDM/GLM-4)
128
+
129
+ ## 协议
130
+
131
+ + GLM-4 模型的权重的使用则需要遵循 [模型协议](https://huggingface.co/THUDM/glm-4-voice-9b/blob/main/LICENSE)。
132
+
133
+ + 本开源仓库的代码则遵循 [Apache 2.0](LICENSE) 协议。
134
+
135
+ ## 引用
136
+
137
+ ```
138
+ @misc{zeng2024glm4,
139
+ title={GLM-4-Voice: Towards Intelligent and Human-Like End-to-End Spoken Chatbot},
140
+ author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Kedong Wang and Shengmin Jiang and Lei Zhao and Yuxiao Dong and Jie Tang},
141
+ year={2024},
142
+ eprint={2412.02612},
143
+ archivePrefix={arXiv},
144
+ primaryClass={cs.CL},
145
+ url={https://arxiv.org/abs/2412.02612},
146
+ }
147
+ ```
148
+
149
+ ```
150
+ @misc{zeng2024scaling,
151
+ title={Scaling Speech-Text Pre-training with Synthetic Interleaved Data},
152
+ author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Lei Zhang and Shengmin Jiang and Yuxiao Dong and Jie Tang},
153
+ year={2024},
154
+ eprint={2411.17607},
155
+ archivePrefix={arXiv},
156
+ primaryClass={cs.CL},
157
+ url={https://arxiv.org/abs/2411.17607},
158
+ }
159
+ ```
third_party/GLM-4-Voice/README_en.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GLM-4-Voice
2
+ <p align="center">
3
+ 📄<a href="https://arxiv.org/abs/2412.02612" target="_blank"> Report </a> • 🤗 <a href="https://huggingface.co/THUDM/glm-4-voice-9b" target="_blank">HF Repo</a> • 🤖 <a href="https://modelscope.cn/studios/ZhipuAI/GLM-4-Voice-Demo" target="_blank">Demo</a> • 🐦 <a href="https://twitter.com/thukeg" target="_blank">Twitter</a>
4
+ </p>
5
+
6
+ GLM-4-Voice is an end-to-end voice model launched by Zhipu AI. GLM-4-Voice can directly understand and generate Chinese and English speech, engage in real-time voice conversations, and change attributes such as emotion, intonation, speech rate, and dialect based on user instructions.
7
+
8
+ ## Model Architecture
9
+
10
+ ![Model Architecture](./resources/architecture.jpeg)
11
+ We provide the three components of GLM-4-Voice:
12
+ * GLM-4-Voice-Tokenizer: Trained by adding vector quantization to the encoder part of [Whisper](https://github.com/openai/whisper), converting continuous speech input into discrete tokens. Each second of audio is converted into 12.5 discrete tokens.
13
+ * GLM-4-Voice-9B: Pre-trained and aligned on speech modality based on [GLM-4-9B](https://github.com/THUDM/GLM-4), enabling understanding and generation of discretized speech.
14
+ * GLM-4-Voice-Decoder: A speech decoder supporting streaming inference, retrained based on [CosyVoice](https://github.com/FunAudioLLM/CosyVoice), converting discrete speech tokens into continuous speech output. Generation can start with as few as 10 audio tokens, reducing conversation latency.
15
+
16
+ ## Model List
17
+
18
+ | Model | Type | Download |
19
+ |:---------------------:|:----------------:|:--------------------------------------------------------------------:|
20
+ | GLM-4-Voice-Tokenizer | Speech Tokenizer | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-tokenizer) |
21
+ | GLM-4-Voice-9B | Chat Model | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-9b) |
22
+ | GLM-4-Voice-Decoder | Speech Decoder | [🤗 Huggingface](https://huggingface.co/THUDM/glm-4-voice-decoder) |
23
+
24
+ ## Usage
25
+ We provide a Web Demo that can be launched directly. Users can input speech or text, and the model will respond with both speech and text.
26
+
27
+ ![](resources/web_demo.png)
28
+
29
+ ### Preparation
30
+
31
+ First, download the repository
32
+ ```shell
33
+ git clone --recurse-submodules https://github.com/THUDM/GLM-4-Voice
34
+ cd GLM-4-Voice
35
+ ```
36
+ Then, install the dependencies. You can also use our pre-built docker image `zhipuai/glm-4-voice:0.1` to skip the step.
37
+ ```shell
38
+ pip install -r requirements.txt
39
+ ```
40
+ Since the Decoder model does not support initialization via `transformers`, the checkpoint needs to be downloaded separately.
41
+
42
+ ```shell
43
+ # Git model download, please ensure git-lfs is installed
44
+ git clone https://huggingface.co/THUDM/glm-4-voice-decoder
45
+ ```
46
+
47
+ ### Launch Web Demo
48
+
49
+ 1. Start the model server
50
+
51
+ ```shell
52
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
53
+ ```
54
+
55
+ If you need to launch with Int4 precision, run
56
+
57
+ ```shell
58
+ python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
59
+ ```
60
+
61
+ This command will automatically download `glm-4-voice-9b`. If network conditions are poor, you can manually download it and specify the local path using `--model-path`.
62
+
63
+ 2. Start the web service
64
+
65
+ ```shell
66
+ python web_demo.py --tokenizer-path THUDM/glm-4-voice-tokenizer --model-path THUDM/glm-4-voice-9b --flow-path ./glm-4-voice-decoder
67
+ ```
68
+
69
+ You can access the web demo at [http://127.0.0.1:8888](http://127.0.0.1:8888).
70
+ This command will automatically download `glm-4-voice-tokenizer` and `glm-4-voice-9b`. Please note that `glm-4-voice-decoder` needs to be downloaded manually.
71
+ If the network connection is poor, you can manually download these three models and specify the local paths using `--tokenizer-path`, `--flow-path`, and `--model-path`.
72
+
73
+ ### Known Issues
74
+ * Gradio’s streaming audio playback can be unstable. The audio quality will be higher when clicking on the audio in the dialogue box after generation is complete.
75
+
76
+ ## Examples
77
+ We provide some dialogue cases for GLM-4-Voice, including emotion control, speech rate alteration, dialect generation, etc. (The examples are in Chinese.)
78
+
79
+ * Use a gentle voice to guide me to relax
80
+
81
+ https://github.com/user-attachments/assets/4e3d9200-076d-4c28-a641-99df3af38eb0
82
+
83
+ * Use an excited voice to commentate a football match
84
+
85
+ https://github.com/user-attachments/assets/0163de2d-e876-4999-b1bc-bbfa364b799b
86
+
87
+ * Tell a ghost story with a mournful voice
88
+
89
+ https://github.com/user-attachments/assets/a75b2087-d7bc-49fa-a0c5-e8c99935b39a
90
+
91
+ * Introduce how cold winter is with a Northeastern dialect
92
+
93
+ https://github.com/user-attachments/assets/91ba54a1-8f5c-4cfe-8e87-16ed1ecf4037
94
+
95
+ * Say "Eat grapes without spitting out the skins" in Chongqing dialect
96
+
97
+ https://github.com/user-attachments/assets/7eb72461-9e84-4d8e-9c58-1809cf6a8a9b
98
+
99
+ * Recite a tongue twister with a Beijing accent
100
+
101
+ https://github.com/user-attachments/assets/a9bb223e-9c0a-440d-8537-0a7f16e31651
102
+
103
+ * Increase the speech rate
104
+
105
+ https://github.com/user-attachments/assets/c98a4604-366b-4304-917f-3c850a82fe9f
106
+
107
+ * Even faster
108
+
109
+ https://github.com/user-attachments/assets/d5ff0815-74f8-4738-b0f1-477cfc8dcc2d
110
+
111
+ ## Acknowledgements
112
+
113
+ Some code in this project is from:
114
+ * [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
115
+ * [transformers](https://github.com/huggingface/transformers)
116
+ * [GLM-4](https://github.com/THUDM/GLM-4)
117
+
118
+ ## License Agreement
119
+
120
+ + The use of GLM-4 model weights must follow the [Model License Agreement](https://huggingface.co/THUDM/glm-4-voice-9b/blob/main/LICENSE).
121
+
122
+ + The code in this open-source repository is licensed under the [Apache 2.0](LICENSE) License.
123
+
124
+ ## Citation
125
+
126
+ ```
127
+ @misc{zeng2024glm4,
128
+ title={GLM-4-Voice: Towards Intelligent and Human-Like End-to-End Spoken Chatbot},
129
+ author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Kedong Wang and Shengmin Jiang and Lei Zhao and Yuxiao Dong and Jie Tang},
130
+ year={2024},
131
+ eprint={2412.02612},
132
+ archivePrefix={arXiv},
133
+ primaryClass={cs.CL},
134
+ url={https://arxiv.org/abs/2412.02612},
135
+ }
136
+ ```
137
+
138
+ ```
139
+ @misc{zeng2024scaling,
140
+ title={Scaling Speech-Text Pre-training with Synthetic Interleaved Data},
141
+ author={Aohan Zeng and Zhengxiao Du and Mingdao Liu and Lei Zhang and Shengmin Jiang and Yuxiao Dong and Jie Tang},
142
+ year={2024},
143
+ eprint={2411.17607},
144
+ archivePrefix={arXiv},
145
+ primaryClass={cs.CL},
146
+ url={https://arxiv.org/abs/2411.17607},
147
+ }
148
+ ```
third_party/GLM-4-Voice/audio_process.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+ import soundfile as sf
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import io
7
+
8
+ # Split audio stream at silence points to prevent playback stuttering issues
9
+ # caused by AAC encoder frame padding when streaming audio through Gradio audio components.
10
+ class AudioStreamProcessor:
11
+ def __init__(self, sr=22050, min_silence_duration=0.1, threshold_db=-40):
12
+ self.sr = sr
13
+ self.min_silence_duration = min_silence_duration
14
+ self.threshold_db = threshold_db
15
+ self.buffer = np.array([])
16
+
17
+
18
+ def process(self, audio_data, last=False):
19
+ """
20
+ Add audio data and process it
21
+ params:
22
+ audio_data: audio data in numpy array
23
+ last: whether this is the last chunk of data
24
+ returns:
25
+ Processed audio data, returns None if no split point is found
26
+ """
27
+
28
+ # Add new data to buffer
29
+ self.buffer = np.concatenate([self.buffer, audio_data]) if len(self.buffer) > 0 else audio_data
30
+
31
+ if last:
32
+ result = self.buffer
33
+ self.buffer = np.array([])
34
+ return self._to_wav_bytes(result)
35
+
36
+ # Find silence boundary
37
+ split_point = self._find_silence_boundary(self.buffer)
38
+
39
+ if split_point is not None:
40
+ # Modified: Extend split point to the end of silence
41
+ silence_end = self._find_silence_end(split_point)
42
+ result = self.buffer[:silence_end]
43
+ self.buffer = self.buffer[silence_end:]
44
+ return self._to_wav_bytes(result)
45
+
46
+ return None
47
+
48
+ def _find_silence_boundary(self, audio):
49
+ """
50
+ Find the starting point of silence boundary in audio
51
+ """
52
+ # Convert audio to decibels
53
+ db = librosa.amplitude_to_db(np.abs(audio), ref=np.max)
54
+
55
+ # Find points below threshold
56
+ silence_points = np.where(db < self.threshold_db)[0]
57
+
58
+ if len(silence_points) == 0:
59
+ return None
60
+
61
+ # Calculate minimum silence samples
62
+ min_silence_samples = int(self.min_silence_duration * self.sr)
63
+
64
+ # Search backwards for continuous silence segment starting point
65
+ for i in range(len(silence_points) - min_silence_samples, -1, -1):
66
+ if i < 0:
67
+ break
68
+ if np.all(np.diff(silence_points[i:i+min_silence_samples]) == 1):
69
+ return silence_points[i]
70
+
71
+ return None
72
+
73
+ def _find_silence_end(self, start_point):
74
+ """
75
+ Find the end point of silence segment
76
+ """
77
+ db = librosa.amplitude_to_db(np.abs(self.buffer[start_point:]), ref=np.max)
78
+ silence_points = np.where(db >= self.threshold_db)[0]
79
+
80
+ if len(silence_points) == 0:
81
+ return len(self.buffer)
82
+
83
+ return start_point + silence_points[0]
84
+
85
+ def _to_wav_bytes(self, audio_data):
86
+ """
87
+ trans_to_wav_bytes
88
+ """
89
+ wav_buffer = io.BytesIO()
90
+ sf.write(wav_buffer, audio_data, self.sr, format='WAV')
91
+ return wav_buffer.getvalue()
92
+
93
+
third_party/GLM-4-Voice/cosyvoice/__init__.py ADDED
File without changes
third_party/GLM-4-Voice/cosyvoice/bin/inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader
24
+ import torchaudio
25
+ from hyperpyyaml import load_hyperpyyaml
26
+ from tqdm import tqdm
27
+ from cosyvoice.cli.model import CosyVoiceModel
28
+
29
+ from cosyvoice.dataset.dataset import Dataset
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser(description='inference with your model')
33
+ parser.add_argument('--config', required=True, help='config file')
34
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
35
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
36
+ parser.add_argument('--tts_text', required=True, help='tts input file')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ with open(args.config, 'r') as f:
64
+ configs = load_hyperpyyaml(f)
65
+
66
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
+
69
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text = batch["text"]
81
+ text_token = batch["text_token"].to(device)
82
+ text_token_len = batch["text_token_len"].to(device)
83
+ tts_text = batch["tts_text"]
84
+ tts_index = batch["tts_index"]
85
+ tts_text_token = batch["tts_text_token"].to(device)
86
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
87
+ speech_token = batch["speech_token"].to(device)
88
+ speech_token_len = batch["speech_token_len"].to(device)
89
+ speech_feat = batch["speech_feat"].to(device)
90
+ speech_feat_len = batch["speech_feat_len"].to(device)
91
+ utt_embedding = batch["utt_embedding"].to(device)
92
+ spk_embedding = batch["spk_embedding"].to(device)
93
+ if args.mode == 'sft':
94
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
95
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
96
+ else:
97
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
98
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
99
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
100
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ model_output = model.inference(**model_input)
104
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
+ f.write('{} {}\n'.format(tts_key, tts_fn))
108
+ f.flush()
109
+ f.close()
110
+ logging.info('Result wav.scp saved in {}'.format(fn))
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()
third_party/GLM-4-Voice/cosyvoice/bin/train.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+ import argparse
17
+ import datetime
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ from copy import deepcopy
21
+ import torch
22
+ import torch.distributed as dist
23
+ # import deepspeed
24
+ import pdb
25
+ from hyperpyyaml import load_hyperpyyaml
26
+
27
+ from torch.distributed.elastic.multiprocessing.errors import record
28
+
29
+ from cosyvoice.utils.executor import Executor
30
+ from cosyvoice.utils.train_utils import (
31
+ init_distributed,
32
+ init_dataset_and_dataloader,
33
+ init_optimizer_and_scheduler,
34
+ init_summarywriter, save_model,
35
+ wrap_cuda_model, check_modify_and_save_config)
36
+
37
+
38
+ def get_args():
39
+ parser = argparse.ArgumentParser(description='training your network')
40
+ parser.add_argument('--train_engine',
41
+ default='torch_ddp',
42
+ choices=['torch_ddp', 'deepspeed'],
43
+ help='Engine for paralleled training')
44
+ parser.add_argument('--model', required=True, help='model which will be trained')
45
+ parser.add_argument('--config', required=True, help='config file')
46
+ parser.add_argument('--train_data', required=True, help='train data file')
47
+ parser.add_argument('--cv_data', required=True, help='cv data file')
48
+ parser.add_argument('--checkpoint', help='checkpoint model')
49
+ parser.add_argument('--model_dir', required=True, help='save model dir')
50
+ parser.add_argument('--tensorboard_dir',
51
+ default='tensorboard',
52
+ help='tensorboard log dir')
53
+ parser.add_argument('--ddp.dist_backend',
54
+ dest='dist_backend',
55
+ default='nccl',
56
+ choices=['nccl', 'gloo'],
57
+ help='distributed backend')
58
+ parser.add_argument('--num_workers',
59
+ default=0,
60
+ type=int,
61
+ help='num of subprocess workers for reading')
62
+ parser.add_argument('--prefetch',
63
+ default=100,
64
+ type=int,
65
+ help='prefetch number')
66
+ parser.add_argument('--pin_memory',
67
+ action='store_true',
68
+ default=False,
69
+ help='Use pinned memory buffers used for reading')
70
+ parser.add_argument('--deepspeed.save_states',
71
+ dest='save_states',
72
+ default='model_only',
73
+ choices=['model_only', 'model+optimizer'],
74
+ help='save model/optimizer states')
75
+ parser.add_argument('--timeout',
76
+ default=30,
77
+ type=int,
78
+ help='timeout (in seconds) of cosyvoice_join.')
79
+ # parser = deepspeed.add_config_arguments(parser)
80
+ args = parser.parse_args()
81
+ return args
82
+
83
+
84
+ @record
85
+ def main():
86
+ args = get_args()
87
+ logging.basicConfig(level=logging.DEBUG,
88
+ format='%(asctime)s %(levelname)s %(message)s')
89
+
90
+ override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
91
+ with open(args.config, 'r') as f:
92
+ configs = load_hyperpyyaml(f, overrides=override_dict)
93
+ configs['train_conf'].update(vars(args))
94
+
95
+ # Init env for ddp
96
+ init_distributed(args)
97
+
98
+ # Get dataset & dataloader
99
+ train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100
+ init_dataset_and_dataloader(args, configs)
101
+
102
+ # Do some sanity checks and save config to arsg.model_dir
103
+ configs = check_modify_and_save_config(args, configs)
104
+
105
+ # Tensorboard summary
106
+ writer = init_summarywriter(args)
107
+
108
+ # load checkpoint
109
+ model = configs[args.model]
110
+ if args.checkpoint is not None:
111
+ model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
112
+
113
+ # Dispatch model from cpu to gpu
114
+ model = wrap_cuda_model(args, model)
115
+
116
+ # Get optimizer & scheduler
117
+ model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
118
+ # pdb.set_trace()
119
+ # Save init checkpoints
120
+ info_dict = deepcopy(configs['train_conf'])
121
+ save_model(model, 'init', info_dict)
122
+
123
+ # Get executor
124
+ executor = Executor()
125
+
126
+ # Start training loop
127
+ for epoch in range(info_dict['max_epoch']):
128
+ executor.epoch = epoch
129
+ train_dataset.set_epoch(epoch)
130
+ dist.barrier()
131
+ # try:
132
+ # dist.barrier()
133
+ # except RuntimeError as e:
134
+ # logging.info('except RuntimeError as e: {}'.format(e))
135
+ group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
136
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
137
+ dist.destroy_process_group(group_join)
138
+
139
+ if __name__ == '__main__':
140
+ main()
third_party/GLM-4-Voice/cosyvoice/cli/__init__.py ADDED
File without changes
third_party/GLM-4-Voice/cosyvoice/cli/cosyvoice.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import torch
16
+ from hyperpyyaml import load_hyperpyyaml
17
+ from modelscope import snapshot_download
18
+ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
+ from cosyvoice.cli.model import CosyVoiceModel
20
+
21
+ class CosyVoice:
22
+
23
+ def __init__(self, model_dir):
24
+ instruct = True if '-Instruct' in model_dir else False
25
+ self.model_dir = model_dir
26
+ if not os.path.exists(model_dir):
27
+ model_dir = snapshot_download(model_dir)
28
+ with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
29
+ configs = load_hyperpyyaml(f)
30
+ self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
31
+ configs['feat_extractor'],
32
+ '{}/campplus.onnx'.format(model_dir),
33
+ '{}/speech_tokenizer_v1.onnx'.format(model_dir),
34
+ '{}/spk2info.pt'.format(model_dir),
35
+ instruct,
36
+ configs['allowed_special'])
37
+ self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
38
+ self.model.load('{}/llm.pt'.format(model_dir),
39
+ '{}/flow.pt'.format(model_dir),
40
+ '{}/hift.pt'.format(model_dir))
41
+ del configs
42
+
43
+ def list_avaliable_spks(self):
44
+ spks = list(self.frontend.spk2info.keys())
45
+ return spks
46
+
47
+ def inference_sft(self, tts_text, spk_id):
48
+ tts_speeches = []
49
+ for i in self.frontend.text_normalize(tts_text, split=True):
50
+ model_input = self.frontend.frontend_sft(i, spk_id)
51
+ model_output = self.model.inference(**model_input)
52
+ tts_speeches.append(model_output['tts_speech'])
53
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
54
+
55
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
56
+ prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
+ tts_speeches = []
58
+ for i in self.frontend.text_normalize(tts_text, split=True):
59
+ model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
+ model_output = self.model.inference(**model_input)
61
+ tts_speeches.append(model_output['tts_speech'])
62
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
63
+
64
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
+ if self.frontend.instruct is True:
66
+ raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
+ tts_speeches = []
68
+ for i in self.frontend.text_normalize(tts_text, split=True):
69
+ model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
+ model_output = self.model.inference(**model_input)
71
+ tts_speeches.append(model_output['tts_speech'])
72
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
73
+
74
+ def inference_instruct(self, tts_text, spk_id, instruct_text):
75
+ if self.frontend.instruct is False:
76
+ raise ValueError('{} do not support instruct inference'.format(self.model_dir))
77
+ instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
+ tts_speeches = []
79
+ for i in self.frontend.text_normalize(tts_text, split=True):
80
+ model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
+ model_output = self.model.inference(**model_input)
82
+ tts_speeches.append(model_output['tts_speech'])
83
+ return {'tts_speech': torch.concat(tts_speeches, dim=1)}
third_party/GLM-4-Voice/cosyvoice/cli/frontend.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ from typing import Callable
20
+ import torchaudio.compliance.kaldi as kaldi
21
+ import torchaudio
22
+ import os
23
+ import re
24
+ import inflect
25
+ try:
26
+ import ttsfrd
27
+ use_ttsfrd = True
28
+ except ImportError:
29
+ print("failed to import ttsfrd, use WeTextProcessing instead")
30
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
31
+ from tn.english.normalizer import Normalizer as EnNormalizer
32
+ use_ttsfrd = False
33
+ from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
34
+
35
+
36
+ class CosyVoiceFrontEnd:
37
+
38
+ def __init__(self,
39
+ get_tokenizer: Callable,
40
+ feat_extractor: Callable,
41
+ campplus_model: str,
42
+ speech_tokenizer_model: str,
43
+ spk2info: str = '',
44
+ instruct: bool = False,
45
+ allowed_special: str = 'all'):
46
+ self.tokenizer = get_tokenizer()
47
+ self.feat_extractor = feat_extractor
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
49
+ option = onnxruntime.SessionOptions()
50
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
+ option.intra_op_num_threads = 1
52
+ self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
54
+ if os.path.exists(spk2info):
55
+ self.spk2info = torch.load(spk2info, map_location=self.device)
56
+ self.instruct = instruct
57
+ self.allowed_special = allowed_special
58
+ self.inflect_parser = inflect.engine()
59
+ self.use_ttsfrd = use_ttsfrd
60
+ if self.use_ttsfrd:
61
+ self.frd = ttsfrd.TtsFrontendEngine()
62
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
64
+ self.frd.set_lang_type('pinyin')
65
+ self.frd.enable_pinyin_mix(True)
66
+ self.frd.set_breakmodel_index(1)
67
+ else:
68
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
69
+ self.en_tn_model = EnNormalizer()
70
+
71
+ def _extract_text_token(self, text):
72
+ text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
73
+ text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
74
+ text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
75
+ return text_token, text_token_len
76
+
77
+ def _extract_speech_token(self, speech):
78
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
81
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
+ return speech_token, speech_token_len
84
+
85
+ def _extract_spk_embedding(self, speech):
86
+ feat = kaldi.fbank(speech,
87
+ num_mel_bins=80,
88
+ dither=0,
89
+ sample_frequency=16000)
90
+ feat = feat - feat.mean(dim=0, keepdim=True)
91
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
92
+ embedding = torch.tensor([embedding]).to(self.device)
93
+ return embedding
94
+
95
+ def _extract_speech_feat(self, speech):
96
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
97
+ speech_feat = speech_feat.unsqueeze(dim=0)
98
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
99
+ return speech_feat, speech_feat_len
100
+
101
+ def text_normalize(self, text, split=True):
102
+ text = text.strip()
103
+ if contains_chinese(text):
104
+ if self.use_ttsfrd:
105
+ text = self.frd.get_frd_extra_info(text, 'input')
106
+ else:
107
+ text = self.zh_tn_model.normalize(text)
108
+ text = text.replace("\n", "")
109
+ text = replace_blank(text)
110
+ text = replace_corner_mark(text)
111
+ text = text.replace(".", "、")
112
+ text = text.replace(" - ", ",")
113
+ text = remove_bracket(text)
114
+ text = re.sub(r'[,,]+$', '。', text)
115
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
+ token_min_n=60, merge_len=20,
117
+ comma_split=False)]
118
+ else:
119
+ if self.use_ttsfrd:
120
+ text = self.frd.get_frd_extra_info(text, 'input')
121
+ else:
122
+ text = self.en_tn_model.normalize(text)
123
+ text = spell_out_number(text, self.inflect_parser)
124
+ texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
+ token_min_n=60, merge_len=20,
126
+ comma_split=False)]
127
+ if split is False:
128
+ return text
129
+ return texts
130
+
131
+ def frontend_sft(self, tts_text, spk_id):
132
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
133
+ embedding = self.spk2info[spk_id]['embedding']
134
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
135
+ return model_input
136
+
137
+ def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
138
+ tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
139
+ prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
140
+ prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
141
+ speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
142
+ speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
143
+ embedding = self._extract_spk_embedding(prompt_speech_16k)
144
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
145
+ 'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
146
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
147
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
148
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
149
+ 'llm_embedding': embedding, 'flow_embedding': embedding}
150
+ return model_input
151
+
152
+ def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
153
+ model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
154
+ # in cross lingual mode, we remove prompt in llm
155
+ del model_input['prompt_text']
156
+ del model_input['prompt_text_len']
157
+ del model_input['llm_prompt_speech_token']
158
+ del model_input['llm_prompt_speech_token_len']
159
+ return model_input
160
+
161
+ def frontend_instruct(self, tts_text, spk_id, instruct_text):
162
+ model_input = self.frontend_sft(tts_text, spk_id)
163
+ # in instruct mode, we remove spk_embedding in llm due to information leakage
164
+ del model_input['llm_embedding']
165
+ instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
166
+ model_input['prompt_text'] = instruct_text_token
167
+ model_input['prompt_text_len'] = instruct_text_token_len
168
+ return model_input
third_party/GLM-4-Voice/cosyvoice/cli/model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+
16
+ class CosyVoiceModel:
17
+
18
+ def __init__(self,
19
+ llm: torch.nn.Module,
20
+ flow: torch.nn.Module,
21
+ hift: torch.nn.Module):
22
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ self.llm = llm
24
+ self.flow = flow
25
+ self.hift = hift
26
+
27
+ def load(self, llm_model, flow_model, hift_model):
28
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
29
+ self.llm.to(self.device).eval()
30
+ self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
31
+ self.flow.to(self.device).eval()
32
+ self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
33
+ self.hift.to(self.device).eval()
34
+
35
+ def inference(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
36
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
41
+ text_len=text_len.to(self.device),
42
+ prompt_text=prompt_text.to(self.device),
43
+ prompt_text_len=prompt_text_len.to(self.device),
44
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
+ embedding=llm_embedding.to(self.device),
47
+ beam_size=1,
48
+ sampling=25,
49
+ max_token_text_ratio=30,
50
+ min_token_text_ratio=3)
51
+ tts_mel = self.flow.inference(token=tts_speech_token,
52
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
+ prompt_token=flow_prompt_speech_token.to(self.device),
54
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
+ prompt_feat=prompt_speech_feat.to(self.device),
56
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
+ embedding=flow_embedding.to(self.device))
58
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
+ torch.cuda.empty_cache()
60
+ return {'tts_speech': tts_speech}
61
+
62
+ def text_to_token(self, text, text_len, flow_embedding, llm_embedding=torch.zeros(0, 192),
63
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
64
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
65
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
66
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
67
+ tts_speech_token = self.llm.inference(text=text.to(self.device),
68
+ text_len=text_len.to(self.device),
69
+ prompt_text=prompt_text.to(self.device),
70
+ prompt_text_len=prompt_text_len.to(self.device),
71
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
72
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
73
+ embedding=llm_embedding.to(self.device),
74
+ beam_size=1,
75
+ sampling=25,
76
+ max_token_text_ratio=30,
77
+ min_token_text_ratio=3)
78
+ return tts_speech_token
79
+
80
+ def token_to_speech(self, tts_speech_token, flow_embedding, llm_embedding=torch.zeros(0, 192),
81
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
82
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
83
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
84
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
85
+
86
+ tts_mel = self.flow.inference(token=tts_speech_token,
87
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
88
+ prompt_token=flow_prompt_speech_token.to(self.device),
89
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
90
+ prompt_feat=prompt_speech_feat.to(self.device),
91
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
92
+ embedding=flow_embedding.to(self.device))
93
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
94
+ torch.cuda.empty_cache()
95
+ return {'tts_speech': tts_speech}
third_party/GLM-4-Voice/cosyvoice/dataset/__init__.py ADDED
File without changes
third_party/GLM-4-Voice/cosyvoice/dataset/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import random
17
+ import json
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.utils.data import IterableDataset
24
+ from cosyvoice.utils.file_utils import read_lists, read_json_lists
25
+
26
+
27
+ class Processor(IterableDataset):
28
+
29
+ def __init__(self, source, f, *args, **kw):
30
+ assert callable(f)
31
+ self.source = source
32
+ self.f = f
33
+ self.args = args
34
+ self.kw = kw
35
+
36
+ def set_epoch(self, epoch):
37
+ self.source.set_epoch(epoch)
38
+
39
+ def __iter__(self):
40
+ """ Return an iterator over the source dataset processed by the
41
+ given processor.
42
+ """
43
+ assert self.source is not None
44
+ assert callable(self.f)
45
+ return self.f(iter(self.source), *self.args, **self.kw)
46
+
47
+ def apply(self, f):
48
+ assert callable(f)
49
+ return Processor(self, f, *self.args, **self.kw)
50
+
51
+
52
+ class DistributedSampler:
53
+
54
+ def __init__(self, shuffle=True, partition=True):
55
+ self.epoch = -1
56
+ self.update()
57
+ self.shuffle = shuffle
58
+ self.partition = partition
59
+
60
+ def update(self):
61
+ assert dist.is_available()
62
+ if dist.is_initialized():
63
+ self.rank = dist.get_rank()
64
+ self.world_size = dist.get_world_size()
65
+ else:
66
+ self.rank = 0
67
+ self.world_size = 1
68
+ worker_info = torch.utils.data.get_worker_info()
69
+ if worker_info is None:
70
+ self.worker_id = 0
71
+ self.num_workers = 1
72
+ else:
73
+ self.worker_id = worker_info.id
74
+ self.num_workers = worker_info.num_workers
75
+ return dict(rank=self.rank,
76
+ world_size=self.world_size,
77
+ worker_id=self.worker_id,
78
+ num_workers=self.num_workers)
79
+
80
+ def set_epoch(self, epoch):
81
+ self.epoch = epoch
82
+
83
+ def sample(self, data):
84
+ """ Sample data according to rank/world_size/num_workers
85
+
86
+ Args:
87
+ data(List): input data list
88
+
89
+ Returns:
90
+ List: data list after sample
91
+ """
92
+ data = list(range(len(data)))
93
+ # force datalist even
94
+ if self.partition:
95
+ if self.shuffle:
96
+ random.Random(self.epoch).shuffle(data)
97
+ if len(data) < self.world_size:
98
+ data = data * math.ceil(self.world_size / len(data))
99
+ data = data[:self.world_size]
100
+ data = data[self.rank::self.world_size]
101
+ if len(data) < self.num_workers:
102
+ data = data * math.ceil(self.num_workers / len(data))
103
+ data = data[:self.num_workers]
104
+ data = data[self.worker_id::self.num_workers]
105
+ return data
106
+
107
+
108
+ class DataList(IterableDataset):
109
+
110
+ def __init__(self, lists, shuffle=True, partition=True):
111
+ self.lists = lists
112
+ self.sampler = DistributedSampler(shuffle, partition)
113
+
114
+ def set_epoch(self, epoch):
115
+ self.sampler.set_epoch(epoch)
116
+
117
+ def __iter__(self):
118
+ sampler_info = self.sampler.update()
119
+ indexes = self.sampler.sample(self.lists)
120
+ for index in indexes:
121
+ data = dict(src=self.lists[index])
122
+ data.update(sampler_info)
123
+ yield data
124
+
125
+
126
+ def Dataset(data_list_file,
127
+ data_pipeline,
128
+ mode='train',
129
+ shuffle=True,
130
+ partition=True,
131
+ tts_file='',
132
+ prompt_utt2data=''):
133
+ """ Construct dataset from arguments
134
+
135
+ We have two shuffle stage in the Dataset. The first is global
136
+ shuffle at shards tar/raw file level. The second is global shuffle
137
+ at training samples level.
138
+
139
+ Args:
140
+ data_type(str): raw/shard
141
+ tokenizer (BaseTokenizer): tokenizer to tokenize
142
+ partition(bool): whether to do data partition in terms of rank
143
+ """
144
+ assert mode in ['train', 'inference']
145
+ lists = read_lists(data_list_file)
146
+ # import pdb
147
+ # pdb.set_trace()
148
+ if mode == 'inference':
149
+ with open(tts_file) as f:
150
+ tts_data = json.load(f)
151
+ utt2lists = read_json_lists(prompt_utt2data)
152
+ # filter unnecessary file in inference mode
153
+ lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
154
+ dataset = DataList(lists,shuffle=shuffle,partition=partition)
155
+ if mode == 'inference':
156
+ # map partial arg tts_data in inference mode
157
+ data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158
+ for func in data_pipeline:
159
+ dataset = Processor(dataset, func, mode=mode)
160
+ return dataset
third_party/GLM-4-Voice/cosyvoice/dataset/processor.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ import json
17
+ import tarfile
18
+ import json
19
+ import io
20
+ import pyarrow.parquet as pq
21
+ from io import BytesIO
22
+ import torch
23
+ import torchaudio
24
+ from torch.nn.utils.rnn import pad_sequence
25
+ import torch.nn.functional as F
26
+ import tarfile
27
+ import json
28
+ import io
29
+ import wave
30
+ import numpy as np
31
+ import torchaudio
32
+ import os
33
+ import sys
34
+ import json
35
+ import random
36
+ import pickle
37
+ import argparse
38
+ import itertools
39
+ import mmap
40
+ import struct
41
+ import collections
42
+
43
+
44
+
45
+ import shutil
46
+ import multiprocessing as mp
47
+ from pathlib import Path
48
+
49
+ from tqdm import tqdm
50
+ from collections import defaultdict
51
+ from copy import deepcopy
52
+ from datetime import datetime
53
+ import pickle
54
+
55
+ from wids import wids
56
+ import math
57
+
58
+ torchaudio.set_audio_backend('soundfile')
59
+
60
+ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
61
+
62
+ try:
63
+ MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
64
+ GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
65
+ except:
66
+ MAIN_SPK_EMBEDDING=torch.zeros(1,192)
67
+ GPT_SPK_EMBEDDING=torch.zeros(1,192)
68
+
69
+ def parquet_opener(data, mode='train', tts_data={}):
70
+ """ Give url or local file, return file descriptor
71
+ Inplace operation.
72
+
73
+ Args:
74
+ data(Iterable[str]): url or local file list
75
+
76
+ Returns:
77
+ Iterable[{src, stream}]
78
+ """
79
+ for sample in data:
80
+ assert 'src' in sample
81
+ url = sample['src']
82
+ try:
83
+ df = pq.read_table(url).to_pandas()
84
+ for i in range(len(df)):
85
+ if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
86
+ continue
87
+ sample.update(dict(df.loc[i]))
88
+ if mode == 'train':
89
+ # NOTE do not return sample directly, must initialize a new dict
90
+ yield {**sample}
91
+ else:
92
+ for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
93
+ yield {**sample, 'tts_index': index, 'tts_text': text}
94
+ except Exception as ex:
95
+ logging.warning('Failed to open {}, ex info {}'.format(url, ex))
96
+
97
+
98
+
99
+
100
+ def parse_tar_header(header_bytes):
101
+ header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
102
+ return TarHeader(*header)
103
+
104
+ TarHeader = collections.namedtuple(
105
+ "TarHeader",
106
+ [
107
+ "name",
108
+ "mode",
109
+ "uid",
110
+ "gid",
111
+ "size",
112
+ "mtime",
113
+ "chksum",
114
+ "typeflag",
115
+ "linkname",
116
+ "magic",
117
+ "version",
118
+ "uname",
119
+ "gname",
120
+ "devmajor",
121
+ "devminor",
122
+ "prefix",
123
+ ],
124
+ )
125
+
126
+ class MMTar:
127
+ def __init__(self, file_path: Path | str):
128
+ self.stream = open(file_path, "rb")
129
+ self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
130
+
131
+ def __del__(self):
132
+ try:
133
+ self.mmap.close()
134
+ self.stream.close()
135
+ except: # noqa
136
+ pass
137
+
138
+ def get_at_offset(self, offset) -> tuple[str, bytes]:
139
+ header = parse_tar_header(self.mmap[offset : offset + 500])
140
+ name = header.name.decode("utf-8").strip("\x00")
141
+ start = offset + 512
142
+ end = start + int(header.size.decode("utf-8")[:-1], 8)
143
+ return name, self.mmap[start:end]
144
+
145
+
146
+ class Tar:
147
+ def __init__(self, path: Path):
148
+ self.tar = MMTar(path)
149
+ indices_path = path.with_suffix(".index")
150
+ self.index = pickle.loads(indices_path.read_bytes())
151
+ self.name_mapping = {}
152
+ for name, offset, _ in self.index:
153
+ self.name_mapping[name] = offset
154
+
155
+ def read(self, name: str) -> bytes:
156
+ return self.tar.get_at_offset(self.name_mapping[name])[1]
157
+
158
+ def cosy_jsonl_opener(data, mode='train', tts_data={}):
159
+ """ Give url or local file, return file descriptor
160
+ Inplace operation.
161
+
162
+ Args:
163
+ data(Iterable[str]): url or local file list
164
+
165
+ Returns:
166
+ Iterable[{src, stream}]
167
+ """
168
+ for sample in data:
169
+ assert 'src' in sample
170
+ cosy_jsonl_path = sample['src']
171
+ tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
172
+ try:
173
+ tar_data=Tar(Path(tar_file_path))
174
+ with open(cosy_jsonl_path, 'r') as f:
175
+ for line in f:
176
+ item=json.loads(line)
177
+ cosy_token = item['cosy_token']
178
+ sample['speech_token']=torch.tensor(cosy_token)
179
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
180
+ # print(item['filename'])
181
+ yield {**sample}
182
+
183
+ except Exception as ex:
184
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
185
+
186
+
187
+ def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
188
+ """ Give url or local file, return file descriptor
189
+ Inplace operation.
190
+
191
+ Args:
192
+ data(Iterable[str]): url or local file list
193
+
194
+ Returns:
195
+ Iterable[{src, stream}]
196
+ """
197
+ for sample in data:
198
+ assert 'src' in sample
199
+ cosy_jsonl_path = sample['src']
200
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")
201
+
202
+
203
+ try:
204
+ tar_data=Tar(Path(tar_file_path))
205
+ with open(cosy_jsonl_path, 'r') as f:
206
+ # cosy_data = [json.loads(line) for line in f]
207
+ for line in f:
208
+ item=json.loads(line)
209
+ cosy_token = item['cosy_token']
210
+ sample['speech_token']=torch.tensor(cosy_token)
211
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
212
+ # print(item['filename'])
213
+ yield {**sample}
214
+
215
+ except Exception as ex:
216
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
217
+
218
+
219
+
220
+ def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
221
+ """ Give url or local file, return file descriptor
222
+ Inplace operation.
223
+
224
+ Args:
225
+ data(Iterable[str]): url or local file list
226
+
227
+ Returns:
228
+ Iterable[{src, stream}]
229
+ """
230
+ for sample in data:
231
+ assert 'src' in sample
232
+ cosy_jsonl_path = sample['src']
233
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")
234
+
235
+ try:
236
+ tar_data=Tar(Path(tar_file_path))
237
+ with open(cosy_jsonl_path, 'r') as f:
238
+ for line in f:
239
+ item=json.loads(line)
240
+ cosy_token = item['cosy_token']
241
+ sample['speech_token']=torch.tensor(cosy_token)
242
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
243
+
244
+ yield {**sample}
245
+
246
+ except Exception as ex:
247
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
248
+
249
+
250
+ def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
251
+ """ Give url or local file, return file descriptor
252
+ Inplace operation.
253
+
254
+ Args:
255
+ data(Iterable[str]): url or local file list
256
+
257
+ Returns:
258
+ Iterable[{src, stream}]
259
+ """
260
+ for sample in data:
261
+ assert 'src' in sample
262
+ cosy_jsonl_path = sample['src']
263
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
264
+ try:
265
+ tar_data=Tar(Path(tar_file_path))
266
+ with open(cosy_jsonl_path, 'r') as f:
267
+ # cosy_data = [json.loads(line) for line in f]
268
+ for line in f:
269
+ item=json.loads(line)
270
+ cosy_token = item['cosy_token']
271
+ sample['speech_token']=torch.tensor(cosy_token)
272
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
273
+ # print(item['filename'])
274
+ yield {**sample}
275
+
276
+ except Exception as ex:
277
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
278
+
279
+
280
+ def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
281
+ """ Give url or local file, return file descriptor
282
+ Inplace operation.
283
+
284
+ Args:
285
+ data(Iterable[str]): url or local file list
286
+
287
+ Returns:
288
+ Iterable[{src, stream}]
289
+ """
290
+ for sample in data:
291
+ assert 'src' in sample
292
+ cosy_jsonl_path = sample['src']
293
+ tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")
294
+
295
+ try:
296
+ tar_data=Tar(Path(tar_file_path))
297
+ with open(cosy_jsonl_path, 'r') as f:
298
+ # cosy_data = [json.loads(line) for line in f]
299
+ for line in f:
300
+ item=json.loads(line)
301
+ cosy_token = item['cosy_token']
302
+ sample['speech_token']=torch.tensor(cosy_token)
303
+ sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
304
+ # print(item['filename'])
305
+ yield {**sample}
306
+
307
+ except Exception as ex:
308
+ logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
309
+
310
+
311
+
312
+ def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
313
+ for sample in data:
314
+ assert 'src' in sample
315
+
316
+ token_npy_path = sample['src']
317
+ wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
318
+
319
+ # wav_path,token_npy_path=sample['src'].split(' ')
320
+ try:
321
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
322
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
323
+ if sample['speech'].shape[0] > 1:
324
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
325
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
326
+ yield {**sample}
327
+ except Exception as ex:
328
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
329
+ logging.warning('Failed to open {}'.format(wav_path))
330
+
331
+
332
+ def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
333
+ for sample in data:
334
+ assert 'src' in sample
335
+
336
+ token_npy_path = sample['src']
337
+ wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
338
+
339
+ # wav_path,token_npy_path=sample['src'].split(' ')
340
+ try:
341
+ # sample['speech_token']=torch.tensor(np.load(token_npy_path))
342
+ # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
343
+ # if sample['speech'].shape[0] > 1:
344
+ # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
345
+
346
+
347
+ # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
348
+
349
+
350
+ speech_token=torch.tensor(np.load(token_npy_path))
351
+ speech,sample_rate= torchaudio.load(wav_path)
352
+ # split_speech=int(split_token / 12.5 * sample_rate)
353
+ if speech.shape[0] > 1:
354
+ speech = speech.mean(dim=0, keepdim=True)
355
+
356
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
357
+ sample['sample_rate']=sample_rate
358
+
359
+ num_splits = (speech_token.size(0) + split_token - 1) // split_token
360
+
361
+ for split_id in range(num_splits):
362
+ end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
363
+ end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
364
+ sample['speech_token']=speech_token[:end_token_idx]
365
+ sample['speech']=speech[:,:end_speech_idx]
366
+ print(sample['speech_token'].size(),sample['speech'].size())
367
+ yield {**sample}
368
+ except Exception as ex:
369
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
370
+ logging.warning('Failed to open {}'.format(wav_path))
371
+
372
+
373
+ def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
374
+ for sample in data:
375
+ assert 'src' in sample
376
+
377
+ token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
378
+ wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
379
+
380
+ # wav_path,token_npy_path=sample['src'].split(' ')
381
+ try:
382
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
383
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
384
+ if sample['speech'].shape[0] > 1:
385
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
386
+
387
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
388
+ yield {**sample}
389
+ except Exception as ex:
390
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
391
+ logging.warning('Failed to open {}'.format(wav_path))
392
+
393
+
394
+ def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
395
+ for sample in data:
396
+ assert 'src' in sample
397
+
398
+ token_npy_path = sample['src']
399
+ wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
400
+
401
+ # wav_path,token_npy_path=sample['src'].split(' ')
402
+ try:
403
+ # sample['speech_token']=torch.tensor(np.load(token_npy_path))
404
+ # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
405
+ # if sample['speech'].shape[0] > 1:
406
+ # sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
407
+
408
+
409
+ # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
410
+
411
+
412
+ speech_token=torch.tensor(np.load(token_npy_path))
413
+ speech,sample_rate= torchaudio.load(wav_path)
414
+ # split_speech=int(split_token / 12.5 * sample_rate)
415
+ if speech.shape[0] > 1:
416
+ speech = speech.mean(dim=0, keepdim=True)
417
+
418
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
419
+ sample['sample_rate']=sample_rate
420
+
421
+ num_splits = (speech_token.size(0) + split_token - 1) // split_token
422
+
423
+ for split_id in range(num_splits):
424
+ end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
425
+ end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
426
+ sample['speech_token']=speech_token[:end_token_idx]
427
+ sample['speech']=speech[:,:end_speech_idx]
428
+ print(sample['speech_token'].size(),sample['speech'].size())
429
+ yield {**sample}
430
+ except Exception as ex:
431
+ logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
432
+ logging.warning('Failed to open {}'.format(wav_path))
433
+
434
+ def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
435
+ for sample in data:
436
+ assert 'src' in sample
437
+ try:
438
+ entry=json.loads(sample['src'])
439
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
440
+
441
+ for conv in entry["conversations"]:
442
+ if "response_wav" in conv:
443
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
444
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
445
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
446
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
447
+ if sample['speech'].shape[0] > 1:
448
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
449
+ sample['spk_embedding']=spk_embedding
450
+ yield {**sample}
451
+ except Exception as ex:
452
+ # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
453
+ logging.warning('Failed to open {}'.format(wav_path))
454
+
455
+
456
+ def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
457
+ for sample in data:
458
+ assert 'src' in sample
459
+ try:
460
+ entry=json.loads(sample['src'])
461
+ sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
462
+
463
+ for conv in entry["conversations"]:
464
+ if "response_wav" in conv:
465
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
466
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
467
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
468
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
469
+ if sample['speech'].shape[0] > 1:
470
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
471
+ sample['spk_embedding']=spk_embedding
472
+ yield {**sample}
473
+ if "prompt_wav" in conv:
474
+ wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
475
+ token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
476
+ sample['speech_token']=torch.tensor(np.load(token_npy_path))
477
+ sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
478
+ if sample['speech'].shape[0] > 1:
479
+ sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
480
+ sample['spk_embedding']=spk_embedding
481
+ yield {**sample}
482
+ except Exception as ex:
483
+ # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
484
+ logging.warning('Failed to open {}'.format(wav_path))
485
+
486
+
487
+ def filter(data,
488
+ max_length=10240,
489
+ min_length=10,
490
+ token_max_length=200,
491
+ token_min_length=1,
492
+ min_output_input_ratio=0.0005,
493
+ max_output_input_ratio=1,
494
+ mode='train'):
495
+ """ Filter sample according to feature and label length
496
+ Inplace operation.
497
+
498
+ Args::
499
+ data: Iterable[{key, wav, label, sample_rate}]
500
+ max_length: drop utterance which is greater than max_length(10ms)
501
+ min_length: drop utterance which is less than min_length(10ms)
502
+ token_max_length: drop utterance which is greater than
503
+ token_max_length, especially when use char unit for
504
+ english modeling
505
+ token_min_length: drop utterance which is
506
+ less than token_max_length
507
+ min_output_input_ratio: minimal ration of
508
+ token_length / feats_length(10ms)
509
+ max_output_input_ratio: maximum ration of
510
+ token_length / feats_length(10ms)
511
+
512
+ Returns:
513
+ Iterable[{key, wav, label, sample_rate}]
514
+ """
515
+ for sample in data:
516
+ # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
517
+ # del sample['audio_data']
518
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
519
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
520
+ if num_frames < min_length:
521
+ continue
522
+ if num_frames > max_length:
523
+ continue
524
+ if len(sample['text_token']) < token_min_length:
525
+ continue
526
+ if len(sample['text_token']) > token_max_length:
527
+ continue
528
+ if len(sample['speech_token']) == 0:
529
+ continue
530
+ if num_frames != 0:
531
+ if len(sample['text_token']) / num_frames < min_output_input_ratio:
532
+ continue
533
+ if len(sample['text_token']) / num_frames > max_output_input_ratio:
534
+ continue
535
+ yield sample
536
+
537
+
538
+ def filter_speech_token(data,
539
+ max_length=10240,
540
+ min_length=10,
541
+ token_max_length=5000,
542
+ token_min_length=1,
543
+ min_output_input_ratio=0.0005,
544
+ max_output_input_ratio=30,
545
+ mode='train'):
546
+ """ Filter sample according to feature and label length
547
+ Inplace operation.
548
+
549
+ Args::
550
+ data: Iterable[{key, wav, label, sample_rate}]
551
+ max_length: drop utterance which is greater than max_length(10ms)
552
+ min_length: drop utterance which is less than min_length(10ms)
553
+ token_max_length: drop utterance which is greater than
554
+ token_max_length, especially when use char unit for
555
+ english modeling
556
+ token_min_length: drop utterance which is
557
+ less than token_max_length
558
+ min_output_input_ratio: minimal ration of
559
+ token_length / feats_length(10ms)
560
+ max_output_input_ratio: maximum ration of
561
+ token_length / feats_length(10ms)
562
+
563
+ Returns:
564
+ Iterable[{key, wav, label, sample_rate}]
565
+ """
566
+ for sample in data:
567
+ # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
568
+ # del sample['audio_data']
569
+ # sample['wav'] is torch.Tensor, we have 100 frames every second
570
+ num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
571
+ if num_frames < min_length:
572
+ continue
573
+ if num_frames > max_length:
574
+ continue
575
+ if len(sample['speech_token']) < token_min_length:
576
+ continue
577
+ if len(sample['speech_token']) > token_max_length:
578
+ continue
579
+ if len(sample['speech_token']) == 0:
580
+ continue
581
+ if num_frames != 0:
582
+ if len(sample['speech_token']) / num_frames < min_output_input_ratio:
583
+ continue
584
+ if len(sample['speech_token']) / num_frames > max_output_input_ratio:
585
+ continue
586
+ yield sample
587
+
588
+
589
+ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
590
+ """ Resample data.
591
+ Inplace operation.
592
+
593
+ Args:
594
+ data: Iterable[{key, wav, label, sample_rate}]
595
+ resample_rate: target resample rate
596
+
597
+ Returns:
598
+ Iterable[{key, wav, label, sample_rate}]
599
+ """
600
+ for sample in data:
601
+ assert 'sample_rate' in sample
602
+ assert 'speech' in sample
603
+ sample_rate = sample['sample_rate']
604
+ waveform = sample['speech']
605
+ if sample_rate != resample_rate:
606
+ if sample_rate < min_sample_rate:
607
+ continue
608
+ sample['sample_rate'] = resample_rate
609
+ sample['speech'] = torchaudio.transforms.Resample(
610
+ orig_freq=sample_rate, new_freq=resample_rate)(waveform)
611
+ max_val = sample['speech'].abs().max()
612
+ if max_val > 1:
613
+ sample['speech'] /= max_val
614
+ yield sample
615
+
616
+
617
+ def compute_fbank(data,
618
+ feat_extractor,
619
+ mode='train'):
620
+ """ Extract fbank
621
+
622
+ Args:
623
+ data: Iterable[{key, wav, label, sample_rate}]
624
+
625
+ Returns:
626
+ Iterable[{key, feat, label}]
627
+ """
628
+ for sample in data:
629
+ assert 'sample_rate' in sample
630
+ assert 'speech' in sample
631
+ # assert 'utt' in sample
632
+ # assert 'text_token' in sample
633
+ waveform = sample['speech']
634
+ mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
635
+ sample['speech_feat'] = mat
636
+ del sample['speech']
637
+ yield sample
638
+
639
+
640
+ def parse_embedding(data, normalize, mode='train'):
641
+ """ Parse utt_embedding/spk_embedding
642
+
643
+ Args:
644
+ data: Iterable[{key, wav, label, sample_rate}]
645
+
646
+ Returns:
647
+ Iterable[{key, feat, label}]
648
+ """
649
+ for sample in data:
650
+ sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
651
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
652
+ if normalize:
653
+ sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
654
+ sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
655
+ yield sample
656
+
657
+
658
+ def tokenize(data, get_tokenizer, allowed_special, mode='train'):
659
+ """ Decode text to chars or BPE
660
+ Inplace operation
661
+
662
+ Args:
663
+ data: Iterable[{key, wav, txt, sample_rate}]
664
+
665
+ Returns:
666
+ Iterable[{key, wav, txt, tokens, label, sample_rate}]
667
+ """
668
+ tokenizer = get_tokenizer()
669
+ for sample in data:
670
+ assert 'text' in sample
671
+ sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
672
+ if mode == 'inference':
673
+ sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
674
+ yield sample
675
+
676
+
677
+ def shuffle(data, shuffle_size=10000, mode='train'):
678
+ """ Local shuffle the data
679
+
680
+ Args:
681
+ data: Iterable[{key, feat, label}]
682
+ shuffle_size: buffer size for shuffle
683
+
684
+ Returns:
685
+ Iterable[{key, feat, label}]
686
+ """
687
+ buf = []
688
+ for sample in data:
689
+ buf.append(sample)
690
+ if len(buf) >= shuffle_size:
691
+ random.shuffle(buf)
692
+ for x in buf:
693
+ yield x
694
+ buf = []
695
+ # The sample left over
696
+ random.shuffle(buf)
697
+ for x in buf:
698
+ yield x
699
+
700
+
701
+ def sort(data, sort_size=500, mode='train'):
702
+ """ Sort the data by feature length.
703
+ Sort is used after shuffle and before batch, so we can group
704
+ utts with similar lengths into a batch, and `sort_size` should
705
+ be less than `shuffle_size`
706
+
707
+ Args:
708
+ data: Iterable[{key, feat, label}]
709
+ sort_size: buffer size for sort
710
+
711
+ Returns:
712
+ Iterable[{key, feat, label}]
713
+ """
714
+
715
+ buf = []
716
+ for sample in data:
717
+ buf.append(sample)
718
+ if len(buf) >= sort_size:
719
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
720
+ for x in buf:
721
+ yield x
722
+ buf = []
723
+ # The sample left over
724
+ buf.sort(key=lambda x: x['speech_feat'].size(0))
725
+ for x in buf:
726
+ yield x
727
+
728
+
729
+ def static_batch(data, batch_size=16):
730
+ """ Static batch the data by `batch_size`
731
+
732
+ Args:
733
+ data: Iterable[{key, feat, label}]
734
+ batch_size: batch size
735
+
736
+ Returns:
737
+ Iterable[List[{key, feat, label}]]
738
+ """
739
+ buf = []
740
+ for sample in data:
741
+ buf.append(sample)
742
+ if len(buf) >= batch_size:
743
+ yield buf
744
+ buf = []
745
+ if len(buf) > 0:
746
+ yield buf
747
+
748
+
749
+ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
750
+ """ Dynamic batch the data until the total frames in batch
751
+ reach `max_frames_in_batch`
752
+
753
+ Args:
754
+ data: Iterable[{key, feat, label}]
755
+ max_frames_in_batch: max_frames in one batch
756
+
757
+ Returns:
758
+ Iterable[List[{key, feat, label}]]
759
+ """
760
+ buf = []
761
+ longest_frames = 0
762
+ for sample in data:
763
+ assert 'speech_feat' in sample
764
+ assert isinstance(sample['speech_feat'], torch.Tensor)
765
+ new_sample_frames = sample['speech_feat'].size(0)
766
+ longest_frames = max(longest_frames, new_sample_frames)
767
+ frames_after_padding = longest_frames * (len(buf) + 1)
768
+ if frames_after_padding > max_frames_in_batch:
769
+ yield buf
770
+ buf = [sample]
771
+ longest_frames = new_sample_frames
772
+ else:
773
+ buf.append(sample)
774
+ if len(buf) > 0:
775
+ yield buf
776
+
777
+
778
+ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
779
+ """ Wrapper for static/dynamic batch
780
+ """
781
+ if mode == 'inference':
782
+ return static_batch(data, 1)
783
+ else:
784
+ if batch_type == 'static':
785
+ return static_batch(data, batch_size)
786
+ elif batch_type == 'dynamic':
787
+ return dynamic_batch(data, max_frames_in_batch)
788
+ else:
789
+ logging.fatal('Unsupported batch type {}'.format(batch_type))
790
+
791
+
792
+ def padding(data, use_spk_embedding, mode='train'):
793
+ """ Padding the data into training data
794
+
795
+ Args:
796
+ data: Iterable[List[{key, feat, label}]]
797
+
798
+ Returns:
799
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
800
+ """
801
+ for sample in data:
802
+ assert isinstance(sample, list)
803
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
804
+ dtype=torch.int32)
805
+ order = torch.argsort(speech_feat_len, descending=True)
806
+
807
+ utts = [sample[i]['utt'] for i in order]
808
+ speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
809
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
810
+ speech_token = pad_sequence(speech_token,
811
+ batch_first=True,
812
+ padding_value=0)
813
+ speech_feat = [sample[i]['speech_feat'] for i in order]
814
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
815
+ speech_feat = pad_sequence(speech_feat,
816
+ batch_first=True,
817
+ padding_value=0)
818
+ text = [sample[i]['text'] for i in order]
819
+ text_token = [torch.tensor(sample[i]['text_token']) for i in order]
820
+ text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
821
+ text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
822
+ utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
823
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
824
+ batch = {
825
+ "utts": utts,
826
+ "speech_token": speech_token,
827
+ "speech_token_len": speech_token_len,
828
+ "speech_feat": speech_feat,
829
+ "speech_feat_len": speech_feat_len,
830
+ "text": text,
831
+ "text_token": text_token,
832
+ "text_token_len": text_token_len,
833
+ "utt_embedding": utt_embedding,
834
+ "spk_embedding": spk_embedding,
835
+ }
836
+ if mode == 'inference':
837
+ tts_text = [sample[i]['tts_text'] for i in order]
838
+ tts_index = [sample[i]['tts_index'] for i in order]
839
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
840
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
841
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
842
+ batch.update({'tts_text': tts_text,
843
+ 'tts_index': tts_index,
844
+ 'tts_text_token': tts_text_token,
845
+ 'tts_text_token_len': tts_text_token_len})
846
+ if use_spk_embedding is True:
847
+ batch["embedding"] = batch["spk_embedding"]
848
+ else:
849
+ batch["embedding"] = batch["utt_embedding"]
850
+ yield batch
851
+
852
+
853
+
854
+ def padding_speech_token(data, use_spk_embedding, mode='train'):
855
+ """ Padding the data into training data
856
+
857
+ Args:
858
+ data: Iterable[List[{key, feat, label}]]
859
+
860
+ Returns:
861
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
862
+ """
863
+ for sample in data:
864
+ assert isinstance(sample, list)
865
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
866
+ dtype=torch.int32)
867
+ order = torch.argsort(speech_feat_len, descending=True)
868
+
869
+ # utts = [sample[i]['utt'] for i in order]
870
+ # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
871
+ try:
872
+ speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
873
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
874
+ speech_token = pad_sequence(speech_token,
875
+ batch_first=True,
876
+ padding_value=0)
877
+ speech_feat = [sample[i]['speech_feat'] for i in order]
878
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
879
+ speech_feat = pad_sequence(speech_feat,
880
+ batch_first=True,
881
+ padding_value=0)
882
+ batch = {
883
+ "speech_token": speech_token,
884
+ "speech_token_len": speech_token_len,
885
+ "speech_feat": speech_feat,
886
+ "speech_feat_len": speech_feat_len,
887
+ }
888
+ if mode == 'inference':
889
+ tts_text = [sample[i]['tts_text'] for i in order]
890
+ tts_index = [sample[i]['tts_index'] for i in order]
891
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
892
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
893
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
894
+ batch.update({'tts_text': tts_text,
895
+ 'tts_index': tts_index,
896
+ 'tts_text_token': tts_text_token,
897
+ 'tts_text_token_len': tts_text_token_len})
898
+ # if use_spk_embedding is True:
899
+ # batch["embedding"] = batch["spk_embedding"]
900
+ # else:
901
+ # batch["embedding"] = batch["utt_embedding"]
902
+ batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
903
+ yield batch
904
+ except Exception as ex:
905
+ logging.warning(' ex info {}'.format(ex))
906
+ # assert False
907
+
908
+
909
+
910
+ def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
911
+ """ Padding the data into training data
912
+
913
+ Args:
914
+ data: Iterable[List[{key, feat, label}]]
915
+
916
+ Returns:
917
+ Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
918
+ """
919
+ for sample in data:
920
+ assert isinstance(sample, list)
921
+ speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
922
+ dtype=torch.int32)
923
+ order = torch.argsort(speech_feat_len, descending=True)
924
+
925
+ # utts = [sample[i]['utt'] for i in order]
926
+ # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
927
+ try:
928
+ speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
929
+ speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
930
+ speech_token = pad_sequence(speech_token,
931
+ batch_first=True,
932
+ padding_value=0)
933
+ speech_feat = [sample[i]['speech_feat'] for i in order]
934
+ speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
935
+ speech_feat = pad_sequence(speech_feat,
936
+ batch_first=True,
937
+ padding_value=0)
938
+ spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
939
+ batch = {
940
+ "speech_token": speech_token,
941
+ "speech_token_len": speech_token_len,
942
+ "speech_feat": speech_feat,
943
+ "speech_feat_len": speech_feat_len,
944
+ "spk_embedding": spk_embedding,
945
+ }
946
+ if mode == 'inference':
947
+ tts_text = [sample[i]['tts_text'] for i in order]
948
+ tts_index = [sample[i]['tts_index'] for i in order]
949
+ tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
950
+ tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
951
+ tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
952
+ batch.update({'tts_text': tts_text,
953
+ 'tts_index': tts_index,
954
+ 'tts_text_token': tts_text_token,
955
+ 'tts_text_token_len': tts_text_token_len})
956
+ # if use_spk_embedding is True:
957
+ # batch["embedding"] = batch["spk_embedding"]
958
+ # else:
959
+ # batch["embedding"] = batch["utt_embedding"]
960
+ # batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
961
+ batch["embedding"] = batch["spk_embedding"]
962
+ yield batch
963
+ except Exception as ex:
964
+ logging.warning(' ex info {}'.format(ex))
965
+ # assert False
third_party/GLM-4-Voice/cosyvoice/flow/decoder.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import pack, rearrange, repeat
17
+ from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
18
+ from matcha.models.components.transformer import BasicTransformerBlock
19
+
20
+
21
+ class ConditionalDecoder(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels,
25
+ out_channels,
26
+ channels=(256, 256),
27
+ dropout=0.05,
28
+ attention_head_dim=64,
29
+ n_blocks=1,
30
+ num_mid_blocks=2,
31
+ num_heads=4,
32
+ act_fn="snake",
33
+ ):
34
+ """
35
+ This decoder requires an input with the same shape of the target. So, if your text content
36
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
37
+ """
38
+ super().__init__()
39
+ channels = tuple(channels)
40
+ self.in_channels = in_channels
41
+ self.out_channels = out_channels
42
+
43
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
44
+ time_embed_dim = channels[0] * 4
45
+ self.time_mlp = TimestepEmbedding(
46
+ in_channels=in_channels,
47
+ time_embed_dim=time_embed_dim,
48
+ act_fn="silu",
49
+ )
50
+ self.down_blocks = nn.ModuleList([])
51
+ self.mid_blocks = nn.ModuleList([])
52
+ self.up_blocks = nn.ModuleList([])
53
+
54
+ output_channel = in_channels
55
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
56
+ input_channel = output_channel
57
+ output_channel = channels[i]
58
+ is_last = i == len(channels) - 1
59
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
60
+ transformer_blocks = nn.ModuleList(
61
+ [
62
+ BasicTransformerBlock(
63
+ dim=output_channel,
64
+ num_attention_heads=num_heads,
65
+ attention_head_dim=attention_head_dim,
66
+ dropout=dropout,
67
+ activation_fn=act_fn,
68
+ )
69
+ for _ in range(n_blocks)
70
+ ]
71
+ )
72
+ downsample = (
73
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
74
+ )
75
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
+
77
+ for i in range(num_mid_blocks):
78
+ input_channel = channels[-1]
79
+ out_channels = channels[-1]
80
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
81
+
82
+ transformer_blocks = nn.ModuleList(
83
+ [
84
+ BasicTransformerBlock(
85
+ dim=output_channel,
86
+ num_attention_heads=num_heads,
87
+ attention_head_dim=attention_head_dim,
88
+ dropout=dropout,
89
+ activation_fn=act_fn,
90
+ )
91
+ for _ in range(n_blocks)
92
+ ]
93
+ )
94
+
95
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
96
+
97
+ channels = channels[::-1] + (channels[0],)
98
+ for i in range(len(channels) - 1):
99
+ input_channel = channels[i] * 2
100
+ output_channel = channels[i + 1]
101
+ is_last = i == len(channels) - 2
102
+ resnet = ResnetBlock1D(
103
+ dim=input_channel,
104
+ dim_out=output_channel,
105
+ time_emb_dim=time_embed_dim,
106
+ )
107
+ transformer_blocks = nn.ModuleList(
108
+ [
109
+ BasicTransformerBlock(
110
+ dim=output_channel,
111
+ num_attention_heads=num_heads,
112
+ attention_head_dim=attention_head_dim,
113
+ dropout=dropout,
114
+ activation_fn=act_fn,
115
+ )
116
+ for _ in range(n_blocks)
117
+ ]
118
+ )
119
+ upsample = (
120
+ Upsample1D(output_channel, use_conv_transpose=True)
121
+ if not is_last
122
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
123
+ )
124
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
125
+ self.final_block = Block1D(channels[-1], channels[-1])
126
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
+ self.initialize_weights()
128
+
129
+
130
+ def initialize_weights(self):
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv1d):
133
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
134
+ if m.bias is not None:
135
+ nn.init.constant_(m.bias, 0)
136
+ elif isinstance(m, nn.GroupNorm):
137
+ nn.init.constant_(m.weight, 1)
138
+ nn.init.constant_(m.bias, 0)
139
+ elif isinstance(m, nn.Linear):
140
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+
144
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
145
+ """Forward pass of the UNet1DConditional model.
146
+
147
+ Args:
148
+ x (torch.Tensor): shape (batch_size, in_channels, time)
149
+ mask (_type_): shape (batch_size, 1, time)
150
+ t (_type_): shape (batch_size)
151
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
152
+ cond (_type_, optional): placeholder for future use. Defaults to None.
153
+
154
+ Raises:
155
+ ValueError: _description_
156
+ ValueError: _description_
157
+
158
+ Returns:
159
+ _type_: _description_
160
+ """
161
+
162
+ t = self.time_embeddings(t)
163
+ t = self.time_mlp(t)
164
+
165
+ x = pack([x, mu], "b * t")[0]
166
+
167
+ if spks is not None:
168
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
169
+ x = pack([x, spks], "b * t")[0]
170
+ if cond is not None:
171
+ x = pack([x, cond], "b * t")[0]
172
+
173
+ hiddens = []
174
+ masks = [mask]
175
+ for resnet, transformer_blocks, downsample in self.down_blocks:
176
+ mask_down = masks[-1]
177
+ x = resnet(x, mask_down, t)
178
+ x = rearrange(x, "b c t -> b t c").contiguous()
179
+ attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
180
+ for transformer_block in transformer_blocks:
181
+ x = transformer_block(
182
+ hidden_states=x,
183
+ attention_mask=attn_mask,
184
+ timestep=t,
185
+ )
186
+ x = rearrange(x, "b t c -> b c t").contiguous()
187
+ hiddens.append(x) # Save hidden states for skip connections
188
+ x = downsample(x * mask_down)
189
+ masks.append(mask_down[:, :, ::2])
190
+ masks = masks[:-1]
191
+ mask_mid = masks[-1]
192
+
193
+ for resnet, transformer_blocks in self.mid_blocks:
194
+ x = resnet(x, mask_mid, t)
195
+ x = rearrange(x, "b c t -> b t c").contiguous()
196
+ attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
197
+ for transformer_block in transformer_blocks:
198
+ x = transformer_block(
199
+ hidden_states=x,
200
+ attention_mask=attn_mask,
201
+ timestep=t,
202
+ )
203
+ x = rearrange(x, "b t c -> b c t").contiguous()
204
+
205
+ for resnet, transformer_blocks, upsample in self.up_blocks:
206
+ mask_up = masks.pop()
207
+ skip = hiddens.pop()
208
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
209
+ x = resnet(x, mask_up, t)
210
+ x = rearrange(x, "b c t -> b t c").contiguous()
211
+ attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
212
+ for transformer_block in transformer_blocks:
213
+ x = transformer_block(
214
+ hidden_states=x,
215
+ attention_mask=attn_mask,
216
+ timestep=t,
217
+ )
218
+ x = rearrange(x, "b t c -> b c t").contiguous()
219
+ x = upsample(x * mask_up)
220
+ x = self.final_block(x, mask_up)
221
+ output = self.final_proj(x * mask_up)
222
+ return output * mask