hainazhu commited on
Commit
258fd02
·
1 Parent(s): 51fab49

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -33
  2. .gitignore +3 -0
  3. Dockerfile +13 -0
  4. LICENSE +211 -0
  5. README.md +63 -6
  6. app.py +140 -0
  7. codeclm/models/__init__.py +11 -0
  8. codeclm/models/builders.py +139 -0
  9. codeclm/models/codeclm.py +303 -0
  10. codeclm/models/levo.py +224 -0
  11. codeclm/models/llama/__init__.py +90 -0
  12. codeclm/models/llama/configuration_llama.py +182 -0
  13. codeclm/models/llama/convert_llama_weights_to_hf.py +318 -0
  14. codeclm/models/llama/modeling_llama.py +1243 -0
  15. codeclm/models/llama/tokenization_llama.py +426 -0
  16. codeclm/models/llama/tokenization_llama_fast.py +264 -0
  17. codeclm/models/lm_levo.py +546 -0
  18. codeclm/modules/conditioners.py +883 -0
  19. codeclm/modules/pattern.py +351 -0
  20. codeclm/modules/streaming.py +112 -0
  21. codeclm/tokenizer/Flow1dVAE/audio.py +304 -0
  22. codeclm/tokenizer/Flow1dVAE/cal_token_stat.py +19 -0
  23. codeclm/tokenizer/Flow1dVAE/compare_model_weight.py +13 -0
  24. codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json +26 -0
  25. codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +14 -0
  26. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py +121 -0
  27. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py +94 -0
  28. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py +70 -0
  29. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py +46 -0
  30. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py +86 -0
  31. codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +283 -0
  32. codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +294 -0
  33. codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +293 -0
  34. codeclm/tokenizer/Flow1dVAE/generate_septoken.py +302 -0
  35. codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py +1278 -0
  36. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py +372 -0
  37. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py +830 -0
  38. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py +994 -0
  39. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py +313 -0
  40. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py +313 -0
  41. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py +313 -0
  42. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py +461 -0
  43. codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py +236 -0
  44. codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py +366 -0
  45. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py +268 -0
  46. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py +290 -0
  47. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py +299 -0
  48. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py +303 -0
  49. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py +301 -0
  50. codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py +305 -0
.gitattributes CHANGED
@@ -1,35 +1,12 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ third_party/demucs/ckpt/htdemucs.pth filter=lfs diff=lfs merge=lfs -text
2
+ ckpt/100000_dpo.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  *.pt filter=lfs diff=lfs merge=lfs -text
 
 
4
  *.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ ckpt/vae/autoencoder_music_1320k.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text
7
+ codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.onnx filter=lfs diff=lfs merge=lfs -text
8
+ codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.pt filter=lfs diff=lfs merge=lfs -text
9
+ third_party/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
10
+ ckpt/60000_alnew.pt filter=lfs diff=lfs merge=lfs -text
11
+ *.wav filter=lfs diff=lfs merge=lfs -text
12
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ launchs/
2
+ **/__pycache__
3
+ sample/generated/
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM juhayna/song-generation-levo:v0.1
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
LICENSE ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making SongGeneration available.
2
+
3
+ Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
6
+
7
+
8
+ License Terms of SongGeneration:
9
+ --------------------------------------------------------------------
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
12
+
13
+ - You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
14
+
15
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
16
+
17
+ For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
18
+
19
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
+
21
+
22
+ Other dependencies and licenses:
23
+
24
+
25
+ Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
26
+ --------------------------------------------------------------------
27
+ 1. stable_audio_tools
28
+ Copyright (c) 2023 Stability AI
29
+
30
+
31
+ Terms of the MIT:
32
+ --------------------------------------------------------------------
33
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
34
+
35
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
36
+
37
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
38
+
39
+ For the license of other third party components, please refer to the following URL:
40
+ https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
41
+
42
+
43
+ Open Source Software Licensed under the MIT License:
44
+ --------------------------------------------------------------------
45
+ 1. demucs
46
+ Copyright (c) Meta Platforms, Inc. and affiliates.
47
+
48
+
49
+ A copy of the MIT is included in this file.
50
+
51
+
52
+
53
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
54
+ --------------------------------------------------------------------
55
+ 1. torch
56
+ From PyTorch:
57
+
58
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
59
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
60
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
61
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
62
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
63
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
64
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
65
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
66
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
67
+
68
+ From Caffe2:
69
+
70
+ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
71
+
72
+ All contributions by Facebook:
73
+ Copyright (c) 2016 Facebook Inc.
74
+
75
+ All contributions by Google:
76
+ Copyright (c) 2015 Google Inc.
77
+ All rights reserved.
78
+
79
+ All contributions by Yangqing Jia:
80
+ Copyright (c) 2015 Yangqing Jia
81
+ All rights reserved.
82
+
83
+ All contributions by Kakao Brain:
84
+ Copyright 2019-2020 Kakao Brain
85
+
86
+ All contributions by Cruise LLC:
87
+ Copyright (c) 2022 Cruise LLC.
88
+ All rights reserved.
89
+
90
+ All contributions from Caffe:
91
+ Copyright(c) 2013, 2014, 2015, the respective contributors
92
+ All rights reserved.
93
+
94
+ All other contributions:
95
+ Copyright(c) 2015, 2016 the respective contributors
96
+ All rights reserved.
97
+
98
+ Caffe2 uses a copyright model similar to Caffe: each contributor holds
99
+ copyright over their contributions to Caffe2. The project versioning records
100
+ all such contribution and copyright details. If a contributor wants to further
101
+ mark their specific copyright on a particular contribution, they should
102
+ indicate their copyright solely in the commit message of the change when it is
103
+ committed.
104
+
105
+ All rights reserved.
106
+
107
+
108
+ Terms of the BSD 3-Clause:
109
+ --------------------------------------------------------------------
110
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
111
+
112
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
113
+
114
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
115
+
116
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
117
+
118
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
119
+
120
+ For the license of other third party components, please refer to the following URL:
121
+ https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
122
+
123
+
124
+ Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
125
+ --------------------------------------------------------------------
126
+ 1. torchaudio
127
+ Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
128
+ All rights reserved.
129
+
130
+
131
+ Terms of the BSD 2-Clause:
132
+ --------------------------------------------------------------------
133
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
134
+
135
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
136
+
137
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
138
+
139
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
140
+
141
+ For the license of other third party components, please refer to the following URL:
142
+ https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
143
+
144
+
145
+ Open Source Software License under the Apache License Version 2.0:
146
+ --------------------------------------------------------------------
147
+ 1. huggingface-hub
148
+ Copyright (c) huggingface-hub original author and authors
149
+
150
+ 2. transformers
151
+ Copyright 2018- The Hugging Face team. All rights reserved.
152
+
153
+
154
+ Terms of the Apache License Version 2.0:
155
+ --------------------------------------------------------------------
156
+ Apache License
157
+
158
+ Version 2.0, January 2004
159
+
160
+ http://www.apache.org/licenses/
161
+
162
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
163
+ 1. Definitions.
164
+
165
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
166
+
167
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
168
+
169
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
170
+
171
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
172
+
173
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
174
+
175
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
176
+
177
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
178
+
179
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
180
+
181
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
182
+
183
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
184
+
185
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
186
+
187
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
188
+
189
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
190
+
191
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
192
+
193
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
194
+
195
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
196
+
197
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
198
+
199
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
200
+
201
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
202
+
203
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
204
+
205
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
206
+
207
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
208
+
209
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
210
+
211
+ END OF TERMS AND CONDITIONS
README.md CHANGED
@@ -1,11 +1,68 @@
1
  ---
2
- title: SongGeneration LeVo
3
- emoji: 🏃
4
  colorFrom: purple
5
- colorTo: blue
6
  sdk: docker
7
- pinned: false
8
- short_description: Demo interface for the LeVo song generation model.
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: LeVo Song Generation
3
+ emoji: 🎵
4
  colorFrom: purple
5
+ colorTo: gray
6
  sdk: docker
7
+ app_port: 7860
 
8
  ---
9
 
10
+
11
+ # SongGeration:
12
+
13
+ This repository is the official code repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. You can find our paper on [here](https://arxiv.org/). The demo page is available [here](https://levo-demo.github.io/).
14
+
15
+ In this repository, we provide the SongGeration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the SFT + auto-DPO version.
16
+
17
+ ## Installation
18
+
19
+ ## Start from scatch
20
+ You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
21
+
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ then install flash attention from wget
27
+
28
+ ```bash
29
+ wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl -P /home/
30
+ pip install /home/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
31
+ ```
32
+
33
+ ## Start with docker
34
+ ```bash
35
+ docker pull juhayna/song-generation-levo:v0.1
36
+ docker run -it --gpus all --network=host juhayna/song-generation-levo:v0.1 /bin/bash
37
+ ```
38
+
39
+ ## Inference
40
+
41
+ Please note that all the two folder below must be downloaded completely for the model to load correctly, which is sourced from [here](https://huggingface.co/waytan22/SongGeneration)
42
+
43
+ - Save `ckpt` to the root directory
44
+ - Save `third_party` to the root directory
45
+
46
+ Then run inference, use the following command:
47
+
48
+ ```bash
49
+ sh generate.sh sample/lyric.jsonl sample/generate
50
+ ```
51
+ - Input keys in the `sample/lyric.jsonl`
52
+ - `idx`: name of the generate song file
53
+ - `descriptions`: text description, can be None or specified gender, timbre, genre, mood, instrument and BPM
54
+ - `prompt_audio_path`: reference audio path, can be None or 10s song audio path
55
+ - `gt_lyric`: lyrics, it needs to follow the format of '\[Structure\] Text', supported structures can be found in `conf/vocab.yaml`
56
+
57
+ - Outputs of the loader `sample/generate`:
58
+ - `audio`: generated audio files
59
+ - `jsonl`: output jsonls
60
+ - `token`: Token corresponding to the generated audio files
61
+
62
+ ## Note
63
+
64
+ Since the model is trained based on data longer than 1 minute, if the given lyrics are too short, the model will automatically fill in the lyrics to extend the duration.
65
+
66
+ ## License
67
+
68
+ The code and weights in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import numpy as np
5
+ from datetime import datetime
6
+ import os
7
+ import sys
8
+ import librosa
9
+
10
+ EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
11
+ EXAMPLE_LYRICS = """
12
+ [intro-short]
13
+
14
+ [verse]
15
+ 夜晚的街灯闪烁.
16
+ 我漫步在熟悉的角落.
17
+ 回忆像潮水般涌来.
18
+ 你的笑容如此清晰.
19
+ 在心头无法抹去.
20
+ 那些曾经的甜蜜.
21
+ 如今只剩我独自回忆.
22
+
23
+ [bridge]
24
+ 手机屏幕亮起.
25
+ 是你发来的消息.
26
+ 简单的几个字.
27
+ 却让我泪流满面.
28
+ 曾经的拥抱温暖.
29
+ 如今却变得遥远.
30
+ 我多想回到从前.
31
+ 重新拥有你的陪伴.
32
+
33
+ [chorus]
34
+ 回忆的温度还在.
35
+ 你却已不在.
36
+ 我的心被爱填满.
37
+ 却又被思念刺痛.
38
+ R&B的节奏奏响.
39
+ 我的心却在流浪.
40
+ 没有你的日子.
41
+ 我该如何继续向前.
42
+
43
+ [outro-short]
44
+ """.strip()
45
+
46
+
47
+ # 模拟歌曲生成函数
48
+ def generate_song(description, lyric, prompt_audio=None):
49
+ # 这里模拟生成过程 - 实际应用中替换为你的模型调用
50
+ print(f"Generating song with description: {description}")
51
+ print(f"Lyrics provided: {lyric}")
52
+ if prompt_audio is not None:
53
+ print("Using prompt audio for generation")
54
+
55
+ # 从文件中加载示例音频
56
+ audio_path = "./sample/example.mp3"
57
+ audio_data, sample_rate = librosa.load(audio_path, sr=None) # 保持原始采样率
58
+
59
+
60
+ # 创建输入配置的JSON
61
+ input_config = {
62
+ "description": description,
63
+ "lyric": lyric,
64
+ "has_prompt_audio": prompt_audio is not None,
65
+ "timestamp": datetime.now().isoformat(),
66
+ }
67
+
68
+ return (sample_rate, audio_data), json.dumps(input_config, indent=2)
69
+
70
+ # 创建Gradio界面
71
+ with gr.Blocks(title="LeVo Demo Space") as demo:
72
+ gr.Markdown("# 🎵 LeVo Demo Space")
73
+ gr.Markdown("Demo interface for the LeVo song generation model. Provide a description, lyrics, and optionally an audio prompt to generate a custom song.")
74
+
75
+ with gr.Row():
76
+ with gr.Column():
77
+ description = gr.Textbox(
78
+ label="Song Description",
79
+ placeholder="Describe the style, mood, and characteristics of the song...",
80
+ lines=1,
81
+ max_lines=2,
82
+ value=EXAMPLE_DESC,
83
+ )
84
+ lyric = gr.Textbox(
85
+ label="Lyrics",
86
+ placeholder="Enter the lyrics for the song...",
87
+ lines=5,
88
+ max_lines=8,
89
+ value=EXAMPLE_LYRICS,
90
+ )
91
+
92
+ with gr.Tabs(elem_id="extra-tabs"):
93
+ with gr.Tab("Audio Prompt"):
94
+ prompt_audio = gr.Audio(
95
+ label="Prompt Audio (Optional)",
96
+ type="filepath",
97
+ elem_id="audio-prompt"
98
+ )
99
+ with gr.Tab("Advanced Config"):
100
+ text_prompt = gr.Textbox(
101
+ label="Text Prompt",
102
+ placeholder="Enter the Text Prompt, eg: emotional piano pop",
103
+ )
104
+
105
+ generate_btn = gr.Button("Generate Song", variant="primary")
106
+
107
+ with gr.Column():
108
+ output_audio = gr.Audio(label="Generated Song", type="numpy")
109
+ output_json = gr.JSON(label="Input Configuration")
110
+
111
+ # 示例按钮
112
+ examples = gr.Examples(
113
+ examples=[
114
+ ["An uplifting pop song with catchy melodies"],
115
+ ["Melancholic piano ballad"],
116
+ ],
117
+ inputs=[description],
118
+ label="Description examples"
119
+ )
120
+
121
+ examples = gr.Examples(
122
+ examples=[
123
+ ["Shine bright like the stars above\nYou're the one that I'm dreaming of"],
124
+ ["The rain keeps falling on my window pane\nReminding me of love that's gone away"],
125
+ ],
126
+ inputs=[lyric],
127
+ label="Lyrics examples"
128
+ )
129
+
130
+ # 生成按钮点击事件
131
+ generate_btn.click(
132
+ fn=generate_song,
133
+ inputs=[description, lyric, prompt_audio],
134
+ outputs=[output_audio, output_json]
135
+ )
136
+
137
+
138
+ # 启动应用
139
+ if __name__ == "__main__":
140
+ demo.launch()
codeclm/models/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
8
+ """
9
+ # flake8: noqa
10
+ from . import builders
11
+ from .codeclm import CodecLM
codeclm/models/builders.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ All the functions to build the relevant models and modules
3
+ from the Hydra config.
4
+ """
5
+
6
+ import typing as tp
7
+
8
+ import omegaconf
9
+ import torch
10
+ from codeclm.utils.utils import dict_from_config
11
+ from codeclm.modules.pattern import (
12
+ CodebooksPatternProvider,
13
+ DelayedPatternProvider,
14
+ )
15
+ from codeclm.modules.conditioners import (
16
+ BaseConditioner,
17
+ QwTokenizerConditioner,
18
+ QwTextConditioner,
19
+ PhonemeTokenizerConditioner,
20
+ QuantizedEmbeddingConditioner,
21
+ ConditionerProvider,
22
+ ConditionFuser,
23
+ )
24
+
25
+
26
+ def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
27
+ from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
28
+ """Instantiate a compression model."""
29
+ if checkpoint_path is None:
30
+ return None
31
+ if checkpoint_path.startswith('//pretrained/'):
32
+ name = checkpoint_path.split('/', 3)[-1]
33
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
34
+ elif checkpoint_path == "":
35
+ return None
36
+ else:
37
+ name = checkpoint_path
38
+ return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
39
+
40
+ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
41
+ """Instantiate a LM."""
42
+ lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
43
+
44
+ # n_q: number of RVQ
45
+ code_depth = lm_kwargs['code_depth']
46
+ q_modeling = lm_kwargs.pop('q_modeling', None)
47
+
48
+ # conditioner
49
+ condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
50
+
51
+ # codebook pattern: delay
52
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
53
+ if codebooks_pattern_cfg.modeling is None:
54
+ assert q_modeling is not None, \
55
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
56
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
57
+ {'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}}
58
+ )
59
+ pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg)
60
+
61
+ # condition dropout
62
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
63
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
64
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
65
+
66
+ # condition fuser
67
+ fuser = get_condition_fuser(cfg)
68
+ lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type
69
+ if lm_type == 'Llama':
70
+ from .lm_levo import LmModel
71
+ return LmModel(
72
+ pattern_provider=pattern_provider,
73
+ condition_provider=condition_provider,
74
+ fuser=fuser,
75
+ cfg_dropout=cfg_prob,
76
+ cfg_coef=cfg_coef,
77
+ attribute_dropout=attribute_dropout,
78
+ cfg=cfg,
79
+ **lm_kwargs
80
+ ).to('cpu')
81
+ else:
82
+ raise KeyError(f"Unexpected LM model {lm_type}")
83
+
84
+
85
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
86
+ """Instantiate a conditioning model."""
87
+ cfg = getattr(cfg, 'conditioners')
88
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
89
+ conditioners: tp.Dict[str, BaseConditioner] = {}
90
+ condition_provider_args = dict_cfg.pop('args', {})
91
+
92
+ for cond, cond_cfg in dict_cfg.items():
93
+ model_type = cond_cfg['model']
94
+ model_args = cond_cfg[model_type]
95
+ if model_type == 'QwTokenizer':
96
+ conditioners[str(cond)] = QwTokenizerConditioner(
97
+ output_dim=output_dim,
98
+ **model_args
99
+ )
100
+ elif model_type == "QwTextTokenizer":
101
+ conditioners[str(cond)] = QwTextConditioner(
102
+ output_dim=output_dim,
103
+ **model_args
104
+ )
105
+ elif model_type == 'PhonemeTokenizer':
106
+ conditioners[str(cond)] = PhonemeTokenizerConditioner(
107
+ output_dim=output_dim,
108
+ **model_args
109
+ )
110
+ elif model_type == "qt_embedding":
111
+ conditioners[str(cond)] = QuantizedEmbeddingConditioner(
112
+ dim=output_dim,
113
+ **model_args
114
+ )
115
+ else:
116
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
117
+ conditioner = ConditionerProvider(conditioners, **condition_provider_args)
118
+ return conditioner
119
+
120
+
121
+ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
122
+ """Instantiate a condition fuser object."""
123
+ fuser_cfg = getattr(cfg, 'fuser')
124
+ fuser_methods = ['sum', 'prepend']
125
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
126
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
127
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
128
+ return fuser
129
+
130
+
131
+ def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
132
+ """Instantiate a codebooks pattern provider object."""
133
+ pattern_providers = {
134
+ 'delay': DelayedPatternProvider,
135
+ }
136
+ name = cfg.modeling
137
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
138
+ klass = pattern_providers[name]
139
+ return klass(code_depth, **kwargs)
codeclm/models/codeclm.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main model for using CodecLM. This will combine all the required components
3
+ and provide easy access to the generation API.
4
+ """
5
+
6
+ import typing as tp
7
+ import warnings
8
+
9
+ import torch
10
+
11
+ from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
12
+ from .lm_levo import LmModel
13
+ from ..modules.conditioners import ConditioningAttributes, AudioCondition
14
+ from ..utils.autocast import TorchAutocast
15
+ import torch
16
+ from torch.nn import functional as F
17
+ import torchaudio
18
+ # from optim.ema import EMA
19
+
20
+
21
+ MelodyList = tp.List[tp.Optional[torch.Tensor]]
22
+ MelodyType = tp.Union[torch.Tensor, MelodyList]
23
+
24
+ class CodecLM:
25
+ """CodecLM main model with convenient generation API.
26
+
27
+ Args:
28
+ name (str): name of the model.
29
+ compression_model (CompressionModel): Compression model
30
+ used to map audio to invertible discrete representations.
31
+ lm (LMModel): Language model over discrete representations.
32
+ max_duration (float, optional): maximum duration the model can produce,
33
+ otherwise, inferred from the training params.
34
+ """
35
+ def __init__(self, name: str, audiotokenizer: AudioTokenizer, lm: LmModel,
36
+ max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
37
+ self.name = name
38
+ self.audiotokenizer = audiotokenizer
39
+ self.lm = lm
40
+ self.seperate_tokenizer = seperate_tokenizer
41
+ # import pdb; pdb.set_trace()
42
+ if max_duration is None:
43
+ if hasattr(lm, 'cfg'):
44
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
45
+ else:
46
+ raise ValueError("You must provide max_duration when building directly CodecLM")
47
+ assert max_duration is not None
48
+
49
+ self.max_duration: float = max_duration
50
+ self.device = next(iter(lm.parameters())).device
51
+ self.generation_params: dict = {}
52
+ # self.set_generation_params(duration=15) # 15 seconds by default
53
+ self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
54
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
55
+ if self.device.type == 'cpu':
56
+ self.autocast = TorchAutocast(enabled=False)
57
+ else:
58
+ self.autocast = TorchAutocast(enabled=False)
59
+
60
+
61
+
62
+ @property
63
+ def frame_rate(self) -> float:
64
+ """Roughly the number of AR steps per seconds."""
65
+ return self.audiotokenizer.frame_rate
66
+
67
+ @property
68
+ def sample_rate(self) -> int:
69
+ """Sample rate of the generated audio."""
70
+ return self.audiotokenizer.sample_rate
71
+
72
+ @property
73
+ def audio_channels(self) -> int:
74
+ """Audio channels of the generated audio."""
75
+ return self.audiotokenizer.channels
76
+
77
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
78
+ top_p: float = 0.0, temperature: float = 1.0,
79
+ duration: float = 30.0, cfg_coef: float = 3.0,
80
+ extend_stride: float = 18, record_tokens: bool = False,
81
+ record_window: int = 50):
82
+ """Set the generation parameters for CodecLM.
83
+
84
+ Args:
85
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
86
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
87
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
88
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
89
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
90
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
91
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
92
+ instead of batching together the two. This has some impact on how things
93
+ are padded but seems to have little impact in practice.
94
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
95
+ should we extend the audio each time. Larger values will mean less context is
96
+ preserved, and shorter value will require extra computations.
97
+ """
98
+ assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration."
99
+ self.extend_stride = extend_stride
100
+ self.duration = duration
101
+ self.generation_params = {
102
+ 'use_sampling': use_sampling,
103
+ 'temp': temperature,
104
+ 'top_k': top_k,
105
+ 'top_p': top_p,
106
+ 'cfg_coef': cfg_coef,
107
+ 'record_tokens': record_tokens,
108
+ 'record_window': record_window,
109
+ }
110
+
111
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
112
+ """Override the default progress callback."""
113
+ self._progress_callback = progress_callback
114
+
115
+ # Inference
116
+ def generate(self, lyrics: tp.List[str],
117
+ descriptions: tp.List[str],
118
+ melody_wavs: torch.Tensor = None,
119
+ melody_is_wav: bool = True,
120
+ vocal_wavs: torch.Tensor = None,
121
+ bgm_wavs: torch.Tensor = None,
122
+ return_tokens: bool = False,
123
+ ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
124
+ """Generate samples conditioned on text and melody.
125
+
126
+ Args:
127
+ descriptions (list of str): A list of strings used as text conditioning.
128
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
129
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
130
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
131
+ a list of [C, T] tensors.
132
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
133
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
134
+ """
135
+ if melody_wavs is not None:
136
+ if melody_wavs.dim() == 2:
137
+ melody_wavs = melody_wavs[None]
138
+ if melody_wavs.dim() != 3:
139
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
140
+ melody_wavs = list(melody_wavs)
141
+ if vocal_wavs is not None:
142
+ if vocal_wavs.dim() == 2:
143
+ vocal_wavs = vocal_wavs[None]
144
+ if vocal_wavs.dim() != 3:
145
+ raise ValueError("Vocal wavs should have a shape [B, C, T].")
146
+ vocal_wavs = list(vocal_wavs)
147
+ if bgm_wavs is not None:
148
+ if bgm_wavs.dim() == 2:
149
+ bgm_wavs = bgm_wavs[None]
150
+ if bgm_wavs.dim() != 3:
151
+ raise ValueError("BGM wavs should have a shape [B, C, T].")
152
+ bgm_wavs = list(bgm_wavs)
153
+
154
+ texts, audio_qt_embs = self._prepare_tokens_and_attributes(lyrics=lyrics, melody_wavs=melody_wavs, vocal_wavs=vocal_wavs, bgm_wavs=bgm_wavs, melody_is_wav=melody_is_wav)
155
+ tokens = self._generate_tokens(texts, descriptions, audio_qt_embs)
156
+
157
+ if (tokens == self.lm.eos_token_id).any():
158
+ length = torch.nonzero(torch.eq(tokens, self.lm.eos_token_id))[:,-1].min()
159
+ tokens = tokens[...,:length]
160
+
161
+ if return_tokens:
162
+ return tokens
163
+ else:
164
+ out = self.generate_audio(tokens)
165
+ return out
166
+
167
+
168
+ @torch.no_grad()
169
+ def _prepare_tokens_and_attributes(
170
+ self,
171
+ lyrics: tp.Sequence[tp.Optional[str]],
172
+ melody_wavs: tp.Optional[MelodyList] = None,
173
+ vocal_wavs: tp.Optional[MelodyList] = None,
174
+ bgm_wavs: tp.Optional[MelodyList] = None,
175
+ melody_is_wav = True
176
+ ) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]:
177
+ """Prepare model inputs.
178
+
179
+ Args:
180
+ descriptions (list of str): A list of strings used as text conditioning.
181
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
182
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
183
+ used as melody conditioning. Defaults to None.
184
+ """
185
+ assert len(lyrics) == 1
186
+ texts = [lyric for lyric in lyrics]
187
+ audio_qt_embs = []
188
+ target_melody_token_len = self.lm.cfg.prompt_len * self.audiotokenizer.frame_rate
189
+ # import pdb; pdb.set_trace()
190
+ if melody_wavs is None:
191
+ melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
192
+ elif melody_wavs is not None:
193
+ if 'prompt_audio' not in self.lm.condition_provider.conditioners:
194
+ raise RuntimeError("This model doesn't support melody conditioning. "
195
+ "Use the `melody` model.")
196
+ assert len(melody_wavs) == len(texts), \
197
+ f"number of melody wavs must match number of descriptions! " \
198
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}"
199
+ if type(melody_wavs) == list:
200
+ melody_wavs = torch.stack(melody_wavs, dim=0)
201
+ melody_wavs = melody_wavs.to(self.device)
202
+ if melody_is_wav:
203
+ melody_tokens, scale = self.audiotokenizer.encode(melody_wavs)
204
+ else:
205
+ melody_tokens = melody_wavs
206
+ if melody_tokens.shape[-1] > target_melody_token_len:
207
+ melody_tokens = melody_tokens[...,:target_melody_token_len]
208
+ elif melody_tokens.shape[-1] < target_melody_token_len:
209
+ melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
210
+ if self.seperate_tokenizer is not None:
211
+ if vocal_wavs is not None:
212
+ if type(vocal_wavs) == list:
213
+ vocal_wavs = torch.stack(vocal_wavs, dim=0)
214
+ if bgm_wavs is None:
215
+ use_bgm = False
216
+ bgm_wavs = torch.zeros_like(vocal_wavs)
217
+ bgm_wavs[:, 0] = 1.0
218
+ bgm_wavs[:, 1:] = torch.randn_like(bgm_wavs[:, 1:])* 0.0003
219
+ else:
220
+ use_bgm = True
221
+ if type(bgm_wavs) == list:
222
+ bgm_wavs = torch.stack(bgm_wavs, dim=0)
223
+ vocal_wavs = vocal_wavs.to(self.device)
224
+ bgm_wavs = bgm_wavs.to(self.device)
225
+ vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
226
+ assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
227
+ f"vocal and bgm tokens should have a shape [B, C, T]! " \
228
+ f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
229
+ assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
230
+ f"vocal and bgm tokens should have the same length! " \
231
+ f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
232
+ if not use_bgm:
233
+ bgm_tokens = torch.full_like(bgm_tokens, 16385)
234
+ if bgm_tokens.shape[-1] > target_melody_token_len:
235
+ bgm_tokens = bgm_tokens[...,:target_melody_token_len]
236
+ elif bgm_tokens.shape[-1] < target_melody_token_len:
237
+ bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
238
+ if vocal_tokens.shape[-1] > target_melody_token_len:
239
+ vocal_tokens = vocal_tokens[...,:target_melody_token_len]
240
+ elif vocal_tokens.shape[-1] < target_melody_token_len:
241
+ vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
242
+ else:
243
+ bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
244
+ vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
245
+
246
+ melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
247
+ assert melody_tokens.shape[-1] == target_melody_token_len
248
+ audio_qt_embs = melody_tokens.long()
249
+ return texts, audio_qt_embs
250
+
251
+
252
+
253
+ def _generate_tokens(self,
254
+ texts: tp.Optional[tp.List[str]] = None,
255
+ descriptions: tp.Optional[tp.List[str]] = None,
256
+ audio_qt_embs: tp.Optional[tp.List[torch.Tensor]] = None) -> torch.Tensor:
257
+ """Generate discrete audio tokens given audio prompt and/or conditions.
258
+
259
+ Args:
260
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
261
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
262
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
263
+ Returns:
264
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
265
+ """
266
+ total_gen_len = int(self.duration * self.frame_rate)
267
+ current_gen_offset: int = 0
268
+
269
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
270
+ generated_tokens += current_gen_offset
271
+ if self._progress_callback is not None:
272
+ # Note that total_gen_len might be quite wrong depending on the
273
+ # codebook pattern used, but with delay it is almost accurate.
274
+ self._progress_callback(generated_tokens, total_gen_len)
275
+ else:
276
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
277
+
278
+ if self.duration <= self.max_duration:
279
+ # generate by sampling from LM, simple case.
280
+ with self.autocast:
281
+ gen_tokens = self.lm.generate(texts=texts,
282
+ descriptions=descriptions,
283
+ audio_qt_embs=audio_qt_embs,
284
+ max_gen_len=total_gen_len,
285
+ **self.generation_params)
286
+ else:
287
+ raise NotImplementedError(f"duration {self.duration} < max duration {self.max_duration}")
288
+ return gen_tokens
289
+
290
+ @torch.no_grad()
291
+ def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None):
292
+ """Generate Audio from tokens"""
293
+ assert gen_tokens.dim() == 3
294
+ if self.seperate_tokenizer is not None:
295
+ gen_tokens_song = gen_tokens[:, [0], :]
296
+ gen_tokens_vocal = gen_tokens[:, [1], :]
297
+ gen_tokens_bgm = gen_tokens[:, [2], :]
298
+ # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
299
+ gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt)
300
+ return gen_audio_seperate
301
+ else:
302
+ gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
303
+ return gen_audio
codeclm/models/levo.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .llama.modeling_llama import LlamaConfig, CausalLMOutputWithPast, BaseModelOutputWithPast, LlamaDecoderLayer, LlamaRMSNorm
3
+ from .llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLM_base
4
+ from .llama.modeling_llama import LlamaModel as LlamaModel_base
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Union, Optional, Tuple, List
9
+ from packaging import version
10
+ import transformers
11
+ """
12
+ Wrap the original Llama model for potential customized changes.
13
+ """
14
+
15
+ """main class"""
16
+ class CausalLM(LlamaForCausalLM_base):
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.model = LmModel(config)
20
+ self.vocab_size = config.vocab_size
21
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
22
+
23
+ def forward(
24
+ self,
25
+ input_ids: torch.LongTensor = None,
26
+ attention_mask: Optional[torch.Tensor] = None,
27
+ position_ids: Optional[torch.LongTensor] = None,
28
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
29
+ inputs_embeds: Optional[torch.FloatTensor] = None,
30
+ labels: Optional[torch.LongTensor] = None,
31
+ use_cache: Optional[bool] = None,
32
+ output_attentions: Optional[bool] = None,
33
+ output_hidden_states: Optional[bool] = None,
34
+ return_dict: Optional[bool] = None,
35
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
36
+
37
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
38
+ output_hidden_states = (
39
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
40
+ )
41
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
42
+
43
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
44
+ outputs = self.model(
45
+ input_ids=input_ids,
46
+ attention_mask=attention_mask,
47
+ position_ids=position_ids,
48
+ past_key_values=past_key_values,
49
+ inputs_embeds=inputs_embeds,
50
+ use_cache=use_cache,
51
+ output_attentions=output_attentions,
52
+ output_hidden_states=output_hidden_states,
53
+ return_dict=return_dict,
54
+ )
55
+
56
+ hidden_states = outputs[0]
57
+ if self.config.pretraining_tp > 1:
58
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
59
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
60
+ logits = torch.cat(logits, dim=-1)
61
+ else:
62
+ logits = self.lm_head(hidden_states)
63
+ logits = logits.float()
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ # Shift so that tokens < n predict n
68
+ shift_logits = logits[..., :-1, :].contiguous()
69
+ shift_labels = labels[..., 1:].contiguous()
70
+ # Flatten the tokens
71
+ loss_fct = nn.CrossEntropyLoss()
72
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
73
+ shift_labels = shift_labels.view(-1)
74
+ # Enable model parallelism
75
+ shift_labels = shift_labels.to(shift_logits.device)
76
+ loss = loss_fct(shift_logits, shift_labels)
77
+
78
+ if not return_dict:
79
+ output = (logits,) + outputs[1:]
80
+ return (loss,) + output if loss is not None else output
81
+
82
+ return CausalLMOutputWithPast(
83
+ loss=loss,
84
+ logits=logits,
85
+ past_key_values=outputs.past_key_values,
86
+ hidden_states=hidden_states,
87
+ attentions=outputs.attentions,
88
+ )
89
+
90
+
91
+ """Submodel class"""
92
+ class LmModel(LlamaModel_base):
93
+ def __init__(self, config: LlamaConfig):
94
+ super().__init__(config)
95
+ self.padding_idx = config.pad_token_id
96
+ self.vocab_size = config.vocab_size
97
+ layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
98
+
99
+ assert version.parse(transformers.__version__) < version.parse("4.40")
100
+
101
+ self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
102
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
103
+
104
+ self.gradient_checkpointing = False
105
+ # Initialize weights and apply final processing
106
+ self.post_init()
107
+ self.gradient_checkpointing_disable()
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor = None,
112
+ attention_mask: Optional[torch.Tensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
115
+ inputs_embeds: Optional[torch.FloatTensor] = None,
116
+ use_cache: Optional[bool] = None,
117
+ output_attentions: Optional[bool] = None,
118
+ output_hidden_states: Optional[bool] = None,
119
+ return_dict: Optional[bool] = None,
120
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
121
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
122
+ output_hidden_states = (
123
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
124
+ )
125
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
126
+
127
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
128
+
129
+ # retrieve input_ids and inputs_embeds
130
+ if input_ids is not None and inputs_embeds is not None:
131
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
132
+ elif input_ids is not None:
133
+ batch_size, seq_length = input_ids.shape
134
+ elif inputs_embeds is not None:
135
+ batch_size, seq_length, _ = inputs_embeds.shape
136
+ else:
137
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
138
+
139
+ seq_length_with_past = seq_length
140
+ past_key_values_length = 0
141
+
142
+ if past_key_values is not None:
143
+ past_key_values_length = past_key_values[0][0].shape[2]
144
+ seq_length_with_past = seq_length_with_past + past_key_values_length
145
+
146
+ if position_ids is None:
147
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
148
+ position_ids = torch.arange(
149
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
150
+ )
151
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
152
+ else:
153
+ position_ids = position_ids.view(-1, seq_length).long()
154
+
155
+ # embed positions
156
+ if attention_mask is None:
157
+ attention_mask = torch.ones(
158
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
159
+ )
160
+ attention_mask = self._prepare_decoder_attention_mask(
161
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
162
+ )
163
+
164
+ hidden_states = inputs_embeds
165
+
166
+ if self.gradient_checkpointing and self.training:
167
+ if use_cache:
168
+ use_cache = False
169
+
170
+ # decoder layers
171
+ all_hidden_states = () if output_hidden_states else None
172
+ all_self_attns = () if output_attentions else None
173
+ next_decoder_cache = () if use_cache else None
174
+
175
+ for idx, decoder_layer in enumerate(self.layers):
176
+ if output_hidden_states:
177
+ all_hidden_states += (hidden_states,)
178
+
179
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
180
+
181
+ layer_args = (hidden_states, attention_mask, position_ids,)
182
+
183
+ if self.gradient_checkpointing and self.training:
184
+
185
+ def create_custom_forward(module):
186
+ def custom_forward(*inputs):
187
+ # None for past_key_value
188
+ return module(*inputs, past_key_value, output_attentions)
189
+
190
+ return custom_forward
191
+ layer_outputs = torch.utils.checkpoint.checkpoint(
192
+ create_custom_forward(decoder_layer), *layer_args
193
+ )
194
+ else:
195
+
196
+ layer_outputs = decoder_layer(*layer_args,
197
+ past_key_value=past_key_value,
198
+ output_attentions=output_attentions,
199
+ use_cache=use_cache)
200
+
201
+ hidden_states = layer_outputs[0]
202
+
203
+ if use_cache:
204
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
205
+
206
+ if output_attentions:
207
+ all_self_attns += (layer_outputs[1],)
208
+
209
+ hidden_states = self.norm(hidden_states)
210
+
211
+ # add hidden states from the last decoder layer
212
+ if output_hidden_states:
213
+ all_hidden_states += (hidden_states,)
214
+
215
+ next_cache = next_decoder_cache if use_cache else None
216
+ if not return_dict:
217
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
218
+ return BaseModelOutputWithPast(
219
+ last_hidden_state=hidden_states,
220
+ past_key_values=next_cache,
221
+ hidden_states=all_hidden_states,
222
+ attentions=all_self_attns,
223
+ )
224
+
codeclm/models/llama/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_sentencepiece_available,
20
+ is_tokenizers_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {
26
+ "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
27
+ }
28
+
29
+ try:
30
+ if not is_sentencepiece_available():
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ pass
34
+ else:
35
+ _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
36
+
37
+ try:
38
+ if not is_tokenizers_available():
39
+ raise OptionalDependencyNotAvailable()
40
+ except OptionalDependencyNotAvailable:
41
+ pass
42
+ else:
43
+ _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
44
+
45
+ try:
46
+ if not is_torch_available():
47
+ raise OptionalDependencyNotAvailable()
48
+ except OptionalDependencyNotAvailable:
49
+ pass
50
+ else:
51
+ _import_structure["modeling_llama"] = [
52
+ "LlamaForCausalLM",
53
+ "LlamaModel",
54
+ "LlamaPreTrainedModel",
55
+ "LlamaForSequenceClassification",
56
+ ]
57
+
58
+
59
+ if TYPE_CHECKING:
60
+ from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
61
+
62
+ try:
63
+ if not is_sentencepiece_available():
64
+ raise OptionalDependencyNotAvailable()
65
+ except OptionalDependencyNotAvailable:
66
+ pass
67
+ else:
68
+ from .tokenization_llama import LlamaTokenizer
69
+
70
+ try:
71
+ if not is_tokenizers_available():
72
+ raise OptionalDependencyNotAvailable()
73
+ except OptionalDependencyNotAvailable:
74
+ pass
75
+ else:
76
+ from .tokenization_llama_fast import LlamaTokenizerFast
77
+
78
+ try:
79
+ if not is_torch_available():
80
+ raise OptionalDependencyNotAvailable()
81
+ except OptionalDependencyNotAvailable:
82
+ pass
83
+ else:
84
+ from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
85
+
86
+
87
+ else:
88
+ import sys
89
+
90
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
codeclm/models/llama/configuration_llama.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
29
+
30
+
31
+ class LlamaConfig(PretrainedConfig):
32
+ r"""
33
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
34
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
35
+ defaults will yield a similar configuration to that of the LLaMA-7B.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 32000):
43
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`LlamaModel`]
45
+ hidden_size (`int`, *optional*, defaults to 4096):
46
+ Dimension of the hidden representations.
47
+ intermediate_size (`int`, *optional*, defaults to 11008):
48
+ Dimension of the MLP representations.
49
+ num_hidden_layers (`int`, *optional*, defaults to 32):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 32):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ num_key_value_heads (`int`, *optional*):
54
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
55
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
56
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
57
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
58
+ by meanpooling all the original heads within that group. For more details checkout [this
59
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
60
+ `num_attention_heads`.
61
+ pretraining_tp (`int`, *optional*, defaults to `1`):
62
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
63
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
64
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
65
+ issue](https://github.com/pytorch/pytorch/issues/76232).
66
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67
+ The non-linear activation function (function or string) in the decoder.
68
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
69
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
70
+ Llama 2 up to 4096, CodeLlama up to 16384.
71
+ initializer_range (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
74
+ The epsilon used by the rms normalization layers.
75
+ use_cache (`bool`, *optional*, defaults to `True`):
76
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
77
+ relevant if `config.is_decoder=True`.
78
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
79
+ Whether to tie weight embeddings
80
+ rope_theta (`float`, *optional*, defaults to 10000.0):
81
+ The base period of the RoPE embeddings.
82
+ rope_scaling (`Dict`, *optional*):
83
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
84
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
85
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
86
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
87
+ these scaling strategies behave:
88
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
89
+ experimental feature, subject to breaking API changes in future versions.
90
+ attention_bias (`bool`, defaults to `False`):
91
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import LlamaModel, LlamaConfig
97
+
98
+ >>> # Initializing a LLaMA llama-7b style configuration
99
+ >>> configuration = LlamaConfig()
100
+
101
+ >>> # Initializing a model from the llama-7b style configuration
102
+ >>> model = LlamaModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "llama"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_size=32000,
113
+ hidden_size=4096,
114
+ intermediate_size=11008,
115
+ num_hidden_layers=32,
116
+ num_attention_heads=32,
117
+ num_key_value_heads=None,
118
+ hidden_act="silu",
119
+ max_position_embeddings=2048,
120
+ initializer_range=0.02,
121
+ rms_norm_eps=1e-6,
122
+ use_cache=True,
123
+ pad_token_id=None,
124
+ bos_token_id=1,
125
+ eos_token_id=2,
126
+ pretraining_tp=1,
127
+ tie_word_embeddings=False,
128
+ rope_theta=10000.0,
129
+ rope_scaling=None,
130
+ attention_bias=False,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.max_position_embeddings = max_position_embeddings
135
+ self.hidden_size = hidden_size
136
+ self.intermediate_size = intermediate_size
137
+ self.num_hidden_layers = num_hidden_layers
138
+ self.num_attention_heads = num_attention_heads
139
+
140
+ # for backward compatibility
141
+ if num_key_value_heads is None:
142
+ num_key_value_heads = num_attention_heads
143
+
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.hidden_act = hidden_act
146
+ self.initializer_range = initializer_range
147
+ self.rms_norm_eps = rms_norm_eps
148
+ self.pretraining_tp = pretraining_tp
149
+ self.use_cache = use_cache
150
+ self.rope_theta = rope_theta
151
+ self.rope_scaling = rope_scaling
152
+ self._rope_scaling_validation()
153
+ self.attention_bias = attention_bias
154
+
155
+ super().__init__(
156
+ pad_token_id=pad_token_id,
157
+ bos_token_id=bos_token_id,
158
+ eos_token_id=eos_token_id,
159
+ tie_word_embeddings=tie_word_embeddings,
160
+ **kwargs,
161
+ )
162
+
163
+ def _rope_scaling_validation(self):
164
+ """
165
+ Validate the `rope_scaling` configuration.
166
+ """
167
+ if self.rope_scaling is None:
168
+ return
169
+
170
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
171
+ raise ValueError(
172
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
173
+ f"got {self.rope_scaling}"
174
+ )
175
+ rope_scaling_type = self.rope_scaling.get("type", None)
176
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
177
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
178
+ raise ValueError(
179
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
180
+ )
181
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
182
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
codeclm/models/llama/convert_llama_weights_to_hf.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import os
18
+ import shutil
19
+ import warnings
20
+
21
+ import torch
22
+
23
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
24
+
25
+
26
+ try:
27
+ from transformers import LlamaTokenizerFast
28
+ except ImportError as e:
29
+ warnings.warn(e)
30
+ warnings.warn(
31
+ "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
32
+ )
33
+ LlamaTokenizerFast = None
34
+
35
+ """
36
+ Sample usage:
37
+
38
+ ```
39
+ python src/transformers/models/llama/convert_llama_weights_to_hf.py \
40
+ --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
41
+ ```
42
+
43
+ Thereafter, models can be loaded via:
44
+
45
+ ```py
46
+ from transformers import LlamaForCausalLM, LlamaTokenizer
47
+
48
+ model = LlamaForCausalLM.from_pretrained("/output/path")
49
+ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
50
+ ```
51
+
52
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
53
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
54
+ """
55
+
56
+ NUM_SHARDS = {
57
+ "7B": 1,
58
+ "7Bf": 1,
59
+ "13B": 2,
60
+ "13Bf": 2,
61
+ "34B": 4,
62
+ "30B": 4,
63
+ "65B": 8,
64
+ "70B": 8,
65
+ "70Bf": 8,
66
+ }
67
+
68
+
69
+ def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
70
+ return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
71
+
72
+
73
+ def read_json(path):
74
+ with open(path, "r") as f:
75
+ return json.load(f)
76
+
77
+
78
+ def write_json(text, path):
79
+ with open(path, "w") as f:
80
+ json.dump(text, f)
81
+
82
+
83
+ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
84
+ # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
85
+ if not os.path.isfile(os.path.join(input_base_path, "params.json")):
86
+ input_base_path = os.path.join(input_base_path, model_size)
87
+
88
+ os.makedirs(model_path, exist_ok=True)
89
+ tmp_model_path = os.path.join(model_path, "tmp")
90
+ os.makedirs(tmp_model_path, exist_ok=True)
91
+
92
+ params = read_json(os.path.join(input_base_path, "params.json"))
93
+ num_shards = NUM_SHARDS[model_size]
94
+ n_layers = params["n_layers"]
95
+ n_heads = params["n_heads"]
96
+ n_heads_per_shard = n_heads // num_shards
97
+ dim = params["dim"]
98
+ dims_per_head = dim // n_heads
99
+ base = params.get("rope_theta", 10000.0)
100
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
101
+ if base > 10000.0:
102
+ max_position_embeddings = 16384
103
+ else:
104
+ max_position_embeddings = 2048
105
+
106
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
107
+ if tokenizer_path is not None:
108
+ tokenizer = tokenizer_class(tokenizer_path)
109
+ tokenizer.save_pretrained(model_path)
110
+ vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
111
+
112
+ if "n_kv_heads" in params:
113
+ num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
114
+ num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
115
+ key_value_dim = dim // num_key_value_heads
116
+ else: # compatibility with other checkpoints
117
+ num_key_value_heads = n_heads
118
+ num_local_key_value_heads = n_heads_per_shard
119
+ key_value_dim = dim
120
+
121
+ # permute for sliced rotary
122
+ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
123
+ return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
124
+
125
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
126
+ # Load weights
127
+ if model_size == "7B":
128
+ # Not sharded
129
+ # (The sharded implementation would also work, but this is simpler.)
130
+ loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
131
+ else:
132
+ # Sharded
133
+ loaded = [
134
+ torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
135
+ for i in range(num_shards)
136
+ ]
137
+ param_count = 0
138
+ index_dict = {"weight_map": {}}
139
+ for layer_i in range(n_layers):
140
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
141
+ if model_size == "7B":
142
+ # Unsharded
143
+ state_dict = {
144
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
145
+ loaded[f"layers.{layer_i}.attention.wq.weight"]
146
+ ),
147
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
148
+ loaded[f"layers.{layer_i}.attention.wk.weight"]
149
+ ),
150
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
151
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
152
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
153
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
154
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
155
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
156
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
157
+ }
158
+ else:
159
+ # Sharded
160
+ # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
161
+ # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
162
+ # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
163
+
164
+ state_dict = {
165
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
166
+ f"layers.{layer_i}.attention_norm.weight"
167
+ ].clone(),
168
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
169
+ f"layers.{layer_i}.ffn_norm.weight"
170
+ ].clone(),
171
+ }
172
+ state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
173
+ torch.cat(
174
+ [
175
+ loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
176
+ for i in range(num_shards)
177
+ ],
178
+ dim=0,
179
+ ).reshape(dim, dim)
180
+ )
181
+ state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
182
+ torch.cat(
183
+ [
184
+ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
185
+ num_local_key_value_heads, dims_per_head, dim
186
+ )
187
+ for i in range(num_shards)
188
+ ],
189
+ dim=0,
190
+ ).reshape(key_value_dim, dim),
191
+ num_key_value_heads,
192
+ key_value_dim,
193
+ dim,
194
+ )
195
+ state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
196
+ [
197
+ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
198
+ num_local_key_value_heads, dims_per_head, dim
199
+ )
200
+ for i in range(num_shards)
201
+ ],
202
+ dim=0,
203
+ ).reshape(key_value_dim, dim)
204
+
205
+ state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
206
+ [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
207
+ )
208
+ state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
209
+ [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
210
+ )
211
+ state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
212
+ [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
213
+ )
214
+ state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
215
+ [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
216
+ )
217
+
218
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
219
+ for k, v in state_dict.items():
220
+ index_dict["weight_map"][k] = filename
221
+ param_count += v.numel()
222
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
223
+
224
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
225
+ if model_size == "7B":
226
+ # Unsharded
227
+ state_dict = {
228
+ "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
229
+ "model.norm.weight": loaded["norm.weight"],
230
+ "lm_head.weight": loaded["output.weight"],
231
+ }
232
+ else:
233
+ state_dict = {
234
+ "model.norm.weight": loaded[0]["norm.weight"],
235
+ "model.embed_tokens.weight": torch.cat(
236
+ [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
237
+ ),
238
+ "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
239
+ }
240
+
241
+ for k, v in state_dict.items():
242
+ index_dict["weight_map"][k] = filename
243
+ param_count += v.numel()
244
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
245
+
246
+ # Write configs
247
+ index_dict["metadata"] = {"total_size": param_count * 2}
248
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
249
+ ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
250
+ multiple_of = params["multiple_of"] if "multiple_of" in params else 256
251
+ config = LlamaConfig(
252
+ hidden_size=dim,
253
+ intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
254
+ num_attention_heads=params["n_heads"],
255
+ num_hidden_layers=params["n_layers"],
256
+ rms_norm_eps=params["norm_eps"],
257
+ num_key_value_heads=num_key_value_heads,
258
+ vocab_size=vocab_size,
259
+ rope_theta=base,
260
+ max_position_embeddings=max_position_embeddings,
261
+ )
262
+ config.save_pretrained(tmp_model_path)
263
+
264
+ # Make space so we can load the model properly now.
265
+ del state_dict
266
+ del loaded
267
+ gc.collect()
268
+
269
+ print("Loading the checkpoint in a Llama model.")
270
+ model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
271
+ # Avoid saving this as part of the config.
272
+ del model.config._name_or_path
273
+ model.config.torch_dtype = torch.float16
274
+ print("Saving in the Transformers format.")
275
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
276
+ shutil.rmtree(tmp_model_path)
277
+
278
+
279
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
280
+ # Initialize the tokenizer based on the `spm` model
281
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
282
+ print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
283
+ tokenizer = tokenizer_class(input_tokenizer_path)
284
+ tokenizer.save_pretrained(tokenizer_path)
285
+
286
+
287
+ def main():
288
+ parser = argparse.ArgumentParser()
289
+ parser.add_argument(
290
+ "--input_dir",
291
+ help="Location of LLaMA weights, which contains tokenizer.model and model folders",
292
+ )
293
+ parser.add_argument(
294
+ "--model_size",
295
+ choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
296
+ help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
297
+ )
298
+ parser.add_argument(
299
+ "--output_dir",
300
+ help="Location to write HF model and tokenizer",
301
+ )
302
+ parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
303
+ args = parser.parse_args()
304
+ spm_path = os.path.join(args.input_dir, "tokenizer.model")
305
+ if args.model_size != "tokenizer_only":
306
+ write_model(
307
+ model_path=args.output_dir,
308
+ input_base_path=args.input_dir,
309
+ model_size=args.model_size,
310
+ safe_serialization=args.safe_serialization,
311
+ tokenizer_path=spm_path,
312
+ )
313
+ else:
314
+ write_tokenizer(args.output_dir, spm_path)
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()
codeclm/models/llama/modeling_llama.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
34
+ from transformers.utils import (
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ is_flash_attn_available,
38
+ logging,
39
+ replace_return_docstrings,
40
+ )
41
+ from .configuration_llama import LlamaConfig
42
+
43
+
44
+ if is_flash_attn_available():
45
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
46
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CONFIG_FOR_DOC = "LlamaConfig"
52
+
53
+
54
+ def _get_unpad_data(padding_mask):
55
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
56
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
57
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
58
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
59
+ return (
60
+ indices,
61
+ cu_seqlens,
62
+ max_seqlen_in_batch,
63
+ )
64
+
65
+
66
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
69
+ ):
70
+ """
71
+ Make causal mask used for bi-directional self-attention.
72
+ """
73
+ bsz, tgt_len = input_ids_shape
74
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
75
+ mask_cond = torch.arange(mask.size(-1), device=device)
76
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
77
+ mask = mask.to(dtype)
78
+
79
+ if past_key_values_length > 0:
80
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
81
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
82
+
83
+
84
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
85
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
+ """
87
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
88
+ """
89
+ bsz, src_len = mask.size()
90
+ tgt_len = tgt_len if tgt_len is not None else src_len
91
+
92
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
93
+
94
+ inverted_mask = 1.0 - expanded_mask
95
+
96
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
97
+
98
+
99
+ class LlamaRMSNorm(nn.Module):
100
+ def __init__(self, hidden_size, eps=1e-6):
101
+ """
102
+ LlamaRMSNorm is equivalent to T5LayerNorm
103
+ """
104
+ super().__init__()
105
+ self.weight = nn.Parameter(torch.ones(hidden_size))
106
+ self.variance_epsilon = eps
107
+
108
+ def forward(self, hidden_states):
109
+ input_dtype = hidden_states.dtype
110
+ hidden_states = hidden_states.to(torch.float32)
111
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
112
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
113
+ return self.weight * hidden_states.to(input_dtype)
114
+
115
+
116
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
117
+
118
+
119
+ class LlamaRotaryEmbedding(nn.Module):
120
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
121
+ super().__init__()
122
+
123
+ self.dim = dim
124
+ self.max_position_embeddings = max_position_embeddings
125
+ self.base = base
126
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
127
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
128
+
129
+ # Build here to make `torch.jit.trace` work.
130
+ self._set_cos_sin_cache(
131
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
132
+ )
133
+
134
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
135
+ self.max_seq_len_cached = seq_len
136
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
137
+
138
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
139
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
140
+ emb = torch.cat((freqs, freqs), dim=-1)
141
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
142
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
143
+
144
+ def forward(self, x, seq_len=None):
145
+ # x: [bs, num_attention_heads, seq_len, head_size]
146
+ if seq_len > self.max_seq_len_cached:
147
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
148
+
149
+ return (
150
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
151
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
152
+ )
153
+
154
+
155
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
156
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
157
+
158
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
159
+ self.scaling_factor = scaling_factor
160
+ super().__init__(dim, max_position_embeddings, base, device)
161
+
162
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
163
+ self.max_seq_len_cached = seq_len
164
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
165
+ t = t / self.scaling_factor
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
171
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
172
+
173
+
174
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
175
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
176
+
177
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
178
+ self.scaling_factor = scaling_factor
179
+ super().__init__(dim, max_position_embeddings, base, device)
180
+
181
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
182
+ self.max_seq_len_cached = seq_len
183
+
184
+ if seq_len > self.max_position_embeddings:
185
+ base = self.base * (
186
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
187
+ ) ** (self.dim / (self.dim - 2))
188
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
189
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
190
+
191
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
192
+
193
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
194
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
195
+ emb = torch.cat((freqs, freqs), dim=-1)
196
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
197
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
198
+
199
+
200
+ def rotate_half(x):
201
+ """Rotates half the hidden dims of the input."""
202
+ x1 = x[..., : x.shape[-1] // 2]
203
+ x2 = x[..., x.shape[-1] // 2 :]
204
+ return torch.cat((-x2, x1), dim=-1)
205
+
206
+
207
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
208
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
209
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
210
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
211
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
212
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
213
+ q_embed = (q * cos) + (rotate_half(q) * sin)
214
+ k_embed = (k * cos) + (rotate_half(k) * sin)
215
+ return q_embed, k_embed
216
+
217
+
218
+ class LlamaMLP(nn.Module):
219
+ def __init__(self, config):
220
+ super().__init__()
221
+ self.config = config
222
+ self.hidden_size = config.hidden_size
223
+ self.intermediate_size = config.intermediate_size
224
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
225
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
226
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
227
+ self.act_fn = ACT2FN[config.hidden_act]
228
+
229
+ def forward(self, x):
230
+ if self.config.pretraining_tp > 1:
231
+ slice = self.intermediate_size // self.config.pretraining_tp
232
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
233
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
234
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
235
+
236
+ gate_proj = torch.cat(
237
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
238
+ )
239
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
240
+
241
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
242
+ down_proj = [
243
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
244
+ ]
245
+ down_proj = sum(down_proj)
246
+ else:
247
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
248
+
249
+ return down_proj
250
+
251
+
252
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
253
+ """
254
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
255
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
256
+ """
257
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
258
+ if n_rep == 1:
259
+ return hidden_states
260
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
261
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
262
+
263
+
264
+ class LlamaAttention(nn.Module):
265
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
266
+
267
+ def __init__(self, config: LlamaConfig):
268
+ super().__init__()
269
+ self.config = config
270
+ self.hidden_size = config.hidden_size
271
+ self.num_heads = config.num_attention_heads
272
+ self.head_dim = self.hidden_size // self.num_heads
273
+ self.num_key_value_heads = config.num_key_value_heads
274
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
275
+ self.max_position_embeddings = config.max_position_embeddings
276
+ self.rope_theta = config.rope_theta
277
+
278
+ if (self.head_dim * self.num_heads) != self.hidden_size:
279
+ raise ValueError(
280
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
281
+ f" and `num_heads`: {self.num_heads})."
282
+ )
283
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
284
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
285
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
286
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
287
+ self._init_rope()
288
+
289
+ def _init_rope(self):
290
+ if self.config.rope_scaling is None:
291
+ self.rotary_emb = LlamaRotaryEmbedding(
292
+ self.head_dim,
293
+ max_position_embeddings=self.max_position_embeddings,
294
+ base=self.rope_theta,
295
+ )
296
+ else:
297
+ scaling_type = self.config.rope_scaling["type"]
298
+ scaling_factor = self.config.rope_scaling["factor"]
299
+ if scaling_type == "linear":
300
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
301
+ self.head_dim,
302
+ max_position_embeddings=self.max_position_embeddings,
303
+ scaling_factor=scaling_factor,
304
+ base=self.rope_theta,
305
+ )
306
+ elif scaling_type == "dynamic":
307
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
308
+ self.head_dim,
309
+ max_position_embeddings=self.max_position_embeddings,
310
+ scaling_factor=scaling_factor,
311
+ base=self.rope_theta,
312
+ )
313
+ else:
314
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
315
+
316
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
317
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
318
+
319
+ def forward(
320
+ self,
321
+ hidden_states: torch.Tensor,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ position_ids: Optional[torch.LongTensor] = None,
324
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
325
+ output_attentions: bool = False,
326
+ use_cache: bool = False,
327
+ padding_mask: Optional[torch.LongTensor] = None,
328
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
329
+ bsz, q_len, _ = hidden_states.size()
330
+
331
+ if self.config.pretraining_tp > 1:
332
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
333
+ query_slices = self.q_proj.weight.split(
334
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
335
+ )
336
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
337
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
338
+
339
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
340
+ query_states = torch.cat(query_states, dim=-1)
341
+
342
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
343
+ key_states = torch.cat(key_states, dim=-1)
344
+
345
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
346
+ value_states = torch.cat(value_states, dim=-1)
347
+
348
+ else:
349
+ query_states = self.q_proj(hidden_states)
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
354
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
355
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
356
+
357
+ kv_seq_len = key_states.shape[-2]
358
+ if past_key_value is not None:
359
+ kv_seq_len += past_key_value[0].shape[-2]
360
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
361
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
362
+
363
+ if past_key_value is not None:
364
+ # reuse k, v, self_attention
365
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
366
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
367
+
368
+ past_key_value = (key_states, value_states) if use_cache else None
369
+
370
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
371
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
372
+
373
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
374
+
375
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
376
+ raise ValueError(
377
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
378
+ f" {attn_weights.size()}"
379
+ )
380
+
381
+ if attention_mask is not None:
382
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
383
+ raise ValueError(
384
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
385
+ )
386
+ attn_weights = attn_weights + attention_mask
387
+
388
+ # upcast attention to fp32
389
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
390
+ attn_output = torch.matmul(attn_weights, value_states)
391
+
392
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
393
+ raise ValueError(
394
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
395
+ f" {attn_output.size()}"
396
+ )
397
+
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+
400
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
401
+
402
+ if self.config.pretraining_tp > 1:
403
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
404
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
405
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
406
+ else:
407
+ attn_output = self.o_proj(attn_output)
408
+
409
+ if not output_attentions:
410
+ attn_weights = None
411
+
412
+ return attn_output, attn_weights, past_key_value
413
+
414
+
415
+ class LlamaFlashAttention2(LlamaAttention):
416
+ """
417
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
418
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
419
+ flash attention and deal with padding tokens in case the input contains any of them.
420
+ """
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ padding_mask: Optional[torch.LongTensor] = None,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
432
+ # LlamaFlashAttention2 attention does not support output_attentions
433
+ output_attentions = False
434
+
435
+ bsz, q_len, _ = hidden_states.size()
436
+
437
+ query_states = self.q_proj(hidden_states)
438
+ key_states = self.k_proj(hidden_states)
439
+ value_states = self.v_proj(hidden_states)
440
+
441
+ # Flash attention requires the input to have the shape
442
+ # batch_size x seq_length x head_dime x hidden_dim
443
+ # therefore we just need to keep the original shape
444
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
445
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
446
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
447
+
448
+ kv_seq_len = key_states.shape[-2]
449
+ if past_key_value is not None:
450
+ kv_seq_len += past_key_value[0].shape[-2]
451
+
452
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
453
+
454
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
455
+
456
+ if past_key_value is not None:
457
+ # reuse k, v, self_attention
458
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
459
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
460
+
461
+ past_key_value = (key_states, value_states) if use_cache else None
462
+
463
+ query_states = query_states.transpose(1, 2)
464
+ key_states = key_states.transpose(1, 2)
465
+ value_states = value_states.transpose(1, 2)
466
+
467
+ # TODO: llama does not have dropout in the config??
468
+ # It is recommended to use dropout with FA according to the docs
469
+ # when training.
470
+ dropout_rate = 0.0 # if not self.training else self.attn_dropout
471
+
472
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
473
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
474
+ # cast them back in float16 just to be sure everything works as expected.
475
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
476
+ # in fp32. (LlamaRMSNorm handles it correctly)
477
+ input_dtype = query_states.dtype
478
+ if input_dtype == torch.float32:
479
+ logger.warning_once(
480
+ "The input hidden states seems to be silently casted in float32, this might be related to"
481
+ " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
482
+ " float16."
483
+ )
484
+
485
+ query_states = query_states.to(torch.float16)
486
+ key_states = key_states.to(torch.float16)
487
+ value_states = value_states.to(torch.float16)
488
+
489
+ attn_output = self._flash_attention_forward(
490
+ query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
491
+ )
492
+
493
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
494
+ attn_output = self.o_proj(attn_output)
495
+
496
+ if not output_attentions:
497
+ attn_weights = None
498
+
499
+ return attn_output, attn_weights, past_key_value
500
+
501
+ def _flash_attention_forward(
502
+ self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
503
+ ):
504
+ """
505
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
506
+ first unpad the input, then computes the attention scores and pad the final attention scores.
507
+
508
+ Args:
509
+ query_states (`torch.Tensor`):
510
+ Input query states to be passed to Flash Attention API
511
+ key_states (`torch.Tensor`):
512
+ Input key states to be passed to Flash Attention API
513
+ value_states (`torch.Tensor`):
514
+ Input value states to be passed to Flash Attention API
515
+ padding_mask (`torch.Tensor`):
516
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
517
+ position of padding tokens and 1 for the position of non-padding tokens.
518
+ dropout (`int`, *optional*):
519
+ Attention dropout
520
+ softmax_scale (`float`, *optional*):
521
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
522
+ """
523
+ # Contains at least one padding token in the sequence
524
+ if padding_mask is not None:
525
+ batch_size = query_states.shape[0]
526
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
527
+ query_states, key_states, value_states, padding_mask, query_length
528
+ )
529
+
530
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
531
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
532
+
533
+ attn_output_unpad = flash_attn_varlen_func(
534
+ query_states,
535
+ key_states,
536
+ value_states,
537
+ cu_seqlens_q=cu_seqlens_q,
538
+ cu_seqlens_k=cu_seqlens_k,
539
+ max_seqlen_q=max_seqlen_in_batch_q,
540
+ max_seqlen_k=max_seqlen_in_batch_k,
541
+ dropout_p=dropout,
542
+ softmax_scale=softmax_scale,
543
+ causal=True,
544
+ )
545
+
546
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
547
+ else:
548
+ attn_output = flash_attn_func(
549
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
550
+ )
551
+
552
+ return attn_output
553
+
554
+ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
555
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
556
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
557
+
558
+ key_layer = index_first_axis(
559
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
560
+ )
561
+ value_layer = index_first_axis(
562
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
563
+ )
564
+ if query_length == kv_seq_len:
565
+ query_layer = index_first_axis(
566
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
567
+ )
568
+ cu_seqlens_q = cu_seqlens_k
569
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
570
+ indices_q = indices_k
571
+ elif query_length == 1:
572
+ max_seqlen_in_batch_q = 1
573
+ cu_seqlens_q = torch.arange(
574
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
575
+ ) # There is a memcpy here, that is very bad.
576
+ indices_q = cu_seqlens_q[:-1]
577
+ query_layer = query_layer.squeeze(1)
578
+ else:
579
+ # The -q_len: slice assumes left padding.
580
+ padding_mask = padding_mask[:, -query_length:]
581
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
582
+
583
+ return (
584
+ query_layer,
585
+ key_layer,
586
+ value_layer,
587
+ indices_q,
588
+ (cu_seqlens_q, cu_seqlens_k),
589
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
590
+ )
591
+
592
+
593
+ class LlamaDecoderLayer(nn.Module):
594
+ def __init__(self, config: LlamaConfig):
595
+ super().__init__()
596
+ self.hidden_size = config.hidden_size
597
+ self.self_attn = (
598
+ LlamaAttention(config=config)
599
+ if not getattr(config, "_flash_attn_2_enabled", False)
600
+ else LlamaFlashAttention2(config=config)
601
+ )
602
+ self.mlp = LlamaMLP(config)
603
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
604
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.Tensor,
609
+ attention_mask: Optional[torch.Tensor] = None,
610
+ position_ids: Optional[torch.LongTensor] = None,
611
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
612
+ output_attentions: Optional[bool] = False,
613
+ use_cache: Optional[bool] = False,
614
+ padding_mask: Optional[torch.LongTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
616
+ """
617
+ Args:
618
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
619
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
620
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
621
+ output_attentions (`bool`, *optional*):
622
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
623
+ returned tensors for more detail.
624
+ use_cache (`bool`, *optional*):
625
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
626
+ (see `past_key_values`).
627
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
628
+ """
629
+
630
+ residual = hidden_states
631
+
632
+ hidden_states = self.input_layernorm(hidden_states)
633
+
634
+ # Self Attention
635
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
636
+ hidden_states=hidden_states,
637
+ attention_mask=attention_mask,
638
+ position_ids=position_ids,
639
+ past_key_value=past_key_value,
640
+ output_attentions=output_attentions,
641
+ use_cache=use_cache,
642
+ padding_mask=padding_mask,
643
+ )
644
+ hidden_states = residual + hidden_states
645
+
646
+ # Fully Connected
647
+ residual = hidden_states
648
+ hidden_states = self.post_attention_layernorm(hidden_states)
649
+ hidden_states = self.mlp(hidden_states)
650
+ hidden_states = residual + hidden_states
651
+
652
+ outputs = (hidden_states,)
653
+
654
+ if output_attentions:
655
+ outputs += (self_attn_weights,)
656
+
657
+ if use_cache:
658
+ outputs += (present_key_value,)
659
+
660
+ return outputs
661
+
662
+
663
+ LLAMA_START_DOCSTRING = r"""
664
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
665
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
666
+ etc.)
667
+
668
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
669
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
670
+ and behavior.
671
+
672
+ Parameters:
673
+ config ([`LlamaConfig`]):
674
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
675
+ load the weights associated with the model, only the configuration. Check out the
676
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
677
+ """
678
+
679
+
680
+ @add_start_docstrings(
681
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
682
+ LLAMA_START_DOCSTRING,
683
+ )
684
+ class LlamaPreTrainedModel(PreTrainedModel):
685
+ config_class = LlamaConfig
686
+ base_model_prefix = "model"
687
+ supports_gradient_checkpointing = True
688
+ _no_split_modules = ["LlamaDecoderLayer"]
689
+ _skip_keys_device_placement = "past_key_values"
690
+ _supports_flash_attn_2 = True
691
+
692
+ def _init_weights(self, module):
693
+ std = self.config.initializer_range
694
+ if isinstance(module, nn.Linear):
695
+ module.weight.data.normal_(mean=0.0, std=std)
696
+ if module.bias is not None:
697
+ module.bias.data.zero_()
698
+ elif isinstance(module, nn.Embedding):
699
+ module.weight.data.normal_(mean=0.0, std=std)
700
+ if module.padding_idx is not None:
701
+ module.weight.data[module.padding_idx].zero_()
702
+
703
+ def _set_gradient_checkpointing(self, module, value=False):
704
+ if isinstance(module, LlamaModel):
705
+ module.gradient_checkpointing = value
706
+
707
+
708
+ LLAMA_INPUTS_DOCSTRING = r"""
709
+ Args:
710
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
711
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
712
+ it.
713
+
714
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
715
+ [`PreTrainedTokenizer.__call__`] for details.
716
+
717
+ [What are input IDs?](../glossary#input-ids)
718
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
719
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
720
+
721
+ - 1 for tokens that are **not masked**,
722
+ - 0 for tokens that are **masked**.
723
+
724
+ [What are attention masks?](../glossary#attention-mask)
725
+
726
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
727
+ [`PreTrainedTokenizer.__call__`] for details.
728
+
729
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
730
+ `past_key_values`).
731
+
732
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
733
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
734
+ information on the default strategy.
735
+
736
+ - 1 indicates the head is **not masked**,
737
+ - 0 indicates the head is **masked**.
738
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
739
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
740
+ config.n_positions - 1]`.
741
+
742
+ [What are position IDs?](../glossary#position-ids)
743
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
744
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
745
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
746
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
747
+
748
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
749
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
750
+
751
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
752
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
753
+ of shape `(batch_size, sequence_length)`.
754
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
755
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
756
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
757
+ model's internal embedding lookup matrix.
758
+ use_cache (`bool`, *optional*):
759
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
760
+ `past_key_values`).
761
+ output_attentions (`bool`, *optional*):
762
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
763
+ tensors for more detail.
764
+ output_hidden_states (`bool`, *optional*):
765
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
766
+ more detail.
767
+ return_dict (`bool`, *optional*):
768
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
769
+ """
770
+
771
+
772
+ @add_start_docstrings(
773
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
774
+ LLAMA_START_DOCSTRING,
775
+ )
776
+ class LlamaModel(LlamaPreTrainedModel):
777
+ """
778
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
779
+
780
+ Args:
781
+ config: LlamaConfig
782
+ """
783
+
784
+ def __init__(self, config: LlamaConfig):
785
+ super().__init__(config)
786
+ self.padding_idx = config.pad_token_id
787
+ self.vocab_size = config.vocab_size
788
+
789
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
790
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
791
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
792
+
793
+ self.gradient_checkpointing = False
794
+ # Initialize weights and apply final processing
795
+ self.post_init()
796
+
797
+ def get_input_embeddings(self):
798
+ return self.embed_tokens
799
+
800
+ def set_input_embeddings(self, value):
801
+ self.embed_tokens = value
802
+
803
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
804
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
805
+ # create causal mask
806
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
807
+ combined_attention_mask = None
808
+ if input_shape[-1] > 1:
809
+ combined_attention_mask = _make_causal_mask(
810
+ input_shape,
811
+ inputs_embeds.dtype,
812
+ device=inputs_embeds.device,
813
+ past_key_values_length=past_key_values_length,
814
+ )
815
+
816
+ if attention_mask is not None:
817
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
818
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
819
+ inputs_embeds.device
820
+ )
821
+ combined_attention_mask = (
822
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
823
+ )
824
+
825
+ return combined_attention_mask
826
+
827
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
828
+ def forward(
829
+ self,
830
+ input_ids: torch.LongTensor = None,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ position_ids: Optional[torch.LongTensor] = None,
833
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
834
+ inputs_embeds: Optional[torch.FloatTensor] = None,
835
+ use_cache: Optional[bool] = None,
836
+ output_attentions: Optional[bool] = None,
837
+ output_hidden_states: Optional[bool] = None,
838
+ return_dict: Optional[bool] = None,
839
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
840
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
841
+ output_hidden_states = (
842
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
843
+ )
844
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
845
+
846
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
847
+
848
+ # retrieve input_ids and inputs_embeds
849
+ if input_ids is not None and inputs_embeds is not None:
850
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
851
+ elif input_ids is not None:
852
+ batch_size, seq_length = input_ids.shape
853
+ elif inputs_embeds is not None:
854
+ batch_size, seq_length, _ = inputs_embeds.shape
855
+ else:
856
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
857
+
858
+ seq_length_with_past = seq_length
859
+ past_key_values_length = 0
860
+
861
+ if past_key_values is not None:
862
+ past_key_values_length = past_key_values[0][0].shape[2]
863
+ seq_length_with_past = seq_length_with_past + past_key_values_length
864
+
865
+ if position_ids is None:
866
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
867
+ position_ids = torch.arange(
868
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
869
+ )
870
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
871
+ else:
872
+ position_ids = position_ids.view(-1, seq_length).long()
873
+
874
+ if inputs_embeds is None:
875
+ inputs_embeds = self.embed_tokens(input_ids)
876
+ # embed positions
877
+ if attention_mask is None:
878
+ attention_mask = torch.ones(
879
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
880
+ )
881
+ padding_mask = None
882
+ else:
883
+ if 0 in attention_mask:
884
+ padding_mask = attention_mask
885
+ else:
886
+ padding_mask = None
887
+
888
+ attention_mask = self._prepare_decoder_attention_mask(
889
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
890
+ )
891
+
892
+ hidden_states = inputs_embeds
893
+
894
+ if self.gradient_checkpointing and self.training:
895
+ if use_cache:
896
+ logger.warning_once(
897
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
898
+ )
899
+ use_cache = False
900
+
901
+ # decoder layers
902
+ all_hidden_states = () if output_hidden_states else None
903
+ all_self_attns = () if output_attentions else None
904
+ next_decoder_cache = () if use_cache else None
905
+
906
+ for idx, decoder_layer in enumerate(self.layers):
907
+ if output_hidden_states:
908
+ all_hidden_states += (hidden_states,)
909
+
910
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
911
+
912
+ if self.gradient_checkpointing and self.training:
913
+
914
+ def create_custom_forward(module):
915
+ def custom_forward(*inputs):
916
+ # None for past_key_value
917
+ return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
918
+
919
+ return custom_forward
920
+
921
+ layer_outputs = torch.utils.checkpoint.checkpoint(
922
+ create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
923
+ )
924
+ else:
925
+ layer_outputs = decoder_layer(
926
+ hidden_states,
927
+ attention_mask=attention_mask,
928
+ position_ids=position_ids,
929
+ past_key_value=past_key_value,
930
+ output_attentions=output_attentions,
931
+ use_cache=use_cache,
932
+ padding_mask=padding_mask,
933
+ )
934
+
935
+ hidden_states = layer_outputs[0]
936
+
937
+ if use_cache:
938
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
939
+
940
+ if output_attentions:
941
+ all_self_attns += (layer_outputs[1],)
942
+
943
+ hidden_states = self.norm(hidden_states)
944
+
945
+ # add hidden states from the last decoder layer
946
+ if output_hidden_states:
947
+ all_hidden_states += (hidden_states,)
948
+
949
+ next_cache = next_decoder_cache if use_cache else None
950
+ if not return_dict:
951
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
952
+ return BaseModelOutputWithPast(
953
+ last_hidden_state=hidden_states,
954
+ past_key_values=next_cache,
955
+ hidden_states=all_hidden_states,
956
+ attentions=all_self_attns,
957
+ )
958
+
959
+
960
+ class LlamaForCausalLM(LlamaPreTrainedModel):
961
+ _tied_weights_keys = ["lm_head.weight"]
962
+
963
+ def __init__(self, config):
964
+ super().__init__(config)
965
+ self.model = LlamaModel(config)
966
+ self.vocab_size = config.vocab_size
967
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
968
+
969
+ # Initialize weights and apply final processing
970
+ self.post_init()
971
+
972
+ def get_input_embeddings(self):
973
+ return self.model.embed_tokens
974
+
975
+ def set_input_embeddings(self, value):
976
+ self.model.embed_tokens = value
977
+
978
+ def get_output_embeddings(self):
979
+ return self.lm_head
980
+
981
+ def set_output_embeddings(self, new_embeddings):
982
+ self.lm_head = new_embeddings
983
+
984
+ def set_decoder(self, decoder):
985
+ self.model = decoder
986
+
987
+ def get_decoder(self):
988
+ return self.model
989
+
990
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
991
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
992
+ def forward(
993
+ self,
994
+ input_ids: torch.LongTensor = None,
995
+ attention_mask: Optional[torch.Tensor] = None,
996
+ position_ids: Optional[torch.LongTensor] = None,
997
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
998
+ inputs_embeds: Optional[torch.FloatTensor] = None,
999
+ labels: Optional[torch.LongTensor] = None,
1000
+ use_cache: Optional[bool] = None,
1001
+ output_attentions: Optional[bool] = None,
1002
+ output_hidden_states: Optional[bool] = None,
1003
+ return_dict: Optional[bool] = None,
1004
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1005
+ r"""
1006
+ Args:
1007
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1008
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1009
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1010
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1011
+
1012
+ Returns:
1013
+
1014
+ Example:
1015
+
1016
+ ```python
1017
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1018
+
1019
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1020
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1021
+
1022
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1023
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1024
+
1025
+ >>> # Generate
1026
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1027
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1028
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1029
+ ```"""
1030
+
1031
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1032
+ output_hidden_states = (
1033
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1034
+ )
1035
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1036
+
1037
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1038
+ outputs = self.model(
1039
+ input_ids=input_ids,
1040
+ attention_mask=attention_mask,
1041
+ position_ids=position_ids,
1042
+ past_key_values=past_key_values,
1043
+ inputs_embeds=inputs_embeds,
1044
+ use_cache=use_cache,
1045
+ output_attentions=output_attentions,
1046
+ output_hidden_states=output_hidden_states,
1047
+ return_dict=return_dict,
1048
+ )
1049
+
1050
+ hidden_states = outputs[0]
1051
+ if self.config.pretraining_tp > 1:
1052
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1053
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1054
+ logits = torch.cat(logits, dim=-1)
1055
+ else:
1056
+ logits = self.lm_head(hidden_states)
1057
+ logits = logits.float()
1058
+
1059
+ loss = None
1060
+ if labels is not None:
1061
+ # Shift so that tokens < n predict n
1062
+ shift_logits = logits[..., :-1, :].contiguous()
1063
+ shift_labels = labels[..., 1:].contiguous()
1064
+ # Flatten the tokens
1065
+ loss_fct = CrossEntropyLoss()
1066
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1067
+ shift_labels = shift_labels.view(-1)
1068
+ # Enable model parallelism
1069
+ shift_labels = shift_labels.to(shift_logits.device)
1070
+ loss = loss_fct(shift_logits, shift_labels)
1071
+
1072
+ if not return_dict:
1073
+ output = (logits,) + outputs[1:]
1074
+ return (loss,) + output if loss is not None else output
1075
+
1076
+ return CausalLMOutputWithPast(
1077
+ loss=loss,
1078
+ logits=logits,
1079
+ past_key_values=outputs.past_key_values,
1080
+ hidden_states=outputs.hidden_states,
1081
+ attentions=outputs.attentions,
1082
+ )
1083
+
1084
+ def prepare_inputs_for_generation(
1085
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1086
+ ):
1087
+ if past_key_values:
1088
+ input_ids = input_ids[:, -1:]
1089
+
1090
+ position_ids = kwargs.get("position_ids", None)
1091
+ if attention_mask is not None and position_ids is None:
1092
+ # create position_ids on the fly for batch generation
1093
+ position_ids = attention_mask.long().cumsum(-1) - 1
1094
+ position_ids.masked_fill_(attention_mask == 0, 1)
1095
+ if past_key_values:
1096
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1097
+
1098
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1099
+ if inputs_embeds is not None and past_key_values is None:
1100
+ model_inputs = {"inputs_embeds": inputs_embeds}
1101
+ else:
1102
+ model_inputs = {"input_ids": input_ids}
1103
+
1104
+ model_inputs.update(
1105
+ {
1106
+ "position_ids": position_ids,
1107
+ "past_key_values": past_key_values,
1108
+ "use_cache": kwargs.get("use_cache"),
1109
+ "attention_mask": attention_mask,
1110
+ }
1111
+ )
1112
+ return model_inputs
1113
+
1114
+ @staticmethod
1115
+ def _reorder_cache(past_key_values, beam_idx):
1116
+ reordered_past = ()
1117
+ for layer_past in past_key_values:
1118
+ reordered_past += (
1119
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1120
+ )
1121
+ return reordered_past
1122
+
1123
+
1124
+ @add_start_docstrings(
1125
+ """
1126
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1127
+
1128
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1129
+ (e.g. GPT-2) do.
1130
+
1131
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1132
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1133
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1134
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1135
+ each row of the batch).
1136
+ """,
1137
+ LLAMA_START_DOCSTRING,
1138
+ )
1139
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1140
+ def __init__(self, config):
1141
+ super().__init__(config)
1142
+ self.num_labels = config.num_labels
1143
+ self.model = LlamaModel(config)
1144
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1145
+
1146
+ # Initialize weights and apply final processing
1147
+ self.post_init()
1148
+
1149
+ def get_input_embeddings(self):
1150
+ return self.model.embed_tokens
1151
+
1152
+ def set_input_embeddings(self, value):
1153
+ self.model.embed_tokens = value
1154
+
1155
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1156
+ def forward(
1157
+ self,
1158
+ input_ids: torch.LongTensor = None,
1159
+ attention_mask: Optional[torch.Tensor] = None,
1160
+ position_ids: Optional[torch.LongTensor] = None,
1161
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1162
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1163
+ labels: Optional[torch.LongTensor] = None,
1164
+ use_cache: Optional[bool] = None,
1165
+ output_attentions: Optional[bool] = None,
1166
+ output_hidden_states: Optional[bool] = None,
1167
+ return_dict: Optional[bool] = None,
1168
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1169
+ r"""
1170
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1171
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1172
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1173
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1174
+ """
1175
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1176
+
1177
+ transformer_outputs = self.model(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ past_key_values=past_key_values,
1182
+ inputs_embeds=inputs_embeds,
1183
+ use_cache=use_cache,
1184
+ output_attentions=output_attentions,
1185
+ output_hidden_states=output_hidden_states,
1186
+ return_dict=return_dict,
1187
+ )
1188
+ hidden_states = transformer_outputs[0]
1189
+ logits = self.score(hidden_states)
1190
+
1191
+ if input_ids is not None:
1192
+ batch_size = input_ids.shape[0]
1193
+ else:
1194
+ batch_size = inputs_embeds.shape[0]
1195
+
1196
+ if self.config.pad_token_id is None and batch_size != 1:
1197
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1198
+ if self.config.pad_token_id is None:
1199
+ sequence_lengths = -1
1200
+ else:
1201
+ if input_ids is not None:
1202
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
1203
+ logits.device
1204
+ )
1205
+ else:
1206
+ sequence_lengths = -1
1207
+
1208
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1209
+
1210
+ loss = None
1211
+ if labels is not None:
1212
+ labels = labels.to(logits.device)
1213
+ if self.config.problem_type is None:
1214
+ if self.num_labels == 1:
1215
+ self.config.problem_type = "regression"
1216
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1217
+ self.config.problem_type = "single_label_classification"
1218
+ else:
1219
+ self.config.problem_type = "multi_label_classification"
1220
+
1221
+ if self.config.problem_type == "regression":
1222
+ loss_fct = MSELoss()
1223
+ if self.num_labels == 1:
1224
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1225
+ else:
1226
+ loss = loss_fct(pooled_logits, labels)
1227
+ elif self.config.problem_type == "single_label_classification":
1228
+ loss_fct = CrossEntropyLoss()
1229
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1230
+ elif self.config.problem_type == "multi_label_classification":
1231
+ loss_fct = BCEWithLogitsLoss()
1232
+ loss = loss_fct(pooled_logits, labels)
1233
+ if not return_dict:
1234
+ output = (pooled_logits,) + transformer_outputs[1:]
1235
+ return ((loss,) + output) if loss is not None else output
1236
+
1237
+ return SequenceClassifierOutputWithPast(
1238
+ loss=loss,
1239
+ logits=pooled_logits,
1240
+ past_key_values=transformer_outputs.past_key_values,
1241
+ hidden_states=transformer_outputs.hidden_states,
1242
+ attentions=transformer_outputs.attentions,
1243
+ )
codeclm/models/llama/tokenization_llama.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ """Tokenization classes for LLaMA."""
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.convert_slow_tokenizer import import_protobuf
29
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30
+ from transformers.utils import logging
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers.tokenization_utils_base import TextInput
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
39
+
40
+ PRETRAINED_VOCAB_FILES_MAP = {
41
+ "vocab_file": {
42
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
43
+ },
44
+ "tokenizer_file": {
45
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
46
+ },
47
+ }
48
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
49
+ "hf-internal-testing/llama-tokenizer": 2048,
50
+ }
51
+ SPIECE_UNDERLINE = "▁"
52
+
53
+ B_INST, E_INST = "[INST]", "[/INST]"
54
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
55
+
56
+ # fmt: off
57
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
58
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
59
+ that your responses are socially unbiased and positive in nature.
60
+
61
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
62
+ correct. If you don't know the answer to a question, please don't share false information."""
63
+ # fmt: on
64
+
65
+
66
+ class LlamaTokenizer(PreTrainedTokenizer):
67
+ """
68
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
69
+ no padding token in the original model.
70
+
71
+ Args:
72
+ vocab_file (`str`):
73
+ Path to the vocabulary file.
74
+ legacy (`bool`, *optional*):
75
+ Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
76
+ and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
77
+ example:
78
+
79
+ - `legacy=True`:
80
+ ```python
81
+ >>> from transformers import T5Tokenizer
82
+
83
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
84
+ >>> tokenizer.encode("Hello <extra_id_0>.")
85
+ [8774, 32099, 3, 5, 1]
86
+ ```
87
+ - `legacy=False`:
88
+ ```python
89
+ >>> from transformers import T5Tokenizer
90
+
91
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
92
+ >>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
93
+ [8774, 32099, 5, 1]
94
+ ```
95
+ Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
96
+
97
+ """
98
+
99
+ vocab_files_names = VOCAB_FILES_NAMES
100
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
101
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
102
+ model_input_names = ["input_ids", "attention_mask"]
103
+
104
+ def __init__(
105
+ self,
106
+ vocab_file,
107
+ unk_token="<unk>",
108
+ bos_token="<s>",
109
+ eos_token="</s>",
110
+ pad_token=None,
111
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
112
+ add_bos_token=True,
113
+ add_eos_token=False,
114
+ clean_up_tokenization_spaces=False,
115
+ use_default_system_prompt=True,
116
+ spaces_between_special_tokens=False,
117
+ legacy=None,
118
+ **kwargs,
119
+ ):
120
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
121
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
122
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
123
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
124
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
125
+
126
+ if legacy is None:
127
+ logger.warning_once(
128
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
129
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
130
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
131
+ " means, and thouroughly read the reason why this was added as explained in"
132
+ " https://github.com/huggingface/transformers/pull/24565"
133
+ )
134
+ legacy = True
135
+
136
+ self.legacy = legacy
137
+ self.vocab_file = vocab_file
138
+ self.add_bos_token = add_bos_token
139
+ self.add_eos_token = add_eos_token
140
+ self.use_default_system_prompt = use_default_system_prompt
141
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
142
+
143
+ super().__init__(
144
+ bos_token=bos_token,
145
+ eos_token=eos_token,
146
+ unk_token=unk_token,
147
+ pad_token=pad_token,
148
+ add_bos_token=add_bos_token,
149
+ add_eos_token=add_eos_token,
150
+ sp_model_kwargs=self.sp_model_kwargs,
151
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
152
+ use_default_system_prompt=use_default_system_prompt,
153
+ spaces_between_special_tokens=spaces_between_special_tokens,
154
+ legacy=legacy,
155
+ **kwargs,
156
+ )
157
+
158
+ @property
159
+ def unk_token_length(self):
160
+ return len(self.sp_model.encode(str(self.unk_token)))
161
+
162
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
163
+ def get_spm_processor(self, from_slow=False):
164
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
165
+ if self.legacy or from_slow: # no dependency on protobuf
166
+ tokenizer.Load(self.vocab_file)
167
+ return tokenizer
168
+
169
+ with open(self.vocab_file, "rb") as f:
170
+ sp_model = f.read()
171
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
172
+ model = model_pb2.ModelProto.FromString(sp_model)
173
+ normalizer_spec = model_pb2.NormalizerSpec()
174
+ normalizer_spec.add_dummy_prefix = False
175
+ model.normalizer_spec.MergeFrom(normalizer_spec)
176
+ sp_model = model.SerializeToString()
177
+ tokenizer.LoadFromSerializedProto(sp_model)
178
+ return tokenizer
179
+
180
+ def __getstate__(self):
181
+ state = self.__dict__.copy()
182
+ state["sp_model"] = None
183
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
184
+ return state
185
+
186
+ def __setstate__(self, d):
187
+ self.__dict__ = d
188
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
189
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
190
+
191
+ @property
192
+ def vocab_size(self):
193
+ """Returns vocab size"""
194
+ return self.sp_model.get_piece_size()
195
+
196
+ def get_vocab(self):
197
+ """Returns vocab as a dict"""
198
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
199
+ vocab.update(self.added_tokens_encoder)
200
+ return vocab
201
+
202
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
203
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
204
+ """
205
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
206
+ first token is special.
207
+ """
208
+ if self.legacy or len(text) == 0:
209
+ return super().tokenize(text, **kwargs)
210
+
211
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
212
+
213
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
214
+ tokens = tokens[1:]
215
+ return tokens
216
+
217
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
218
+ def _tokenize(self, text, **kwargs):
219
+ """
220
+ Returns a tokenized string.
221
+
222
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
223
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
224
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
225
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
226
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
227
+ """
228
+ tokens = self.sp_model.encode(text, out_type=str)
229
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
230
+ return tokens
231
+
232
+ # 1. Encode string + prefix ex: "<unk> Hey"
233
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
234
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
235
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
236
+
237
+ def _convert_token_to_id(self, token):
238
+ """Converts a token (str) in an id using the vocab."""
239
+ return self.sp_model.piece_to_id(token)
240
+
241
+ def _convert_id_to_token(self, index):
242
+ """Converts an index (integer) in a token (str) using the vocab."""
243
+ token = self.sp_model.IdToPiece(index)
244
+ return token
245
+
246
+ def convert_tokens_to_string(self, tokens):
247
+ """Converts a sequence of tokens (string) in a single string."""
248
+ # since we manually add the prefix space, we have to remove it when decoding
249
+ if tokens[0].startswith(SPIECE_UNDERLINE):
250
+ tokens[0] = tokens[0][1:]
251
+
252
+ current_sub_tokens = []
253
+ out_string = ""
254
+ prev_is_special = False
255
+ for i, token in enumerate(tokens):
256
+ # make sure that special tokens are not decoded using sentencepiece model
257
+ if token in self.all_special_tokens:
258
+ if not prev_is_special and i != 0 and self.legacy:
259
+ out_string += " "
260
+ out_string += self.sp_model.decode(current_sub_tokens) + token
261
+ prev_is_special = True
262
+ current_sub_tokens = []
263
+ else:
264
+ current_sub_tokens.append(token)
265
+ prev_is_special = False
266
+ out_string += self.sp_model.decode(current_sub_tokens)
267
+ return out_string
268
+
269
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
270
+ """
271
+ Save the vocabulary and special tokens file to a directory.
272
+
273
+ Args:
274
+ save_directory (`str`):
275
+ The directory in which to save the vocabulary.
276
+
277
+ Returns:
278
+ `Tuple(str)`: Paths to the files saved.
279
+ """
280
+ if not os.path.isdir(save_directory):
281
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
282
+ return
283
+ out_vocab_file = os.path.join(
284
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
285
+ )
286
+
287
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
288
+ copyfile(self.vocab_file, out_vocab_file)
289
+ elif not os.path.isfile(self.vocab_file):
290
+ with open(out_vocab_file, "wb") as fi:
291
+ content_spiece_model = self.sp_model.serialized_model_proto()
292
+ fi.write(content_spiece_model)
293
+
294
+ return (out_vocab_file,)
295
+
296
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
297
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
298
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
299
+
300
+ output = bos_token_id + token_ids_0 + eos_token_id
301
+
302
+ if token_ids_1 is not None:
303
+ output = output + bos_token_id + token_ids_1 + eos_token_id
304
+
305
+ return output
306
+
307
+ def get_special_tokens_mask(
308
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
309
+ ) -> List[int]:
310
+ """
311
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
312
+ special tokens using the tokenizer `prepare_for_model` method.
313
+
314
+ Args:
315
+ token_ids_0 (`List[int]`):
316
+ List of IDs.
317
+ token_ids_1 (`List[int]`, *optional*):
318
+ Optional second list of IDs for sequence pairs.
319
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
320
+ Whether or not the token list is already formatted with special tokens for the model.
321
+
322
+ Returns:
323
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
324
+ """
325
+ if already_has_special_tokens:
326
+ return super().get_special_tokens_mask(
327
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
328
+ )
329
+
330
+ bos_token_id = [1] if self.add_bos_token else []
331
+ eos_token_id = [1] if self.add_eos_token else []
332
+
333
+ if token_ids_1 is None:
334
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
335
+ return (
336
+ bos_token_id
337
+ + ([0] * len(token_ids_0))
338
+ + eos_token_id
339
+ + bos_token_id
340
+ + ([0] * len(token_ids_1))
341
+ + eos_token_id
342
+ )
343
+
344
+ def create_token_type_ids_from_sequences(
345
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
346
+ ) -> List[int]:
347
+ """
348
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
349
+ sequence pair mask has the following format:
350
+
351
+ ```
352
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
353
+ | first sequence | second sequence |
354
+ ```
355
+
356
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
357
+
358
+ Args:
359
+ token_ids_0 (`List[int]`):
360
+ List of ids.
361
+ token_ids_1 (`List[int]`, *optional*):
362
+ Optional second list of IDs for sequence pairs.
363
+
364
+ Returns:
365
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
366
+ """
367
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
368
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
369
+
370
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
371
+
372
+ if token_ids_1 is not None:
373
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
374
+
375
+ return output
376
+
377
+ @property
378
+ def default_chat_template(self):
379
+ """
380
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
381
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
382
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
383
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
384
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
385
+ to fine-tune a model with more flexible role ordering!
386
+
387
+ The output should look something like:
388
+
389
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
390
+ <bos>[INST] Prompt [/INST]
391
+ """
392
+
393
+ template = (
394
+ "{% if messages[0]['role'] == 'system' %}"
395
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
396
+ "{% set system_message = messages[0]['content'] %}"
397
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
398
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
399
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
400
+ "{% else %}"
401
+ "{% set loop_messages = messages %}"
402
+ "{% set system_message = false %}"
403
+ "{% endif %}"
404
+ "{% for message in loop_messages %}" # Loop over all non-system messages
405
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
406
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
407
+ "{% endif %}"
408
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
409
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
410
+ "{% else %}"
411
+ "{% set content = message['content'] %}"
412
+ "{% endif %}"
413
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
414
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
415
+ "{% elif message['role'] == 'system' %}"
416
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
417
+ "{% elif message['role'] == 'assistant' %}"
418
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
419
+ "{% endif %}"
420
+ "{% endfor %}"
421
+ )
422
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
423
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
424
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
425
+
426
+ return template
codeclm/models/llama/tokenization_llama_fast.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ from shutil import copyfile
17
+ from typing import Optional, Tuple
18
+
19
+ from tokenizers import processors
20
+
21
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
22
+ from transformers.utils import is_sentencepiece_available, logging
23
+ from transformers.utils.versions import require_version
24
+
25
+
26
+ require_version("tokenizers>=0.13.3")
27
+
28
+ if is_sentencepiece_available():
29
+ from .tokenization_llama import LlamaTokenizer
30
+ else:
31
+ LlamaTokenizer = None
32
+
33
+ logger = logging.get_logger(__name__)
34
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {
37
+ "vocab_file": {
38
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
39
+ },
40
+ "tokenizer_file": {
41
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
42
+ },
43
+ }
44
+ B_INST, E_INST = "[INST]", "[/INST]"
45
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
46
+
47
+ # fmt: off
48
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
49
+ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
50
+ that your responses are socially unbiased and positive in nature.
51
+
52
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
53
+ correct. If you don't know the answer to a question, please don't share false information."""
54
+ # fmt: on
55
+
56
+
57
+ class LlamaTokenizerFast(PreTrainedTokenizerFast):
58
+ """
59
+ Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
60
+
61
+ This uses notably ByteFallback and no normalization.
62
+
63
+ ```
64
+ from transformers import LlamaTokenizerFast
65
+
66
+ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
67
+ tokenizer.encode("Hello this is a test")
68
+ >>> [1, 15043, 445, 338, 263, 1243]
69
+ ```
70
+
71
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
72
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
73
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
74
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
75
+
76
+
77
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
78
+ refer to this superclass for more information regarding those methods.
79
+
80
+ Args:
81
+ vocab_file (`str`):
82
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
83
+ contains the vocabulary necessary to instantiate a tokenizer.
84
+ tokenizer_file (`str`):
85
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
86
+ contains everything needed to load the tokenizer.
87
+
88
+ clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
89
+ Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
90
+ spaces.
91
+
92
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
93
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
94
+
95
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
96
+ The end of sequence token.
97
+
98
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
99
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
100
+ token instead.
101
+ """
102
+
103
+ vocab_files_names = VOCAB_FILES_NAMES
104
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
105
+ slow_tokenizer_class = LlamaTokenizer
106
+ padding_side = "left"
107
+ model_input_names = ["input_ids", "attention_mask"]
108
+
109
+ def __init__(
110
+ self,
111
+ vocab_file=None,
112
+ tokenizer_file=None,
113
+ clean_up_tokenization_spaces=False,
114
+ unk_token="<unk>",
115
+ bos_token="<s>",
116
+ eos_token="</s>",
117
+ add_bos_token=True,
118
+ add_eos_token=False,
119
+ use_default_system_prompt=True,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(
123
+ vocab_file=vocab_file,
124
+ tokenizer_file=tokenizer_file,
125
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
126
+ unk_token=unk_token,
127
+ bos_token=bos_token,
128
+ eos_token=eos_token,
129
+ use_default_system_prompt=use_default_system_prompt,
130
+ **kwargs,
131
+ )
132
+ self._add_bos_token = add_bos_token
133
+ self._add_eos_token = add_eos_token
134
+ self.update_post_processor()
135
+ self.use_default_system_prompt = use_default_system_prompt
136
+ self.vocab_file = vocab_file
137
+
138
+ @property
139
+ def can_save_slow_tokenizer(self) -> bool:
140
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
141
+
142
+ def update_post_processor(self):
143
+ """
144
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
145
+ """
146
+ bos = self.bos_token
147
+ bos_token_id = self.bos_token_id
148
+
149
+ eos = self.eos_token
150
+ eos_token_id = self.eos_token_id
151
+
152
+ single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
153
+ pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
154
+
155
+ special_tokens = []
156
+ if self.add_bos_token:
157
+ special_tokens.append((bos, bos_token_id))
158
+ if self.add_eos_token:
159
+ special_tokens.append((eos, eos_token_id))
160
+ self._tokenizer.post_processor = processors.TemplateProcessing(
161
+ single=single, pair=pair, special_tokens=special_tokens
162
+ )
163
+
164
+ @property
165
+ def add_eos_token(self):
166
+ return self._add_eos_token
167
+
168
+ @property
169
+ def add_bos_token(self):
170
+ return self._add_bos_token
171
+
172
+ @add_eos_token.setter
173
+ def add_eos_token(self, value):
174
+ self._add_eos_token = value
175
+ self.update_post_processor()
176
+
177
+ @add_bos_token.setter
178
+ def add_bos_token(self, value):
179
+ self._add_bos_token = value
180
+ self.update_post_processor()
181
+
182
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
183
+ if not self.can_save_slow_tokenizer:
184
+ raise ValueError(
185
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
186
+ "tokenizer."
187
+ )
188
+
189
+ if not os.path.isdir(save_directory):
190
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
191
+ return
192
+ out_vocab_file = os.path.join(
193
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
194
+ )
195
+
196
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
197
+ copyfile(self.vocab_file, out_vocab_file)
198
+
199
+ return (out_vocab_file,)
200
+
201
+ @property
202
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
203
+ def default_chat_template(self):
204
+ """
205
+ LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
206
+ Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
207
+ user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
208
+ rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
209
+ results in an unusual token ordering when it is present. This template should definitely be changed if you wish
210
+ to fine-tune a model with more flexible role ordering!
211
+
212
+ The output should look something like:
213
+
214
+ <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
215
+ <bos>[INST] Prompt [/INST]
216
+ """
217
+
218
+ template = (
219
+ "{% if messages[0]['role'] == 'system' %}"
220
+ "{% set loop_messages = messages[1:] %}" # Extract system message if it's present
221
+ "{% set system_message = messages[0]['content'] %}"
222
+ "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
223
+ "{% set loop_messages = messages %}" # Or use the default system message if the flag is set
224
+ "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
225
+ "{% else %}"
226
+ "{% set loop_messages = messages %}"
227
+ "{% set system_message = false %}"
228
+ "{% endif %}"
229
+ "{% for message in loop_messages %}" # Loop over all non-system messages
230
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
231
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
232
+ "{% endif %}"
233
+ "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
234
+ "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
235
+ "{% else %}"
236
+ "{% set content = message['content'] %}"
237
+ "{% endif %}"
238
+ "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
239
+ "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
240
+ "{% elif message['role'] == 'system' %}"
241
+ "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
242
+ "{% elif message['role'] == 'assistant' %}"
243
+ "{{ ' ' + content.strip() + ' ' + eos_token }}"
244
+ "{% endif %}"
245
+ "{% endfor %}"
246
+ )
247
+ template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
248
+ default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
249
+ template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
250
+
251
+ return template
252
+
253
+ # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
254
+ # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
255
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
256
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
257
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
258
+
259
+ output = bos_token_id + token_ids_0 + eos_token_id
260
+
261
+ if token_ids_1 is not None:
262
+ output = output + bos_token_id + token_ids_1 + eos_token_id
263
+
264
+ return output
codeclm/models/lm_levo.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import math
4
+ import random
5
+ import torch.nn as nn
6
+ import typing as tp
7
+ import torch.nn.functional as F
8
+ from dataclasses import dataclass
9
+ from codeclm.models.levo import CausalLM, LlamaConfig
10
+ from codeclm.modules.streaming import StreamingModule
11
+ from codeclm.modules.conditioners import (
12
+ ConditioningAttributes,
13
+ AudioCondition,
14
+ ConditionType,
15
+ ConditionerProvider,
16
+ ConditionFuser,
17
+ ClassifierFreeGuidanceDropoutInference,
18
+ ClassifierFreeGuidanceDropout,
19
+ AttributeDropout,
20
+ )
21
+ from codeclm.utils.utils import create_norm_fn, init_layer, sample_top_k, sample_top_p, multinomial
22
+ from codeclm.modules.pattern import CodebooksPatternProvider
23
+ ConditionTensors = tp.Dict[str, ConditionType]
24
+
25
+ @dataclass
26
+ class LMOutput:
27
+ # The logits are already re-aligned with the input codes
28
+ # hence no extra shift is required, e.g. when computing CE
29
+ logits: torch.Tensor # [B, K, T, card]
30
+ mask: torch.Tensor # [B, K, T]
31
+
32
+
33
+ class LmModel(StreamingModule):
34
+ """Transformer-based language model on multiple streams of codes.
35
+
36
+ Args:
37
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
38
+ condition_provider (ConditioningProvider): Conditioning provider from metadata.
39
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
40
+ code_depth (int): Number of parallel streams to model.
41
+ code_size (int): Cardinality, vocabulary size.
42
+ dim (int): Dimension of the transformer encoder.
43
+ num_heads (int): Number of heads for the transformer encoder.
44
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
45
+ norm (str): Normalization method.
46
+ norm_first (bool): Use pre-norm instead of post-norm.
47
+ emb_lr (float, optional): Embedding-specific learning rate.
48
+ bias_proj (bool): Use bias for output projections.
49
+ weight_init (str, optional): Method for weight initialization.
50
+ depthwise_init (str, optional): Method for depthwise weight initialization.
51
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
52
+ cfg_dropout (float): Classifier-free guidance dropout.
53
+ cfg_coef (float): Classifier-free guidance coefficient.
54
+ attribute_dropout (dict): Attribute dropout probabilities.
55
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
56
+ **kwargs: Additional parameters for the transformer encoder.
57
+ """
58
+ def __init__(self,
59
+ pattern_provider: CodebooksPatternProvider,
60
+ condition_provider: ConditionerProvider,
61
+ fuser: ConditionFuser,
62
+ code_depth: int = 8,
63
+ code_size: int = 1024,
64
+ dim: int = 128,
65
+ intermediate_size: int = 4096,
66
+ num_heads: int = 8,
67
+ norm: str = 'layer_norm', norm_first: bool = False,
68
+ bias_proj: bool = True,
69
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
70
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
71
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
72
+ lm_type = 'Llama',
73
+ num_layers=16,
74
+ cfg = None,
75
+ **kwargs):
76
+ super().__init__()
77
+
78
+ self.cfg_coef = cfg_coef
79
+
80
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout,seed=random.randint(0, 9999))
81
+ self.att_dropout = AttributeDropout(p=attribute_dropout,seed=random.randint(0, 9999))
82
+ self.condition_provider = condition_provider
83
+ self.fuser = fuser
84
+ self.code_size = code_size + 1 # + EOS
85
+ input_emb_dim = code_size + 2 # EOP
86
+ self.code_depth = code_depth
87
+ self.dim = dim
88
+ self.cfg = cfg
89
+ self.pattern_provider = pattern_provider
90
+ self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
91
+ # if 'activation' in kwargs:
92
+ # kwargs['activation'] = get_activation_fn(kwargs['activation'])
93
+
94
+ model_cfg = LlamaConfig(
95
+ hidden_size=dim,
96
+ intermediate_size = intermediate_size,
97
+ num_attention_heads = num_heads,
98
+ num_hidden_layers = num_layers,
99
+ num_key_value_heads = num_heads,
100
+ vocab_size = self.code_size,
101
+ use_cache=False,
102
+ max_position_embeddings=8196,
103
+ _flash_attn_2_enabled=True,
104
+ rms_norm_eps= 1e-5,
105
+ rope_theta= 100000.0,
106
+ use_flash_attn_2=True,
107
+ attn_implementation="flash_attention_2"
108
+ )
109
+
110
+ self.transformer = CausalLM(model_cfg)
111
+ self.mlp = nn.Sequential(
112
+ nn.Linear(dim * 2, dim),
113
+ nn.GELU(),
114
+ nn.Linear(dim, dim)
115
+ )
116
+ self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim) #, lr=emb_lr)
117
+ for _ in range(self.code_depth)])
118
+ sub_model_cfg = LlamaConfig(
119
+ hidden_size=dim,
120
+ intermediate_size = intermediate_size,
121
+ num_attention_heads = num_heads,
122
+ num_hidden_layers = 12,
123
+ num_key_value_heads = num_heads,
124
+ vocab_size = self.code_size,
125
+ use_cache=False,
126
+ max_position_embeddings=10000,
127
+ rms_norm_eps= 1e-5,
128
+ rope_theta= 500000.0,
129
+ _flash_attn_2_enabled=True,
130
+ use_flash_attn_2=True,
131
+ attn_implementation="flash_attention_2"
132
+ )
133
+ self.transformer2 = CausalLM(sub_model_cfg)
134
+ self.out_norm: tp.Optional[nn.Module] = None
135
+ if norm_first:
136
+ self.out_norm = create_norm_fn(norm, dim)
137
+ # enable EOS prediction
138
+ if code_depth > 1:
139
+ self.linears = nn.ModuleList([nn.Linear(dim, self.code_size, bias=False)
140
+ for _ in range(code_depth - 1)])
141
+
142
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
143
+ self._fsdp: tp.Optional[nn.Module]
144
+ self.__dict__['_fsdp'] = None
145
+
146
+ self.reset_streaming()
147
+
148
+ def _init_weights(self, weight_init: tp.Optional[str],
149
+ depthwise_init: tp.Optional[str], zero_bias_init: bool):
150
+ """Initialization of the transformer module weights.
151
+
152
+ Args:
153
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
154
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
155
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
156
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
157
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
158
+ """
159
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
160
+ assert depthwise_init is None or weight_init is not None, \
161
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
162
+ assert not zero_bias_init or weight_init is not None, \
163
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
164
+
165
+ if weight_init is None:
166
+ return
167
+
168
+ for emb_layer in self.emb:
169
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
170
+
171
+
172
+ @property
173
+ def special_token_id(self) -> int:
174
+ return self.code_size # 10001
175
+
176
+ @property
177
+ def eos_token_id(self) -> int:
178
+ return self.code_size-1 # 10000
179
+
180
+ @torch.no_grad()
181
+ def prepare_condition_tensors(self,
182
+ batch_size = 1,
183
+ text: tp.Optional[tp.List[str]] = None,
184
+ descriptions: tp.Optional[tp.List[str]] = None,
185
+ audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None,
186
+ prepare_null_condition = False,
187
+ ):
188
+ if self.training:
189
+ attributes = []
190
+ for i in range(batch_size):
191
+ attr = ConditioningAttributes()
192
+ if 'description' in self.condition_provider.conditioners:
193
+ attr["text"]["description"] = ""
194
+ if text is not None:
195
+ attr["text"]["description"] = text[i]
196
+ if 'prompt_audio' in self.condition_provider.conditioners:
197
+ mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1)
198
+ audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1)
199
+ mask = mask.repeat(1, 1, audio_qt_seq.shape[-1])
200
+ audio_qt_seq[mask] = 16385
201
+ attr["audio"]['prompt_audio'] = AudioCondition(
202
+ wav=audio_qt_seq.long(),
203
+ length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
204
+ sample_rate=[self.cfg.sample_rate],)
205
+ if 'type_info' in self.condition_provider.conditioners:
206
+ attr["text"]["type_info"] = ""
207
+ if descriptions is not None:
208
+ attr["text"]["type_info"] = descriptions[i]
209
+ attributes.append(attr)
210
+ # print("before cfg dropout", attributes)
211
+ attributes = self.cfg_dropout(attributes) # drop ALL conditions
212
+ # print("after cfg dropout", attributes)
213
+ attributes = self.att_dropout(attributes) # selectively drop some attributes (text, wav, or more fine-grained)
214
+ # print("after attribute dropout", attributes)
215
+ # attribute to discrete tokenized ids
216
+ tokenized = self.condition_provider.tokenize(attributes)
217
+ # print("after tokenize", attributes)
218
+ # discrete tokenized ids to continuous embeddings
219
+ condition_tensors = self.condition_provider(tokenized)
220
+ else:
221
+ conditions = []
222
+ for i in range(batch_size):
223
+ attr = ConditioningAttributes()
224
+ if 'description' in self.condition_provider.conditioners:
225
+ attr["text"]["description"] = ""
226
+ if text is not None:
227
+ attr["text"]["description"] = text[i]
228
+ if 'prompt_audio' in self.condition_provider.conditioners:
229
+ mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1)
230
+ audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1)
231
+ mask = mask.repeat(1, 1, audio_qt_seq.shape[-1])
232
+ audio_qt_seq[mask] = 16385
233
+ attr["audio"]['prompt_audio'] = AudioCondition(
234
+ wav=audio_qt_seq.long().cuda(),
235
+ length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
236
+ sample_rate=[self.cfg.sample_rate],)
237
+ if 'type_info' in self.condition_provider.conditioners:
238
+ attr["text"]["type_info"] = ""
239
+ if descriptions is not None:
240
+ attr["text"]["type_info"] = descriptions[i]
241
+ conditions.append(attr)
242
+ print("conditions", conditions)
243
+ if prepare_null_condition:
244
+ cfg_inference = ClassifierFreeGuidanceDropoutInference()
245
+ null_conditions = cfg_inference(conditions, condition_types=["audio", "text"],
246
+ customized=None)
247
+ conditions = conditions + null_conditions
248
+ tokenized_conditions = self.condition_provider.tokenize(conditions)
249
+ condition_tensors = self.condition_provider(tokenized_conditions)
250
+ return condition_tensors
251
+
252
+ def forward(self,
253
+ sequence: torch.Tensor,
254
+ condition_tensors: ConditionTensors) -> torch.Tensor:
255
+ """Apply language model on sequence and conditions.
256
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
257
+ S the sequence steps, return the logits with shape [B, card, K, S].
258
+
259
+ Args:
260
+ indices (torch.Tensor): Indices of the codes to model.
261
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
262
+ tensors, see `conditions`.
263
+ Returns:
264
+ torch.Tensor: Logits.
265
+ """
266
+
267
+ # import pdb; pdb.set_trace()
268
+ B, K, S = sequence.shape
269
+ assert K == self.code_depth, "Sequence shape must match the specified number of codebooks"
270
+ input_1 = self.emb[0](sequence[:, 0])
271
+ input_2 = sum([self.layer2_emb[k](sequence[:, k]) for k in range(1, K)])
272
+ fused_input1, fused_input2 = self.fuser(input_1, input_2, condition_tensors)
273
+ output = self.transformer(inputs_embeds=fused_input1,
274
+ use_cache=self._is_streaming,
275
+ past_key_values=self._streaming_state.get('past_key_values_1', None))
276
+ if self._is_streaming:
277
+ self._streaming_state['past_key_values_1'] = output.past_key_values
278
+ logits = output.logits # [B, S, card]
279
+ logits = logits.unsqueeze(1) # [B, 1, S, card]
280
+
281
+ # if self.out_norm:
282
+ # out = self.out_norm(out.to(self.out_norm.weight.data.dtype))
283
+ if K > 1:
284
+ fused_input2 = torch.cat([fused_input2, output.hidden_states], dim=-1)
285
+ fused_input2 = self.mlp(fused_input2)
286
+ output2 = self.transformer2(inputs_embeds=fused_input2,
287
+ use_cache=self._is_streaming,
288
+ past_key_values=self._streaming_state.get('past_key_values_2', None))
289
+ if self._is_streaming:
290
+ self._streaming_state['past_key_values_2'] = output2.past_key_values
291
+
292
+ res_logits = torch.stack([self.linears[k](output2.hidden_states) for k in range(K - 1)], dim=1) # [B, K, S, card] # [B, K, S, card]
293
+ logits = torch.cat([logits, res_logits], dim=1) # [B, K, S, card]
294
+
295
+ # remove the prefix from the model outputs
296
+ if len(self.fuser.fuse2cond['prepend']) > 0:
297
+ logits = logits[:, :, -S:, :]
298
+
299
+ return logits # [B, K, S, card]
300
+
301
+ def compute_predictions(self,
302
+ codes: torch.Tensor,
303
+ condition_tensors: tp.Optional[ConditionTensors] = None,
304
+ **kwargs,
305
+ ): # this function is called during training
306
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
307
+ forward using the specified codes interleaving pattern.
308
+
309
+ Args:
310
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
311
+ K the number of codebooks and T the number of timesteps.
312
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
313
+ tensors, see `conditions`.
314
+ Returns:
315
+ LMOutput: Language model outputs
316
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
317
+ i.e. the first item corresponds to logits to predict the first code, meaning that
318
+ no additional shifting of codes and logits is required.
319
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
320
+ Given the specified interleaving strategies, parts of the logits and codes should
321
+ not be considered as valid predictions because of invalid context.
322
+ """
323
+ B, K, T = codes.shape
324
+ codes = codes.contiguous()
325
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
326
+ pattern = self.pattern_provider.get_pattern(T)
327
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
328
+ codes, self.special_token_id, keep_only_valid_steps=False
329
+ )
330
+ model = self if self._fsdp is None else self._fsdp
331
+ logits = model(sequence_codes, condition_tensors) # [B, K, S, card]
332
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
333
+ # and provide the corresponding mask over invalid positions of tokens
334
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
335
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
336
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
337
+ logits, float('nan'), keep_only_valid_steps=False
338
+ )
339
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
340
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
341
+
342
+ return LMOutput(logits, logits_mask)
343
+
344
+ @torch.no_grad()
345
+ def generate(self, #
346
+ # conditions: tp.List[ConditioningAttributes] = [],
347
+ texts = None,
348
+ descriptions = None,
349
+ audio_qt_embs = None,
350
+ num_samples: tp.Optional[int] = None,
351
+ max_gen_len: int = 256,
352
+ use_sampling: bool = True,
353
+ temp: float = 1.0,
354
+ top_k: int = 250,
355
+ top_p: float = 0.0,
356
+ cfg_coef: tp.Optional[float] = None,
357
+ check: bool = False,
358
+ record_tokens: bool = True,
359
+ record_window: int = 150
360
+ ) -> torch.Tensor:
361
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
362
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
363
+
364
+ Args:
365
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
366
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
367
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
368
+ max_gen_len (int): Maximum generation length.
369
+ use_sampling (bool): Whether to use a sampling strategy or not.
370
+ temp (float): Sampling temperature.
371
+ top_k (int): K for "top-k" sampling.
372
+ top_p (float): P for "top-p" sampling.
373
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
374
+ check (bool): Whether to apply further checks on generated sequence.
375
+ callback (Callback, optional): Callback function to report generation progress.
376
+ Returns:
377
+ torch.Tensor: Generated tokens.
378
+ """
379
+ assert not self.training, "generation shouldn't be used in training mode."
380
+ first_param = next(iter(self.parameters()))
381
+ device = first_param.device
382
+
383
+ # 1) Check input shapes are consistent
384
+ possible_num_samples = []
385
+ if num_samples is not None:
386
+ possible_num_samples.append(num_samples)
387
+ elif texts:
388
+ possible_num_samples.append(len(texts))
389
+ elif audio_qt_embs:
390
+ possible_num_samples.append(len(audio_qt_embs))
391
+ else:
392
+ possible_num_samples.append(1)
393
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
394
+ num_samples = possible_num_samples[0]
395
+ condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, descriptions=descriptions, audio_qt_emb=audio_qt_embs, prepare_null_condition=True)
396
+ # 3) Prepare token pool
397
+ record_token_pool = None
398
+ if record_tokens:
399
+ record_token_pool = []
400
+
401
+ # 4) set up startoff patterns
402
+ start_offset = 0
403
+ assert start_offset < max_gen_len, f"{start_offset}, {max_gen_len}"
404
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
405
+ # this token is used as default value for codes that are not generated yet
406
+ unknown_token = -1
407
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
408
+ B = num_samples
409
+ gen_codes = torch.full((B, self.code_depth, max_gen_len),
410
+ unknown_token, dtype=torch.long, device=device)
411
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
412
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
413
+ output_codes = torch.full_like(gen_sequence, self.code_size)
414
+ # retrieve the start_offset in the sequence:
415
+ # it is the first sequence step that contains the `start_offset` timestep
416
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
417
+ assert start_offset_sequence is not None
418
+ is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
419
+ ignore_tokens = audio_qt_embs[0][0]
420
+ # 5) auto-regressive sampling
421
+ with self.streaming():
422
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
423
+ prev_offset = 0
424
+ for offset in range(start_offset_sequence, gen_sequence_len):
425
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
426
+ curr_sequence = gen_sequence[..., prev_offset:offset]
427
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
428
+ if check:
429
+ # check coherence between mask and sequence
430
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
431
+ # should never happen as gen_sequence is filled progressively
432
+ assert not (curr_sequence == unknown_token).any()
433
+ # sample next token from the model, next token shape is [B, K, 1]
434
+ next_token = self._sample_next_token(
435
+ curr_sequence, condition_tensors, use_sampling, temp, top_k, top_p,
436
+ cfg_coef=cfg_coef,
437
+ sampled_token_pool=record_token_pool[-record_window:] if record_tokens else None,
438
+ ignore_tokens = ignore_tokens
439
+ )
440
+ # ensure the tokens that should be masked are properly set to special_token_id
441
+ # as the model never output special_token_id
442
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
443
+ next_token[~valid_mask] = self.special_token_id
444
+ # 检查eos id
445
+ next_token[is_end] = self.special_token_id
446
+ is_end = is_end | (next_token == self.eos_token_id)
447
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
448
+ # (then mask tokens should be left as is as well, which is correct)
449
+ gen_sequence[..., offset:offset+1] = torch.where(
450
+ gen_sequence[..., offset:offset+1] == unknown_token,
451
+ next_token, gen_sequence[..., offset:offset+1])
452
+
453
+ # record sampled tokens in a window
454
+ if record_tokens:
455
+ record_token_pool.append(next_token.squeeze())
456
+ if torch.all(is_end):
457
+ gen_sequence = gen_sequence[..., :offset+1]
458
+ break
459
+
460
+ prev_offset = offset
461
+
462
+ # ensure sequence has been entirely filled
463
+ assert not (gen_sequence == unknown_token).any()
464
+ max_gen_len = gen_sequence.shape[-1]
465
+ output_codes[..., :max_gen_len] = gen_sequence
466
+ # ensure gen_sequence pattern and mask are matching
467
+ # which means the gen_sequence is valid according to the pattern
468
+ # assert (gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence,
469
+ # self.special_token_id)
470
+ # ).all()
471
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
472
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(output_codes, special_token=unknown_token)
473
+ # sanity checks over the returned codes and corresponding masks
474
+ assert (out_codes != unknown_token).all()
475
+ assert (out_mask == 1).all()
476
+ # ensure the returned codes are all valid
477
+ assert (out_codes >= 0).all() and (out_codes <= self.code_size).all()
478
+ return out_codes
479
+
480
+ def _sample_next_token(self,
481
+ sequence: torch.Tensor,
482
+ condition_tensors: ConditionTensors,
483
+ use_sampling: bool = False,
484
+ temp: float = 1.0,
485
+ top_k: int = 0,
486
+ top_p: float = 0.0,
487
+ cfg_coef: tp.Optional[float] = None,
488
+ sampled_token_pool: tp.Optional[list] = None,
489
+ ignore_tokens: tp.Optional[torch.tensor] = torch.tensor([])) -> torch.Tensor:
490
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
491
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
492
+
493
+ Args:
494
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
495
+ with K corresponding to the number of codebooks and S the number of sequence steps.
496
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
497
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
498
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
499
+ use_sampling (bool): Whether to use a sampling strategy or not.
500
+ temp (float): Sampling temperature.
501
+ top_k (int): K for "top-k" sampling.
502
+ top_p (float): P for "top-p" sampling.
503
+ cfg_coef (float, optional): classifier free guidance coefficient
504
+ Returns:
505
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
506
+ """
507
+ # import pdb; pdb.set_trace()
508
+ B = sequence.shape[0]
509
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
510
+ model = self if self._fsdp is None else self._fsdp
511
+
512
+ # Preparing for CFG, predicting both conditional and unconditional logits.
513
+ sequence = torch.cat([sequence, sequence], dim=0)
514
+ all_logits = model(sequence, condition_tensors=condition_tensors)
515
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
516
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
517
+
518
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
519
+ logits = logits[..., -1] # [B x K x card]
520
+
521
+ # add punishment to pre-sampled tokens
522
+ if sampled_token_pool is not None and len(sampled_token_pool) > 0:
523
+ sampled_token_pool = torch.stack(sampled_token_pool, -1) # [K, T]
524
+ for q in range(self.code_depth):
525
+ # q_count = torch.bincount(sampled_token_pool)
526
+ q_count = torch.bincount(torch.unique(sampled_token_pool[q]))
527
+ tmp = min(q_count.shape[-1], self.code_size - 1)
528
+ logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
529
+
530
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
531
+ if(ignore_tokens is not None):
532
+ logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
533
+ if use_sampling and temp > 0.0:
534
+ probs = torch.softmax(logits / temp, dim=-1)
535
+ if top_p > 0.0:
536
+ next_token = sample_top_p(probs, p=top_p)
537
+ elif top_k > 0:
538
+ next_token_first = sample_top_k(probs[:,[0],:], k=top_k)
539
+ next_token_res = sample_top_k(probs[:,1:,:], k=1)
540
+ next_token = torch.cat([next_token_first,next_token_res], dim = 1)
541
+ else:
542
+ next_token = multinomial(probs, num_samples=1)
543
+ else:
544
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
545
+
546
+ return next_token
codeclm/modules/conditioners.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import typing as tp
3
+ import torch
4
+ import torch.nn as nn
5
+ from dataclasses import dataclass, field, fields
6
+ from itertools import chain
7
+ import warnings
8
+ import torch.nn.functional as F
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from codeclm.utils.utils import length_to_mask, collate
11
+ from codeclm.modules.streaming import StreamingModule
12
+ from collections import defaultdict
13
+ from copy import deepcopy
14
+ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
15
+
16
+ # ================================================================
17
+ # Condition and Condition attributes definitions
18
+ # ================================================================
19
+ class AudioCondition(tp.NamedTuple):
20
+ wav: torch.Tensor
21
+ length: torch.Tensor
22
+ sample_rate: tp.List[int]
23
+ path: tp.List[tp.Optional[str]] = []
24
+ seek_time: tp.List[tp.Optional[float]] = []
25
+
26
+ @dataclass
27
+ class ConditioningAttributes:
28
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
29
+ audio: tp.Dict[str, AudioCondition] = field(default_factory=dict)
30
+
31
+ def __getitem__(self, item):
32
+ return getattr(self, item)
33
+
34
+ @property
35
+ def text_attributes(self):
36
+ return self.text.keys()
37
+
38
+ @property
39
+ def audio_attributes(self):
40
+ return self.audio.keys()
41
+
42
+ @property
43
+ def attributes(self):
44
+ return {
45
+ "text": self.text_attributes,
46
+ "audio": self.audio_attributes,
47
+ }
48
+
49
+ def to_flat_dict(self):
50
+ return {
51
+ **{f"text.{k}": v for k, v in self.text.items()},
52
+ **{f"audio.{k}": v for k, v in self.audio.items()},
53
+ }
54
+
55
+ @classmethod
56
+ def from_flat_dict(cls, x):
57
+ out = cls()
58
+ for k, v in x.items():
59
+ kind, att = k.split(".")
60
+ out[kind][att] = v
61
+ return out
62
+
63
+ # ================================================================
64
+ # Conditioner (tokenize and encode raw conditions) definitions
65
+ # ================================================================
66
+
67
+ class BaseConditioner(nn.Module):
68
+ """Base model for all conditioner modules.
69
+ We allow the output dim to be different than the hidden dim for two reasons:
70
+ 1) keep our LUTs small when the vocab is large;
71
+ 2) make all condition dims consistent.
72
+
73
+ Args:
74
+ dim (int): Hidden dim of the model.
75
+ output_dim (int): Output dim of the conditioner.
76
+ """
77
+ def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=0):
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.output_dim = output_dim
81
+ if input_token:
82
+ self.output_proj = nn.Embedding(dim, output_dim, padding_idx)
83
+ else:
84
+ self.output_proj = nn.Linear(dim, output_dim)
85
+
86
+ def tokenize(self, *args, **kwargs) -> tp.Any:
87
+ """Should be any part of the processing that will lead to a synchronization
88
+ point, e.g. BPE tokenization with transfer to the GPU.
89
+
90
+ The returned value will be saved and return later when calling forward().
91
+ """
92
+ raise NotImplementedError()
93
+
94
+ def forward(self, inputs: tp.Any) -> ConditionType:
95
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
96
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
97
+
98
+ Returns:
99
+ ConditionType:
100
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
101
+ output embedding and D is the dimension of the embedding.
102
+ - And a mask indicating where the padding tokens.
103
+ """
104
+ raise NotImplementedError()
105
+
106
+ class TextConditioner(BaseConditioner):
107
+ ...
108
+
109
+
110
+ class PhonemeTokenizerConditioner(TextConditioner):
111
+ def __init__(self,
112
+ output_dim: int,
113
+ vocab_list,
114
+ max_len = 600,
115
+ max_sentence_per_structure = 50,
116
+ structure_tokens=None,
117
+ structure_split_tokens=[','],
118
+ sentence_split_tokens=['.'],
119
+ mode='sum',
120
+ structure_output_dim = 64,
121
+ sentence_output_dim = 64,
122
+ max_duration = 120,
123
+ ):
124
+
125
+ self.vocab_list = vocab_list
126
+ self.max_len = max_len
127
+ self.mode = mode
128
+ self.max_sentence_per_structure = max_sentence_per_structure
129
+ voc_size = len(self.vocab_list)
130
+
131
+ if structure_tokens is None:
132
+ structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
133
+ self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
134
+ self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
135
+ self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
136
+
137
+ # here initialize a output_proj (nn.Embedding) layer
138
+ # By default the first vocab is "" (null)
139
+ if mode == 'sum':
140
+ content_output_dim = output_dim
141
+ sentence_output_dim = output_dim
142
+ structure_output_dim = output_dim
143
+ else: # concat'
144
+ raise NotImplementedError("concat 模式还未实现")
145
+ # content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default
146
+
147
+ super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
148
+ self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0)
149
+
150
+ self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
151
+
152
+ # the first index is "empty structure" token
153
+ self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
154
+ self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
155
+
156
+ print("max_len", self.max_len)
157
+ print(self.structure_token_ids)
158
+
159
+ self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s
160
+ print(self.__class__, f"resolution = {self.resolution}")
161
+
162
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
163
+ inputs = []
164
+ for xx in x:
165
+ xx = '' if xx is None else xx
166
+ vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
167
+ inputs.append(torch.tensor(vocab_id).long()) # [T]
168
+ return inputs
169
+
170
+
171
+ def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
172
+ """
173
+ Encode token_id into three types of embeddings:
174
+ 1) content embedding: phoneme only (or meaningful contents to be sung out)
175
+ 2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
176
+ The two above share the same embedding layer, can be changed to separate embedding layers.
177
+ 3) sentence_idx embedding (per structure):
178
+ """
179
+ embeds_batch = []
180
+ for b in range(len(batch_tokens)):
181
+ tokens = batch_tokens[b]
182
+ content_tokens = torch.zeros_like(tokens)
183
+ special_tokens = torch.zeros_like(tokens)
184
+ sentence_idx_in_structure_tokens = torch.zeros_like(tokens)
185
+ sentence_reidx_in_structure_tokens = torch.zeros_like(tokens)
186
+
187
+ current_sentence_in_structure_idx = 1
188
+ current_structure = 0
189
+ for i in range(tokens.shape[-1]):
190
+ token = tokens[i]
191
+ if token in self.structure_token_ids: # structure token
192
+ # only update structure token, leave content and sentence index token null (default 0)
193
+ special_tokens[i] = token
194
+ content_tokens[i] = token
195
+ current_structure = token
196
+ current_sentence_in_structure_idx = 1
197
+ sentence_idx_in_structure_tokens[i] = 0
198
+
199
+ elif token in self.sentence_split_token_ids: # utterance split token
200
+ # only update structure token, leave content and sentence index token null (default 0)
201
+ # add up sentence index
202
+ special_tokens[i] = current_structure
203
+ content_tokens[i] = token
204
+ sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
205
+ current_sentence_in_structure_idx += 1
206
+
207
+ elif token in self.structure_split_token_ids: # structure split token
208
+ # update structure token (current structure), content token (current token),
209
+ # blank index token
210
+ content_tokens[i] = token
211
+ special_tokens[i] = current_structure
212
+ sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1]
213
+ else: # content tokens
214
+ content_tokens[i] = token
215
+ special_tokens[i] = current_structure
216
+ sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
217
+ # 反推
218
+ current_sentence_num = sentence_idx_in_structure_tokens[-1]
219
+ for i in range(tokens.shape[-1]-1,-1,-1):
220
+ if current_sentence_num != 0:
221
+ sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1)
222
+ if sentence_idx_in_structure_tokens[i] == 0 and i > 0:
223
+ current_sentence_num = sentence_idx_in_structure_tokens[i-1]
224
+
225
+ # print("tokens", tokens.max(), tokens.min())
226
+ # print("special tokens", special_tokens.max(), special_tokens.min())
227
+ # print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
228
+ device = self.output_proj.weight.device
229
+
230
+ # import pdb; pdb.set_trace()
231
+ content_embeds = self.output_proj(content_tokens.to(device)) # [T, N]
232
+ structure_embeds = self.output_proj(special_tokens.to(device))
233
+ # sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
234
+ sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device))
235
+
236
+ if self.mode == 'sum':
237
+ embeds = content_embeds + structure_embeds + sentence_idx_embeds
238
+ else:
239
+ embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
240
+ embeds_batch.append(embeds)
241
+
242
+ # set batch_size = 1, [B, T, N]
243
+ if self.max_len is not None:
244
+ max_len = self.max_len
245
+ else:
246
+ max_len = max([e.shape[0] for e in embeds_batch])
247
+ embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
248
+
249
+ return embeds, embeds, mask
250
+
251
+
252
+ def pad_2d_tensor(self, xs, max_len):
253
+ new_tensor = []
254
+ new_mask = []
255
+ for x in xs:
256
+ seq_len, dim = x.size()
257
+ pad_len = max_len - seq_len
258
+
259
+ if pad_len > 0:
260
+ pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D
261
+ padded_tensor = torch.cat([x, pad_tensor], dim=0)
262
+ mask = torch.cat((torch.ones_like(x[:, 0]),
263
+ torch.zeros_like(pad_tensor[:, 0])), 0) # T
264
+ elif pad_len < 0:
265
+ padded_tensor = x[:max_len]
266
+ mask = torch.ones_like(padded_tensor[:, 0])
267
+ else:
268
+ padded_tensor = x
269
+ mask = torch.ones_like(x[:, 0])
270
+
271
+ new_tensor.append(padded_tensor)
272
+ new_mask.append(mask)
273
+ # [B, T, D] & [B, T]
274
+ return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)
275
+
276
+
277
+ class QwTokenizerConditioner(TextConditioner):
278
+ def __init__(self, output_dim: int,
279
+ token_path = "",
280
+ max_len = 300,
281
+ add_token_list=[]): #""
282
+ from transformers import Qwen2Tokenizer
283
+ self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
284
+ if add_token_list != []:
285
+ self.text_tokenizer.add_tokens(add_token_list, special_tokens=True)
286
+ voc_size = len(self.text_tokenizer.get_vocab())
287
+ # here initialize a output_proj (nn.Embedding) layer
288
+ super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
289
+ self.max_len = max_len
290
+ self.padding_idx =' <|endoftext|>'
291
+
292
+ vocab = self.text_tokenizer.get_vocab()
293
+ # struct是全部的结构
294
+ struct_tokens = [i for i in add_token_list if i[0]=='[' and i[-1]==']']
295
+ self.struct_token_ids = [vocab[i] for i in struct_tokens]
296
+ self.pad_token_idx = 151643
297
+
298
+ self.structure_emb = nn.Embedding(200, output_dim, padding_idx=0)
299
+ # self.split_token_id = vocab["."]
300
+ print("all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids})
301
+
302
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
303
+ x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x]
304
+ # x = [xi if xi is not None else "" for xi in x]
305
+ inputs = self.text_tokenizer(x, return_tensors="pt", padding=True)
306
+ return inputs
307
+
308
+ def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
309
+ """
310
+ Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that
311
+ belong to these structures accordingly,
312
+ Then delete or keep these structure embeddings.
313
+ """
314
+ mask = inputs['attention_mask']
315
+ tokens = inputs['input_ids']
316
+ B = tokens.shape[0]
317
+ is_sp_embed = torch.any(torch.stack([tokens == i for i in self.struct_token_ids], dim=-1),dim=-1)
318
+
319
+ tp_cover_range = torch.zeros_like(tokens)
320
+ for b, is_sp in enumerate(is_sp_embed):
321
+ sp_list = torch.where(is_sp)[0].tolist()
322
+ sp_list.append(mask[b].sum())
323
+ for i, st in enumerate(sp_list[:-1]):
324
+ tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
325
+
326
+ if self.max_len is not None:
327
+ if inputs['input_ids'].shape[-1] > self.max_len:
328
+ warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
329
+ {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
330
+ tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
331
+ mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
332
+ tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
333
+ device = self.output_proj.weight.device
334
+ content_embeds = self.output_proj(tokens.to(device))
335
+ structure_embeds = self.structure_emb(tp_cover_range.to(device))
336
+
337
+ embeds = content_embeds + structure_embeds
338
+ return embeds, embeds, mask
339
+
340
+ def pad_2d_tensor(self, x, max_len, pad_id):
341
+ batch_size, seq_len = x.size()
342
+ pad_len = max_len - seq_len
343
+
344
+ if pad_len > 0:
345
+ pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
346
+ padded_tensor = torch.cat([x, pad_tensor], dim=1)
347
+ elif pad_len < 0:
348
+ padded_tensor = x[:, :max_len]
349
+ else:
350
+ padded_tensor = x
351
+
352
+ return padded_tensor
353
+
354
+
355
+ class QwTextConditioner(TextConditioner):
356
+ def __init__(self, output_dim: int,
357
+ token_path = "",
358
+ max_len = 300): #""
359
+
360
+ from transformers import Qwen2Tokenizer
361
+ self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
362
+ voc_size = len(self.text_tokenizer.get_vocab())
363
+ # here initialize a output_proj (nn.Embedding) layer
364
+ super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
365
+
366
+ self.max_len = max_len
367
+
368
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
369
+ x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x]
370
+ inputs = self.text_tokenizer(x, return_tensors="pt", padding=True)
371
+ return inputs
372
+
373
+ def forward(self, inputs: tp.Dict[str, torch.Tensor], structure_dur = None) -> ConditionType:
374
+ """
375
+ Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that
376
+ belong to these structures accordingly,
377
+ Then delete or keep these structure embeddings.
378
+ """
379
+ mask = inputs['attention_mask']
380
+ tokens = inputs['input_ids']
381
+
382
+ if self.max_len is not None:
383
+ if inputs['input_ids'].shape[-1] > self.max_len:
384
+ warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
385
+ {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
386
+ tokens = self.pad_2d_tensor(tokens, self.max_len, 151643).to(self.output_proj.weight.device)
387
+ mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
388
+
389
+ embeds = self.output_proj(tokens)
390
+ return embeds, embeds, mask
391
+
392
+ def pad_2d_tensor(self, x, max_len, pad_id):
393
+ batch_size, seq_len = x.size()
394
+ pad_len = max_len - seq_len
395
+
396
+ if pad_len > 0:
397
+ pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
398
+ padded_tensor = torch.cat([x, pad_tensor], dim=1)
399
+ elif pad_len < 0:
400
+ padded_tensor = x[:, :max_len]
401
+ else:
402
+ padded_tensor = x
403
+
404
+ return padded_tensor
405
+
406
+
407
+ class AudioConditioner(BaseConditioner):
408
+ ...
409
+
410
+ class QuantizedEmbeddingConditioner(AudioConditioner):
411
+ def __init__(self, dim: int,
412
+ code_size: int,
413
+ code_depth: int,
414
+ max_len: int,
415
+ **kwargs):
416
+ super().__init__(dim, dim, input_token=True)
417
+ self.code_depth = code_depth
418
+ # add 1 for <s> token
419
+ self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
420
+ # add End-Of-Text embedding
421
+ self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
422
+ self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
423
+ self.output_proj = None
424
+ self.max_len = max_len
425
+ self.vocab_size = code_size
426
+
427
+ def tokenize(self, x: AudioCondition) -> AudioCondition:
428
+ """no extra ops"""
429
+ # wav, length, sample_rate, path, seek_time = x
430
+ # assert length is not None
431
+ return x #AudioCondition(wav, length, sample_rate, path, seek_time)
432
+
433
+ def forward(self, x: AudioCondition):
434
+ wav, lengths, *_ = x
435
+ B = wav.shape[0]
436
+ wav = wav.reshape(B, self.code_depth, -1).long()
437
+ if wav.shape[2] < self.max_len - 1:
438
+ wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
439
+ else:
440
+ wav = wav[:, :, :self.max_len-1]
441
+ embeds1 = self.emb[0](wav[:, 0])
442
+ embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1),
443
+ embeds1), dim=1)
444
+ embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) # B,T,D
445
+ embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1),
446
+ embeds2), dim=1)
447
+ lengths = lengths + 1
448
+ lengths = torch.clamp(lengths, max=self.max_len)
449
+
450
+ if lengths is not None:
451
+ mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() # type: ignore
452
+ else:
453
+ mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int)
454
+ return embeds1, embeds2, mask
455
+
456
+
457
+ # ================================================================
458
+ # Aggregate all conditions and corresponding conditioners
459
+ # ================================================================
460
+ class ConditionerProvider(nn.Module):
461
+ """Prepare and provide conditions given all the supported conditioners.
462
+
463
+ Args:
464
+ conditioners (dict): Dictionary of conditioners.
465
+ device (torch.device or str, optional): Device for conditioners and output condition types.
466
+ """
467
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner]):
468
+ super().__init__()
469
+ self.conditioners = nn.ModuleDict(conditioners)
470
+
471
+ @property
472
+ def text_conditions(self):
473
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
474
+
475
+ @property
476
+ def audio_conditions(self):
477
+ return [k for k, v in self.conditioners.items() if isinstance(v, AudioConditioner)]
478
+
479
+ @property
480
+ def has_audio_condition(self):
481
+ return len(self.audio_conditions) > 0
482
+
483
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
484
+ """Match attributes/audios with existing conditioners in self, and compute tokenize them accordingly.
485
+ This should be called before starting any real GPU work to avoid synchronization points.
486
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
487
+
488
+ Args:
489
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
490
+ text and audio conditions.
491
+ """
492
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
493
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
494
+ f" but types were {set([type(x) for x in inputs])}")
495
+
496
+ output = {}
497
+ text = self._collate_text(inputs)
498
+ audios = self._collate_audios(inputs)
499
+
500
+ assert set(text.keys() | audios.keys()).issubset(set(self.conditioners.keys())), (
501
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
502
+ f"got {text.keys(), audios.keys()}")
503
+
504
+ for attribute, batch in chain(text.items(), audios.items()):
505
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
506
+ return output
507
+
508
+ def forward(self, tokenized: tp.Dict[str, tp.Any], structure_dur = None) -> tp.Dict[str, ConditionType]:
509
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
510
+ The output is for example:
511
+ {
512
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
513
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
514
+ ...
515
+ }
516
+
517
+ Args:
518
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
519
+ """
520
+ output = {}
521
+ for attribute, inputs in tokenized.items():
522
+ if attribute == 'description' and structure_dur is not None:
523
+ condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur)
524
+ else:
525
+ condition1, condition2, mask = self.conditioners[attribute](inputs)
526
+ output[attribute] = (condition1, condition2, mask)
527
+ return output
528
+
529
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
530
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
531
+ are the attributes and the values are the aggregated input per attribute.
532
+ For example:
533
+ Input:
534
+ [
535
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
536
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, audio=...),
537
+ ]
538
+ Output:
539
+ {
540
+ "genre": ["Rock", "Hip-hop"],
541
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
542
+ }
543
+
544
+ Args:
545
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
546
+ Returns:
547
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
548
+ """
549
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
550
+ texts = [x.text for x in samples]
551
+ for text in texts:
552
+ for condition in self.text_conditions:
553
+ out[condition].append(text[condition])
554
+ return out
555
+
556
+ def _collate_audios(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, AudioCondition]:
557
+ """Generate a dict where the keys are attributes by which we fetch similar audios,
558
+ and the values are Tensors of audios according to said attributes.
559
+
560
+ *Note*: by the time the samples reach this function, each sample should have some audios
561
+ inside the "audio" attribute. It should be either:
562
+ 1. A real audio
563
+ 2. A null audio due to the sample having no similar audios (nullified by the dataset)
564
+ 3. A null audio due to it being dropped in a dropout module (nullified by dropout)
565
+
566
+ Args:
567
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
568
+ Returns:
569
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
570
+ """
571
+ # import pdb; pdb.set_trace()
572
+ wavs = defaultdict(list)
573
+ lengths = defaultdict(list)
574
+ sample_rates = defaultdict(list)
575
+ paths = defaultdict(list)
576
+ seek_times = defaultdict(list)
577
+ out: tp.Dict[str, AudioCondition] = {}
578
+
579
+ for sample in samples:
580
+ for attribute in self.audio_conditions:
581
+ wav, length, sample_rate, path, seek_time = sample.audio[attribute]
582
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
583
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
584
+ wavs[attribute].append(wav.flatten()) # [C*T]
585
+ lengths[attribute].append(length)
586
+ sample_rates[attribute].extend(sample_rate)
587
+ paths[attribute].extend(path)
588
+ seek_times[attribute].extend(seek_time)
589
+
590
+ # stack all wavs to a single tensor
591
+ for attribute in self.audio_conditions:
592
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
593
+ out[attribute] = AudioCondition(
594
+ stacked_wav.unsqueeze(1),
595
+ torch.cat(lengths[attribute]), sample_rates[attribute],
596
+ paths[attribute], seek_times[attribute])
597
+
598
+ return out
599
+
600
+
601
+ class ConditionFuser(StreamingModule):
602
+ """Condition fuser handles the logic to combine the different conditions
603
+ to the actual model input.
604
+
605
+ Args:
606
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
607
+ each condition. For example:
608
+ {
609
+ "prepend": ["description"],
610
+ "sum": ["genre", "bpm"],
611
+ }
612
+ """
613
+ FUSING_METHODS = ["sum", "prepend"] #, "cross", "input_interpolate"] (not support in this simplest version)
614
+
615
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]]):
616
+ super().__init__()
617
+ assert all([k in self.FUSING_METHODS for k in fuse2cond.keys()]
618
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
619
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
620
+ self.cond2fuse: tp.Dict[str, str] = {}
621
+ for fuse_method, conditions in fuse2cond.items():
622
+ for condition in conditions:
623
+ self.cond2fuse[condition] = fuse_method
624
+
625
+ def forward(
626
+ self,
627
+ input1: torch.Tensor,
628
+ input2: torch.Tensor,
629
+ conditions: tp.Dict[str, ConditionType]
630
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
631
+ """Fuse the conditions to the provided model input.
632
+
633
+ Args:
634
+ input (torch.Tensor): Transformer input.
635
+ conditions (dict[str, ConditionType]): Dict of conditions.
636
+ Returns:
637
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
638
+ after the conditions have been fused. The second output tensor is the tensor
639
+ used for cross-attention or None if no cross attention inputs exist.
640
+ """
641
+ #import pdb; pdb.set_trace()
642
+ B, T, _ = input1.shape
643
+
644
+ if 'offsets' in self._streaming_state:
645
+ first_step = False
646
+ offsets = self._streaming_state['offsets']
647
+ else:
648
+ first_step = True
649
+ offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device)
650
+
651
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
652
+ f"given conditions contain unknown attributes for fuser, " \
653
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
654
+
655
+ # if 'prepend' mode is used,
656
+ # the concatenation order will be the SAME with the conditions in config:
657
+ # prepend: ['description', 'prompt_audio'] (then goes the input)
658
+ fused_input_1 = input1
659
+ fused_input_2 = input2
660
+ for fuse_op in self.fuse2cond.keys():
661
+ fuse_op_conditions = self.fuse2cond[fuse_op]
662
+ if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
663
+ for cond in fuse_op_conditions:
664
+ this_cond_1, this_cond_2, cond_mask = conditions[cond]
665
+ fused_input_1 += this_cond_1
666
+ fused_input_2 += this_cond_2
667
+ elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
668
+ if not first_step:
669
+ continue
670
+ reverse_list = deepcopy(fuse_op_conditions)
671
+ reverse_list.reverse()
672
+ for cond in reverse_list:
673
+ this_cond_1, this_cond_2, cond_mask = conditions[cond]
674
+ fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) # concat along T dim
675
+ fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim
676
+ elif fuse_op not in self.FUSING_METHODS:
677
+ raise ValueError(f"unknown op ({fuse_op})")
678
+
679
+ if self._is_streaming:
680
+ self._streaming_state['offsets'] = offsets + T
681
+
682
+ return fused_input_1, fused_input_2
683
+
684
+
685
+
686
+ # ================================================================
687
+ # Condition Dropout
688
+ # ================================================================
689
+ class DropoutModule(nn.Module):
690
+ """Base module for all dropout modules."""
691
+ def __init__(self, seed: int = 1234):
692
+ super().__init__()
693
+ self.rng = torch.Generator()
694
+ self.rng.manual_seed(seed)
695
+
696
+
697
+
698
+ class ClassifierFreeGuidanceDropout(DropoutModule):
699
+ """Classifier Free Guidance dropout.
700
+ All attributes are dropped with the same probability.
701
+
702
+ Args:
703
+ p (float): Probability to apply condition dropout during training.
704
+ seed (int): Random seed.
705
+ """
706
+ def __init__(self, p: float, seed: int = 1234):
707
+ super().__init__(seed=seed)
708
+ self.p = p
709
+
710
+ def check(self, sample, condition_type, condition):
711
+
712
+ if condition_type not in ['text', 'audio']:
713
+ raise ValueError("dropout_condition got an unexpected condition type!"
714
+ f" expected 'text', 'audio' but got '{condition_type}'")
715
+
716
+ if condition not in getattr(sample, condition_type):
717
+ raise ValueError(
718
+ "dropout_condition received an unexpected condition!"
719
+ f" expected audio={sample.audio.keys()} and text={sample.text.keys()}"
720
+ f" but got '{condition}' of type '{condition_type}'!")
721
+
722
+
723
+ def get_null_wav(self, wav, sr=48000) -> AudioCondition:
724
+ out = wav * 0 + 16385
725
+ return AudioCondition(
726
+ wav=out,
727
+ length=torch.Tensor([0]).long(),
728
+ sample_rate=[sr],)
729
+
730
+ def dropout_condition(self,
731
+ sample: ConditioningAttributes,
732
+ condition_type: str,
733
+ condition: str) -> ConditioningAttributes:
734
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
735
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
736
+ If the condition is of any other type, set its value to None.
737
+ Works in-place.
738
+ """
739
+ self.check(sample, condition_type, condition)
740
+
741
+ if condition_type == 'audio':
742
+ audio_cond = sample.audio[condition]
743
+ depth = audio_cond.wav.shape[1]
744
+ sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
745
+ else:
746
+ sample.text[condition] = None
747
+
748
+ return sample
749
+
750
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
751
+ """
752
+ Args:
753
+ samples (list[ConditioningAttributes]): List of conditions.
754
+ Returns:
755
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
756
+ """
757
+ # decide on which attributes to drop in a batched fashion
758
+ # drop = torch.rand(1, generator=self.rng).item() < self.p
759
+ # if not drop:
760
+ # return samples
761
+
762
+ # nullify conditions of all attributes
763
+ samples = deepcopy(samples)
764
+
765
+ for sample in samples:
766
+ drop = torch.rand(1, generator=self.rng).item()
767
+ if drop<self.p:
768
+ for condition_type in ["audio", "text"]:
769
+ for condition in sample.attributes[condition_type]:
770
+ self.dropout_condition(sample, condition_type, condition)
771
+ return samples
772
+
773
+ def __repr__(self):
774
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
775
+
776
+
777
+ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
778
+ """Classifier Free Guidance dropout during inference.
779
+ All attributes are dropped with the same probability.
780
+
781
+ Args:
782
+ p (float): Probability to apply condition dropout during training.
783
+ seed (int): Random seed.
784
+ """
785
+ def __init__(self, seed: int = 1234):
786
+ super().__init__(p=1, seed=seed)
787
+
788
+ def dropout_condition_customized(self,
789
+ sample: ConditioningAttributes,
790
+ condition_type: str,
791
+ condition: str,
792
+ customized: list = None) -> ConditioningAttributes:
793
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
794
+ If the condition is of type "audio", then nullify it using `nullify_condition` function.
795
+ If the condition is of any other type, set its value to None.
796
+ Works in-place.
797
+ """
798
+ self.check(sample, condition_type, condition)
799
+
800
+ if condition_type == 'audio':
801
+ audio_cond = sample.audio[condition]
802
+ depth = audio_cond.wav.shape[1]
803
+ sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
804
+ else:
805
+ if customized is None:
806
+ sample.text[condition] = None
807
+ else:
808
+ text_cond = deepcopy(sample.text[condition])
809
+ if "structure" in customized:
810
+ for _s in ['[inst]', '[outro]', '[intro]', '[verse]', '[chorus]', '[bridge]']:
811
+ text_cond = text_cond.replace(_s, "")
812
+ text_cond = text_cond.replace(' , ', '')
813
+ text_cond = text_cond.replace(" ", " ")
814
+ if '.' in customized:
815
+ text_cond = text_cond.replace(" . ", " ")
816
+ text_cond = text_cond.replace(".", " ")
817
+
818
+ sample.text[condition] = text_cond
819
+
820
+ return sample
821
+
822
+ def forward(self, samples: tp.List[ConditioningAttributes],
823
+ condition_types=["wav", "text"],
824
+ customized=None,
825
+ ) -> tp.List[ConditioningAttributes]:
826
+ """
827
+ 100% dropout some condition attributes (description, prompt_wav) or types (text, wav) of
828
+ samples during inference.
829
+
830
+ Args:
831
+ samples (list[ConditioningAttributes]): List of conditions.
832
+ Returns:
833
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
834
+ """
835
+ new_samples = deepcopy(samples)
836
+ for condition_type in condition_types:
837
+ for sample in new_samples:
838
+ for condition in sample.attributes[condition_type]:
839
+ self.dropout_condition_customized(sample, condition_type, condition, customized)
840
+ return new_samples
841
+
842
+ class AttributeDropout(ClassifierFreeGuidanceDropout):
843
+ """Dropout with a given probability per attribute.
844
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
845
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
846
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
847
+ must also be dropped.
848
+
849
+ Args:
850
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
851
+ ...
852
+ "genre": 0.1,
853
+ "artist": 0.5,
854
+ "audio": 0.25,
855
+ ...
856
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
857
+ seed (int, optional): Random seed.
858
+ """
859
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
860
+ super().__init__(p=p, seed=seed)
861
+ self.active_on_eval = active_on_eval
862
+ # construct dict that return the values from p otherwise 0
863
+ self.p = {}
864
+ for condition_type, probs in p.items():
865
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
866
+
867
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
868
+ """
869
+ Args:
870
+ samples (list[ConditioningAttributes]): List of conditions.
871
+ Returns:
872
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
873
+ """
874
+ if not self.training and not self.active_on_eval:
875
+ return samples
876
+
877
+ samples = deepcopy(samples)
878
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
879
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
880
+ if torch.rand(1, generator=self.rng).item() < p:
881
+ for sample in samples:
882
+ self.dropout_condition(sample, condition_type, condition)
883
+ return samples
codeclm/modules/pattern.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ import logging
5
+ import typing as tp
6
+
7
+ from abc import ABC, abstractmethod
8
+ import torch
9
+
10
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
11
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class Pattern:
17
+ """Base implementation of a pattern over a sequence with multiple codebooks.
18
+
19
+ The codebook pattern consists in a layout, defining for each sequence step
20
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
21
+ The first item of the pattern is always an empty list in order to properly insert a special token
22
+ to start with. For convenience, we also keep track of ``code_depth`` the number of codebooks used for the pattern
23
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
24
+
25
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
26
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
27
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
28
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
29
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
30
+ is returned along with a mask indicating valid tokens.
31
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
32
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
33
+ to fill and specify invalid positions if needed.
34
+ See the dedicated methods for more details.
35
+ """
36
+ # Pattern layout, for each sequence step, we have a list of coordinates
37
+ # corresponding to the original codebook timestep and position.
38
+ # The first list is always an empty list in order to properly insert
39
+ # a special token to start with.
40
+ layout: PatternLayout
41
+ timesteps: int
42
+ code_depth: int
43
+
44
+ def __post_init__(self):
45
+ assert len(self.layout) > 0
46
+ assert self.layout[0] == []
47
+ self._validate_layout()
48
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
49
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
50
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
51
+
52
+ def _validate_layout(self):
53
+ """Runs checks on the layout to ensure a valid pattern is defined.
54
+ A pattern is considered invalid if:
55
+ - Multiple timesteps for a same codebook are defined in the same sequence step
56
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
57
+ (this would mean that we have future timesteps before past timesteps).
58
+ """
59
+ q_timesteps = {q: 0 for q in range(self.code_depth)}
60
+ for s, seq_coords in enumerate(self.layout):
61
+ if len(seq_coords) > 0:
62
+ qs = set()
63
+ for coord in seq_coords:
64
+ qs.add(coord.q)
65
+ last_q_timestep = q_timesteps[coord.q]
66
+ # assert coord.t >= last_q_timestep, \
67
+ # f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
68
+ q_timesteps[coord.q] = coord.t
69
+ # each sequence step contains at max 1 coordinate per codebook
70
+ assert len(qs) == len(seq_coords), \
71
+ f"Multiple entries for a same codebook are found at step {s}"
72
+
73
+ @property
74
+ def num_sequence_steps(self):
75
+ return len(self.layout) - 1
76
+
77
+ @property
78
+ def max_delay(self):
79
+ max_t_in_seq_coords = 0
80
+ for seq_coords in self.layout[1:]:
81
+ for coords in seq_coords:
82
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
83
+ return max_t_in_seq_coords - self.timesteps
84
+
85
+ @property
86
+ def valid_layout(self):
87
+ valid_step = len(self.layout) - self.max_delay
88
+ return self.layout[:valid_step]
89
+
90
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
91
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
92
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
93
+ and the actual codebook coordinates.
94
+ """
95
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
96
+ if q is not None:
97
+ assert q <= self.code_depth, "provided number of codebooks is greater than the pattern's number of codebooks"
98
+ coords = []
99
+ for s, seq_codes in enumerate(self.layout):
100
+ for code in seq_codes:
101
+ if code.t == t and (q is None or code.q == q):
102
+ coords.append((s, code))
103
+ return coords
104
+
105
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
106
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
107
+
108
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
109
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
110
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
111
+
112
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int,
113
+ code_depth: int,
114
+ keep_only_valid_steps: bool,
115
+ device: tp.Union[torch.device, str] = 'cpu'):
116
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
117
+
118
+ Args:
119
+ timesteps (int): Maximum number of timesteps steps to consider.
120
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
121
+ device (torch.device or str): Device for created tensors.
122
+ Returns:
123
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
124
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
125
+ """
126
+ assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}"
127
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
128
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
129
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
130
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
131
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
132
+ indexes = torch.zeros(code_depth, len(ref_layout), dtype=torch.long).numpy()
133
+ mask = torch.zeros(code_depth, len(ref_layout), dtype=torch.bool).numpy()
134
+ # fill indexes with last sequence step value that will correspond to our special token
135
+ # the last value is code_depth * timesteps as we have flattened z and append special token as the last token
136
+ # which will correspond to the index: code_depth * timesteps
137
+ indexes[:] = code_depth * timesteps
138
+ # iterate over the pattern and fill scattered indexes and mask
139
+ for s, sequence_coords in enumerate(ref_layout):
140
+ for coords in sequence_coords:
141
+ if coords.t < timesteps:
142
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
143
+ mask[coords.q, s] = 1
144
+ indexes = torch.from_numpy(indexes).to(device)
145
+ mask = torch.from_numpy(mask).to(device)
146
+ return indexes, mask
147
+
148
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
149
+ """Build sequence corresponding to the pattern from the input tensor z.
150
+ The sequence is built using up to sequence_steps if specified, and non-pattern
151
+ coordinates are filled with the special token.
152
+
153
+ Args:
154
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
155
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
156
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
157
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
158
+ Returns:
159
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
160
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
161
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
162
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
163
+ """
164
+ B, K, T = z.shape
165
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
166
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
167
+ )
168
+ z = z.reshape(B, -1)
169
+ # we append the special token as the last index of our flattened z tensor
170
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
171
+ values = z[:, indexes.view(-1)]
172
+ values = values.view(B, K, indexes.shape[-1])
173
+ # import pdb; pdb.set_trace()
174
+ return values, indexes, mask
175
+
176
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, code_depth: int,
177
+ keep_only_valid_steps: bool = False,
178
+ is_model_output: bool = False,
179
+ device: tp.Union[torch.device, str] = 'cpu'):
180
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
181
+ from interleaving pattern.
182
+
183
+ Args:
184
+ sequence_steps (int): Sequence steps.
185
+ code_depth (int): Number of codebooks.
186
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
187
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
188
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
189
+ device (torch.device or str): Device for created tensors.
190
+ Returns:
191
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
192
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
193
+ """
194
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
195
+ timesteps = self.timesteps
196
+ assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}"
197
+ assert sequence_steps <= len(ref_layout), \
198
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
199
+
200
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
201
+ if is_model_output:
202
+ ref_layout = ref_layout[1:]
203
+
204
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
205
+ indexes = torch.zeros(code_depth, timesteps, dtype=torch.long).numpy()
206
+ mask = torch.zeros(code_depth, timesteps, dtype=torch.bool).numpy()
207
+ # fill indexes with last sequence step value that will correspond to our special token
208
+ indexes[:] = code_depth * sequence_steps
209
+ for s, sequence_codes in enumerate(ref_layout):
210
+ if s < sequence_steps:
211
+ for code in sequence_codes:
212
+ if code.t < timesteps:
213
+ indexes[code.q, code.t] = s + code.q * sequence_steps
214
+ mask[code.q, code.t] = 1
215
+ indexes = torch.from_numpy(indexes).to(device)
216
+ mask = torch.from_numpy(mask).to(device)
217
+ return indexes, mask
218
+
219
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
220
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
221
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
222
+ are filled with the special token.
223
+
224
+ Args:
225
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
226
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
227
+ Returns:
228
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
229
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
230
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
231
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
232
+ """
233
+ B, K, S = s.shape
234
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
235
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
236
+ )
237
+ s = s.view(B, -1)
238
+ # we append the special token as the last index of our flattened z tensor
239
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
240
+ values = s[:, indexes.view(-1)]
241
+ values = values.view(B, K, indexes.shape[-1])
242
+ return values, indexes, mask
243
+
244
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
245
+ """Revert model logits obtained on a sequence built from the pattern
246
+ back to a tensor matching the original sequence.
247
+
248
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
249
+ 1. It is designed to work with the extra cardinality dimension
250
+ 2. We return the logits for the first sequence item that matches the special_token and
251
+ which matching target in the original sequence is the first item of the sequence,
252
+ while we skip the last logits as there is no matching target
253
+ """
254
+ B, card, K, S = logits.shape
255
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
256
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
257
+ )
258
+ logits = logits.reshape(B, card, -1)
259
+ # we append the special token as the last index of our flattened z tensor
260
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
261
+ values = logits[:, :, indexes.view(-1)]
262
+
263
+ values = values.view(B, card, K, indexes.shape[-1])
264
+ return values, indexes, mask
265
+
266
+
267
+
268
+ class CodebooksPatternProvider(ABC):
269
+ """Abstraction around providing pattern for interleaving codebooks.
270
+
271
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
272
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
273
+ number of codebooks `code_depth`, the pattern provider can generate a specified pattern
274
+ corresponding to a sequence of `T` timesteps with `code_depth` parallel codebooks. This pattern
275
+ can be used to construct a new sequence from the original codes respecting the specified
276
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
277
+ being a tuple with the original timestep and codebook to build the new sequence.
278
+ Note that all patterns must start with an empty list that is then used to insert a first
279
+ sequence step of special tokens in the newly generated sequence.
280
+
281
+ Args:
282
+ code_depth (int): number of codebooks.
283
+ cached (bool): if True, patterns for a given length are cached. In general
284
+ that should be true for efficiency reason to avoid synchronization points.
285
+ """
286
+ def __init__(self, code_depth: int, cached: bool = True):
287
+ assert code_depth > 0
288
+ self.code_depth = code_depth
289
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
290
+
291
+ @abstractmethod
292
+ def get_pattern(self, timesteps: int) -> Pattern:
293
+ """Builds pattern with specific interleaving between codebooks.
294
+
295
+ Args:
296
+ timesteps (int): Total number of timesteps.
297
+ """
298
+ raise NotImplementedError()
299
+
300
+
301
+ class DelayedPatternProvider(CodebooksPatternProvider):
302
+ """Provider for delayed pattern across delayed codebooks.
303
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
304
+ from different timesteps.
305
+
306
+ Example:
307
+ Taking timesteps=4 and code_depth=3, delays=None, the multi-codebook sequence:
308
+ [[1, 2, 3, 4],
309
+ [1, 2, 3, 4],
310
+ [1, 2, 3, 4]]
311
+ The resulting sequence obtained from the returned pattern is:
312
+ [[S, 1, 2, 3, 4],
313
+ [S, S, 1, 2, 3],
314
+ [S, S, S, 1, 2]]
315
+ (with S being a special token)
316
+
317
+ Args:
318
+ code_depth (int): Number of codebooks.
319
+ delays (list of int, optional): Delay for each of the codebooks.
320
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
321
+ flatten_first (int): Flatten the first N timesteps.
322
+ empty_initial (int): Prepend with N empty list of coordinates.
323
+ """
324
+ def __init__(self, code_depth: int, delays: tp.Optional[tp.List[int]] = None,
325
+ flatten_first: int = 0, empty_initial: int = 0):
326
+ super().__init__(code_depth)
327
+ if delays is None:
328
+ delays = list(range(code_depth))
329
+ self.delays = delays
330
+ self.flatten_first = flatten_first
331
+ self.empty_initial = empty_initial
332
+ assert len(self.delays) == self.code_depth
333
+ assert sorted(self.delays) == self.delays
334
+
335
+ def get_pattern(self, timesteps: int) -> Pattern:
336
+ out: PatternLayout = [[]]
337
+ max_delay = max(self.delays)
338
+ if self.empty_initial:
339
+ out += [[] for _ in range(self.empty_initial)]
340
+ if self.flatten_first:
341
+ for t in range(min(timesteps, self.flatten_first)):
342
+ for q in range(self.code_depth):
343
+ out.append([LayoutCoord(t, q)])
344
+ for t in range(self.flatten_first, timesteps + max_delay):
345
+ v = []
346
+ for q, delay in enumerate(self.delays):
347
+ t_for_q = t - delay
348
+ if t_for_q >= self.flatten_first:
349
+ v.append(LayoutCoord(t_for_q, q))
350
+ out.append(v)
351
+ return Pattern(out, code_depth=self.code_depth, timesteps=timesteps)
codeclm/modules/streaming.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streaming module API that should be implemented by all Streaming components,
3
+ """
4
+
5
+ from contextlib import contextmanager
6
+ import typing as tp
7
+ from torch import nn
8
+ import torch
9
+
10
+
11
+ State = tp.Dict[str, torch.Tensor]
12
+
13
+ class StreamingModule(nn.Module):
14
+ """Common API for streaming components.
15
+
16
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
17
+ By convention, the first dim of each tensor must be the batch size.
18
+ Don't use dots in the key names, as this would clash with submodules
19
+ (like in state_dict).
20
+
21
+ If `self._is_streaming` is True, the component should use and remember
22
+ the proper state inside `self._streaming_state`.
23
+
24
+ To set a streaming component in streaming state, use
25
+
26
+ with module.streaming():
27
+ ...
28
+
29
+ This will automatically reset the streaming state when exiting the context manager.
30
+ This also automatically propagates to all streaming children module.
31
+
32
+ Some module might also implement the `StreamingModule.flush` method, although
33
+ this one is trickier, as all parents module must be StreamingModule and implement
34
+ it as well for it to work properly. See `StreamingSequential` after.
35
+ """
36
+ def __init__(self) -> None:
37
+ super().__init__()
38
+ self._streaming_state: State = {}
39
+ self._is_streaming = False
40
+
41
+ def _apply_named_streaming(self, fn: tp.Any):
42
+ for name, module in self.named_modules():
43
+ if isinstance(module, StreamingModule):
44
+ fn(name, module)
45
+
46
+ def _set_streaming(self, streaming: bool):
47
+ def _set_streaming(name, module):
48
+ module._is_streaming = streaming
49
+ self._apply_named_streaming(_set_streaming)
50
+
51
+ @contextmanager
52
+ def streaming(self):
53
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
54
+ self._set_streaming(True)
55
+ try:
56
+ yield
57
+ finally:
58
+ self._set_streaming(False)
59
+ self.reset_streaming()
60
+
61
+ def reset_streaming(self):
62
+ """Reset the streaming state."""
63
+ def _reset(name: str, module: StreamingModule):
64
+ module._streaming_state.clear()
65
+
66
+ self._apply_named_streaming(_reset)
67
+
68
+ def get_streaming_state(self) -> State:
69
+ """Return the streaming state, including that of sub-modules."""
70
+ state: State = {}
71
+
72
+ def _add(name: str, module: StreamingModule):
73
+ if name:
74
+ name += "."
75
+ for key, value in module._streaming_state.items():
76
+ state[name + key] = value
77
+
78
+ self._apply_named_streaming(_add)
79
+ return state
80
+
81
+ def set_streaming_state(self, state: State):
82
+ """Set the streaming state, including that of sub-modules."""
83
+ state = dict(state)
84
+
85
+ def _set(name: str, module: StreamingModule):
86
+ if name:
87
+ name += "."
88
+ module._streaming_state.clear()
89
+ for key, value in list(state.items()):
90
+ # complexity is not ideal here, but probably fine.
91
+ if key.startswith(name):
92
+ local_key = key[len(name):]
93
+ if '.' not in local_key:
94
+ module._streaming_state[local_key] = value
95
+ del state[key]
96
+
97
+ self._apply_named_streaming(_set)
98
+ assert len(state) == 0, list(state.keys())
99
+
100
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
101
+ """Flush any remaining outputs that were waiting for completion.
102
+ Typically, for convolutions, this will add the final padding
103
+ and process the last buffer.
104
+
105
+ This should take an optional argument `x`, which will be provided
106
+ if a module before this one in the streaming pipeline has already
107
+ spitted out a flushed out buffer.
108
+ """
109
+ if x is None:
110
+ return None
111
+ else:
112
+ return self(x)
codeclm/tokenizer/Flow1dVAE/audio.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : audio.py
5
+ @Time : 2023/8/8 下午7:18
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Audio
10
+ """
11
+ import json
12
+ import subprocess as sp
13
+ import typing as tp
14
+ from pathlib import Path
15
+
16
+ import lameenc
17
+ import julius
18
+ import torch
19
+ import numpy as np
20
+ import torchaudio as ta
21
+ from contextlib import contextmanager
22
+ import tempfile
23
+ import os
24
+
25
+ @contextmanager
26
+ def temp_filenames(count: int, delete=True):
27
+ names = []
28
+ try:
29
+ for _ in range(count):
30
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
31
+ yield names
32
+ finally:
33
+ if delete:
34
+ for name in names:
35
+ os.unlink(name)
36
+
37
+
38
+ def _read_info(path):
39
+ stdout_data = sp.check_output([
40
+ 'ffprobe', "-loglevel", "panic",
41
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
42
+ ])
43
+ return json.loads(stdout_data.decode('utf-8'))
44
+
45
+
46
+ class AudioFile:
47
+ """
48
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
49
+ converting to mono on the fly. See :method:`read` for more details.
50
+ """
51
+ def __init__(self, path: Path):
52
+ self.path = Path(path)
53
+ self._info = None
54
+
55
+ def __repr__(self):
56
+ features = [("path", self.path)]
57
+ features.append(("samplerate", self.samplerate()))
58
+ features.append(("channels", self.channels()))
59
+ features.append(("streams", len(self)))
60
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
61
+ return f"AudioFile({features_str})"
62
+
63
+ @property
64
+ def info(self):
65
+ if self._info is None:
66
+ self._info = _read_info(self.path)
67
+ return self._info
68
+
69
+ @property
70
+ def duration(self):
71
+ return float(self.info['format']['duration'])
72
+
73
+ @property
74
+ def _audio_streams(self):
75
+ return [
76
+ index for index, stream in enumerate(self.info["streams"])
77
+ if stream["codec_type"] == "audio"
78
+ ]
79
+
80
+ def __len__(self):
81
+ return len(self._audio_streams)
82
+
83
+ def channels(self, stream=0):
84
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
85
+
86
+ def samplerate(self, stream=0):
87
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
88
+
89
+ def read(self,
90
+ seek_time=None,
91
+ duration=None,
92
+ streams=slice(None),
93
+ samplerate=None,
94
+ channels=None):
95
+ """
96
+ Slightly more efficient implementation than stempeg,
97
+ in particular, this will extract all stems at once
98
+ rather than having to loop over one file multiple times
99
+ for each stream.
100
+
101
+ Args:
102
+ seek_time (float): seek time in seconds or None if no seeking is needed.
103
+ duration (float): duration in seconds to extract or None to extract until the end.
104
+ streams (slice, int or list): streams to extract, can be a single int, a list or
105
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
106
+ with S the number of streams, C the number of channels and T the number of samples.
107
+ If it is an int, the output will be [C, T].
108
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
109
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
110
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
111
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
112
+ See https://sound.stackexchange.com/a/42710.
113
+ Our definition of mono is simply the average of the two channels. Any other
114
+ value will be ignored.
115
+ """
116
+ streams = np.array(range(len(self)))[streams]
117
+ single = not isinstance(streams, np.ndarray)
118
+ if single:
119
+ streams = [streams]
120
+
121
+ if duration is None:
122
+ target_size = None
123
+ query_duration = None
124
+ else:
125
+ target_size = int((samplerate or self.samplerate()) * duration)
126
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
127
+
128
+ with temp_filenames(len(streams)) as filenames:
129
+ command = ['ffmpeg', '-y']
130
+ command += ['-loglevel', 'panic']
131
+ if seek_time:
132
+ command += ['-ss', str(seek_time)]
133
+ command += ['-i', str(self.path)]
134
+ for stream, filename in zip(streams, filenames):
135
+ command += ['-map', f'0:{self._audio_streams[stream]}']
136
+ if query_duration is not None:
137
+ command += ['-t', str(query_duration)]
138
+ command += ['-threads', '1']
139
+ command += ['-f', 'f32le']
140
+ if samplerate is not None:
141
+ command += ['-ar', str(samplerate)]
142
+ command += [filename]
143
+
144
+ sp.run(command, check=True)
145
+ wavs = []
146
+ for filename in filenames:
147
+ wav = np.fromfile(filename, dtype=np.float32)
148
+ wav = torch.from_numpy(wav)
149
+ wav = wav.view(-1, self.channels()).t()
150
+ if channels is not None:
151
+ wav = convert_audio_channels(wav, channels)
152
+ if target_size is not None:
153
+ wav = wav[..., :target_size]
154
+ wavs.append(wav)
155
+ wav = torch.stack(wavs, dim=0)
156
+ if single:
157
+ wav = wav[0]
158
+ return wav
159
+
160
+
161
+ def convert_audio_channels(wav, channels=2):
162
+ """Convert audio to the given number of channels."""
163
+ *shape, src_channels, length = wav.shape
164
+ if src_channels == channels:
165
+ pass
166
+ elif channels == 1:
167
+ # Case 1:
168
+ # The caller asked 1-channel audio, but the stream have multiple
169
+ # channels, downmix all channels.
170
+ wav = wav.mean(dim=-2, keepdim=True)
171
+ elif src_channels == 1:
172
+ # Case 2:
173
+ # The caller asked for multiple channels, but the input file have
174
+ # one single channel, replicate the audio over all channels.
175
+ wav = wav.expand(*shape, channels, length)
176
+ elif src_channels >= channels:
177
+ # Case 3:
178
+ # The caller asked for multiple channels, and the input file have
179
+ # more channels than requested. In that case return the first channels.
180
+ wav = wav[..., :channels, :]
181
+ else:
182
+ # Case 4: What is a reasonable choice here?
183
+ raise ValueError('The audio file has less channels than requested but is not mono.')
184
+ return wav
185
+
186
+
187
+ def convert_audio(wav, from_samplerate, to_samplerate, channels):
188
+ """Convert audio from a given samplerate to a target one and target number of channels."""
189
+ wav = convert_audio_channels(wav, channels)
190
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
191
+
192
+
193
+ def i16_pcm(wav):
194
+ """Convert audio to 16 bits integer PCM format."""
195
+ if wav.dtype.is_floating_point:
196
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
197
+ else:
198
+ return wav
199
+
200
+
201
+ def f32_pcm(wav):
202
+ """Convert audio to float 32 bits PCM format."""
203
+ if wav.dtype.is_floating_point:
204
+ return wav
205
+ else:
206
+ return wav.float() / (2**15 - 1)
207
+
208
+
209
+ def as_dtype_pcm(wav):
210
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
211
+ if wav.dtype.is_floating_point:
212
+ return f32_pcm(wav)
213
+ else:
214
+ return i16_pcm(wav)
215
+
216
+
217
+ def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
218
+ """Save given audio as mp3. This should work on all OSes."""
219
+ c, _ = wav.shape
220
+ wav = i16_pcm(wav)
221
+ encoder = lameenc.Encoder()
222
+ encoder.set_bit_rate(bitrate)
223
+ encoder.set_in_sample_rate(samplerate)
224
+ encoder.set_channels(c)
225
+ encoder.set_quality(2) # 2-highest, 7-fastest
226
+ if not verbose:
227
+ encoder.silence()
228
+ wav = wav.data.cpu()
229
+ wav = wav.transpose(0, 1).numpy()
230
+ mp3_data = encoder.encode(wav.tobytes())
231
+ mp3_data += encoder.flush()
232
+ with open(path, "wb") as f:
233
+ f.write(mp3_data)
234
+
235
+
236
+ def prevent_clip(wav, mode='rescale'):
237
+ """
238
+ different strategies for avoiding raw clipping.
239
+ """
240
+ if mode is None or mode == 'none':
241
+ return wav
242
+ assert wav.dtype.is_floating_point, "too late for clipping"
243
+ if mode == 'rescale':
244
+ wav = wav / max(1.01 * wav.abs().max(), 1)
245
+ elif mode == 'clamp':
246
+ wav = wav.clamp(-0.99, 0.99)
247
+ elif mode == 'tanh':
248
+ wav = torch.tanh(wav)
249
+ else:
250
+ raise ValueError(f"Invalid mode {mode}")
251
+ return wav
252
+
253
+
254
+ def save_audio(wav: torch.Tensor,
255
+ path: tp.Union[str, Path],
256
+ samplerate: int,
257
+ bitrate: int = 320,
258
+ clip: tp.Union[str] = 'rescale',
259
+ bits_per_sample: tp.Union[int] = 16,
260
+ as_float: bool = False):
261
+ """Save audio file, automatically preventing clipping if necessary
262
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
263
+ will save as mp3 with the given `bitrate`.
264
+ """
265
+ wav = prevent_clip(wav, mode=clip)
266
+ path = Path(path)
267
+ suffix = path.suffix.lower()
268
+ if suffix == ".mp3":
269
+ encode_mp3(wav, path, samplerate, bitrate, verbose=True)
270
+ elif suffix == ".wav":
271
+ if as_float:
272
+ bits_per_sample = 32
273
+ encoding = 'PCM_F'
274
+ else:
275
+ encoding = 'PCM_S'
276
+ ta.save(str(path), wav, sample_rate=samplerate,
277
+ encoding=encoding, bits_per_sample=bits_per_sample)
278
+ elif suffix == ".flac":
279
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
280
+ else:
281
+ raise ValueError(f"Invalid suffix for path: {suffix}")
282
+
283
+
284
+ def load_track(track, audio_channels, samplerate):
285
+ errors = {}
286
+ wav = None
287
+
288
+ try:
289
+ wav = AudioFile(track).read(
290
+ streams=0,
291
+ samplerate=samplerate,
292
+ channels=audio_channels)
293
+ except sp.CalledProcessError:
294
+ errors['ffmpeg'] = 'FFmpeg could not read the file.'
295
+
296
+ if wav is None:
297
+ try:
298
+ wav, sr = ta.load(str(track))
299
+ except RuntimeError as err:
300
+ errors['torchaudio'] = err.args[0]
301
+ else:
302
+ wav = convert_audio(wav, sr, samplerate, audio_channels)
303
+
304
+ return wav, errors
codeclm/tokenizer/Flow1dVAE/cal_token_stat.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kaldiio
2
+ from tqdm import tqdm
3
+ import torch
4
+
5
+ if __name__ == "__main__":
6
+ bar = torch.zeros(1, 16384)
7
+ with open('token.scp', 'r') as f:
8
+ for item_idx, line in tqdm(enumerate(f)):
9
+ idx, pos = line.strip().split()
10
+ codes = kaldiio.load_mat(pos)
11
+ for i0 in range(codes.shape[-1]):
12
+ bar[0, codes[0, 0, i0]] += 1
13
+ if(item_idx % 1000 == 0):
14
+ print("=========")
15
+ print(1 - (bar[0]==0).sum() / bar.shape[-1])
16
+ print("=========")
17
+ print("=========")
18
+ print(1 - (bar[0]==0).sum() / bar.shape[-1])
19
+ print("=========")
codeclm/tokenizer/Flow1dVAE/compare_model_weight.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from safetensors.torch import load_file
4
+
5
+ if __name__ == "__main__":
6
+ m0, m1 = sys.argv[1], sys.argv[2]
7
+ m0 = load_file(m0)
8
+ m1 = load_file(m1)
9
+
10
+ ks = [k for k in m0.keys() if 'bestrq' in k]
11
+ for k in ks:
12
+ print(k, (m0[k] - m1[k]).abs().sum())
13
+
codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "Transformer2DModel",
3
+ "_diffusers_version": "0.22.0.dev0",
4
+ "activation_fn": "gelu-approximate",
5
+ "attention_bias": true,
6
+ "attention_head_dim": 72,
7
+ "attention_type": "default",
8
+ "cross_attention_dim": null,
9
+ "double_self_attention": false,
10
+ "dropout": 0.0,
11
+ "in_channels": 96,
12
+ "norm_elementwise_affine": false,
13
+ "norm_eps": 1e-06,
14
+ "norm_num_groups": 32,
15
+ "norm_type": "ada_norm_single",
16
+ "num_attention_heads": 22,
17
+ "num_embeds_ada_norm": 1000,
18
+ "num_layers": 24,
19
+ "num_vector_embeds": null,
20
+ "only_cross_attention": false,
21
+ "out_channels": 32,
22
+ "patch_size": 2,
23
+ "sample_size": 384,
24
+ "upcast_attention": false,
25
+ "use_linear_projection": false
26
+ }
codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.8.0",
4
+ "beta_end": 0.02,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.0015,
7
+ "clip_sample": false,
8
+ "num_train_timesteps": 1000,
9
+ "prediction_type": "sample",
10
+ "set_alpha_to_one": false,
11
+ "skip_prk_steps": true,
12
+ "steps_offset": 1,
13
+ "trained_betas": null
14
+ }
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,torchaudio
2
+ import os,sys,json
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+
6
+ #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
7
+ from generate_septoken import Tango as Tango_sep
8
+ from generate_2rvq import Tango as Tango_1x2
9
+ import kaldiio
10
+ from kaldiio import WriteHelper
11
+ from audio import AudioFile
12
+
13
+ from demucs.models.pretrained import get_model_from_yaml
14
+ from filelock import FileLock
15
+
16
+ # os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
17
+ class Separator:
18
+ def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
19
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
20
+ self.device = torch.device(f"cuda:{gpu_id}")
21
+ else:
22
+ self.device = torch.device("cpu")
23
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
24
+
25
+ def init_demucs_model(self, model_path, config_path):
26
+ model = get_model_from_yaml(config_path, model_path)
27
+ model.to(self.device)
28
+ model.eval()
29
+ return model
30
+
31
+ def load_audio(self, f):
32
+ a, fs = torchaudio.load(f)
33
+ if (fs != 48000):
34
+ a = torchaudio.functional.resample(a, fs, 48000)
35
+ # if a.shape[-1] >= 48000*10:
36
+ # a = a[..., :48000*10]
37
+ # else:
38
+ # a = torch.cat([a, a], -1)
39
+ # return a[:, 0:48000*10]
40
+ return a
41
+
42
+ def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
43
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
44
+ output_paths = []
45
+ # lock_path = os.path.join(output_dir, f"{name}.lock")
46
+ # with FileLock(lock_path): # 加一个避免多卡访问时死锁
47
+ for stem in self.demucs_model.sources:
48
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
49
+ if os.path.exists(output_path):
50
+ output_paths.append(output_path)
51
+ if len(output_paths) == 1: # 4
52
+ # drums_path, bass_path, other_path, vocal_path = output_paths
53
+ vocal_path = output_paths[0]
54
+ else:
55
+ lock_path = os.path.join(output_dir, f"{name}_separate.lock")
56
+ with FileLock(lock_path):
57
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
58
+ full_audio = self.load_audio(audio_path)
59
+ vocal_audio = self.load_audio(vocal_path)
60
+ minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
61
+ # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
62
+ bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
63
+ for path in [drums_path, bass_path, other_path, vocal_path]:
64
+ os.remove(path)
65
+ return full_audio, vocal_audio, bgm_audio
66
+
67
+ def read_wav(fname, sample_rate=48_000):
68
+ try:
69
+ orig_samples, fs = torchaudio.load(fname)
70
+ except:
71
+ af = AudioFile(fname)
72
+ orig_samples = af.read()
73
+ fs = af.samplerate()
74
+ orig_samples = orig_samples[0]
75
+ if(fs!=sample_rate):
76
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
77
+ fs = sample_rate
78
+ if orig_samples.shape[0] == 1:
79
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
80
+ return orig_samples
81
+
82
+ if __name__ == "__main__":
83
+ # Define Model
84
+ json_path = sys.argv[1]
85
+
86
+ mus_infos = []
87
+ with open(json_path) as f:
88
+ for line in f:
89
+ item = json.loads(line)
90
+ mus_infos.append(item)
91
+
92
+ tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors")
93
+ tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
94
+ separator = Separator()
95
+
96
+ # Feature extraction loop
97
+ # for i in tqdm(range(2000)):
98
+ first_time = True
99
+ for item in tqdm(mus_infos):
100
+ if(os.path.exists(item['path'])):
101
+ full_path = item['path']
102
+ else:
103
+ full_path = '/mnt/share/' + item['path']
104
+
105
+ full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path)
106
+
107
+ # full_tensor = read_wav(full_path)
108
+ # vocal_tensor = read_wav(vocal_path)
109
+ # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
110
+ # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
111
+ # bgm_tensor = full_tensor - vocal_tensor
112
+ codes_1x2 = tango_1x2.sound2code(full_tensor)
113
+ codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor)
114
+ codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy()
115
+ save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy')
116
+ assert save_path != full_path, (save_path, full_path)
117
+ np.save(save_path, codes)
118
+
119
+ if(first_time):
120
+ first_time = False
121
+ print(codes_vocal.shape, codes_bgm.shape)
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,torchaudio
2
+ import os,sys,json
3
+ from tqdm import tqdm
4
+
5
+ #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
+ from generate_septoken import Tango
7
+ import kaldiio
8
+ from kaldiio import WriteHelper
9
+ from audio import AudioFile
10
+
11
+ def read_wav(fname, sample_rate=48_000):
12
+ try:
13
+ orig_samples, fs = torchaudio.load(fname)
14
+ except:
15
+ af = AudioFile(fname)
16
+ orig_samples = af.read()
17
+ fs = af.samplerate()
18
+ orig_samples = orig_samples[0]
19
+ if(fs!=sample_rate):
20
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
21
+ fs = sample_rate
22
+ if orig_samples.shape[0] == 1:
23
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
24
+ return orig_samples
25
+
26
+ if __name__ == "__main__":
27
+ # Define Model
28
+ json_path = sys.argv[1]
29
+ outdir = sys.argv[2]
30
+
31
+ mus_infos = []
32
+ with open(json_path) as f:
33
+ for line in f:
34
+ item = json.loads(line)
35
+ mus_infos.append(item)
36
+
37
+ tango = Tango(model_path="./saved/model_septoken/model_2.safetensors")
38
+
39
+
40
+ # Feature extraction loop
41
+ # for i in tqdm(range(2000)):
42
+ first_time = True
43
+ with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm:
44
+ print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir))
45
+ print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir))
46
+ for item in tqdm(mus_infos):
47
+ try:
48
+ # if True:
49
+ idx = item['idx']
50
+ # print(idx)
51
+ if(os.path.exists(item['path'])):
52
+ full_path = item['path']
53
+ else:
54
+ full_path = '/mnt/share/' + item['path']
55
+ if(os.path.exists(item['vocal_path'])):
56
+ vocal_path = item['vocal_path']
57
+ bgm_paths = item['bgm_path']
58
+ else:
59
+ vocal_path = '/mnt/share/' + item['vocal_path']
60
+ bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']]
61
+ vocal_tensor = read_wav(vocal_path)
62
+ # full_tensor = read_wav(full_path)
63
+ # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
64
+ # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
65
+ # bgm_tensor = full_tensor - vocal_tensor
66
+ bgm_tensor = sum([read_wav(p) for p in bgm_paths])
67
+ codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
68
+ writer_vocal(str(idx), codes_vocal.cpu())
69
+ writer_bgm(str(idx), codes_bgm.cpu())
70
+ if(first_time):
71
+ first_time = False
72
+ print(codes_vocal.shape, codes_bgm.shape)
73
+ except:
74
+ print(item['vocal_path'])
75
+ print(item['bgm_path'])
76
+ continue
77
+
78
+ # idx = item['idx']
79
+ # # print(idx)
80
+ # full_path = item['path']
81
+ # vocal_path = item['vocal_path']
82
+ # bgm_paths = item['bgm_path']
83
+ # full_tensor = read_wav(full_path)
84
+ # vocal_tensor = read_wav(vocal_path)
85
+ # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
86
+ # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
87
+ # bgm_tensor = full_tensor - vocal_tensor
88
+ # codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
89
+ # writer_vocal(str(idx), codes_vocal.cpu())
90
+ # writer_bgm(str(idx), codes_bgm.cpu())
91
+ # if(first_time):
92
+ # first_time = False
93
+ # print(codes_vocal.shape, codes_bgm.shape)
94
+
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,torchaudio
2
+ import os,sys,json
3
+ from tqdm import tqdm
4
+
5
+ #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
+ from generate_2rvq import Tango
7
+ import kaldiio
8
+ from kaldiio import WriteHelper
9
+ import torch
10
+ import subprocess
11
+ import time
12
+ import sys
13
+
14
+ def get_gpu_memory():
15
+ _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
16
+
17
+ ACCEPTABLE_AVAILABLE_MEMORY = 1024
18
+ COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
19
+ memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
20
+ memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
21
+ return memory_free_values
22
+
23
+ if __name__ == "__main__":
24
+ # Define Model
25
+ json_path = sys.argv[1]
26
+ outdir = sys.argv[2]
27
+
28
+ gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
29
+ while True:
30
+ free_mem = get_gpu_memory()
31
+ free_mem = free_mem[gpu_idx]
32
+ if(free_mem > 25_000):
33
+ print("GPU memory {}, run matrix cal".format(free_mem))
34
+ break
35
+ else:
36
+ print("GPU memory {}, sleep 1min".format(free_mem))
37
+ time.sleep(60)
38
+
39
+ mus_infos = []
40
+ with open(json_path) as f:
41
+ for line in f:
42
+ item = json.loads(line)
43
+ mus_infos.append(item)
44
+
45
+ tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
46
+
47
+
48
+ # Feature extraction loop
49
+ # for i in tqdm(range(2000)):
50
+ with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
51
+ print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
52
+ for item in tqdm(mus_infos):
53
+ try:
54
+ # if True:
55
+ idx = item['idx']
56
+ # print(idx)
57
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
58
+ if(os.path.exists(item['path'])):
59
+ codes = tango.file2code(item['path'])
60
+ else:
61
+ codes = tango.file2code('/mnt/share/' + item['path'])
62
+ writer(str(idx), codes.cpu())
63
+ except:
64
+ print(item['path'])
65
+ continue
66
+ # idx = item['idx']
67
+ # # print(idx)
68
+ # with torch.autocast(device_type="cuda", dtype=torch.float16):
69
+ # codes = tango.file2code(item['path'])
70
+ # writer(str(idx), codes.cpu())
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,torchaudio
2
+ import os,sys,json
3
+ from tqdm import tqdm
4
+
5
+ #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
+ from generate_4rvq import Tango
7
+ import kaldiio
8
+ from kaldiio import WriteHelper
9
+
10
+ if __name__ == "__main__":
11
+ # Define Model
12
+ json_path = sys.argv[1]
13
+ outdir = sys.argv[2]
14
+
15
+ mus_infos = []
16
+ with open(json_path) as f:
17
+ for line in f:
18
+ item = json.loads(line)
19
+ mus_infos.append(item)
20
+
21
+ tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
22
+
23
+
24
+ # Feature extraction loop
25
+ # for i in tqdm(range(2000)):
26
+ with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
27
+ print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
28
+ for item in tqdm(mus_infos):
29
+ try:
30
+ # if True:
31
+ idx = item['idx']
32
+ # print(idx)
33
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
34
+ if(os.path.exists(item['path'])):
35
+ codes = tango.file2code(item['path'])
36
+ else:
37
+ codes = tango.file2code('/mnt/share/' + item['path'])
38
+ writer(str(idx), codes.cpu())
39
+ except:
40
+ print(item['path'])
41
+ continue
42
+ # idx = item['idx']
43
+ # # print(idx)
44
+ # with torch.autocast(device_type="cuda", dtype=torch.float16):
45
+ # codes = tango.file2code(item['path'])
46
+ # writer(str(idx), codes.cpu())
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,torchaudio
2
+ import os,sys,json
3
+ from tqdm import tqdm
4
+
5
+ #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
+ from generate_4rvq import Tango
7
+ import kaldiio
8
+ from kaldiio import WriteHelper
9
+ import torch
10
+ import subprocess
11
+ import time
12
+ import sys
13
+
14
+ def get_gpu_memory():
15
+ _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
16
+
17
+ ACCEPTABLE_AVAILABLE_MEMORY = 1024
18
+ COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
19
+ memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
20
+ memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
21
+ return memory_free_values
22
+
23
+ if __name__ == "__main__":
24
+ # Define Model
25
+ json_path = sys.argv[1]
26
+ outdir = sys.argv[2]
27
+ ds = int(sys.argv[3])
28
+
29
+ gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
30
+ while True:
31
+ free_mem = get_gpu_memory()
32
+ free_mem = free_mem[gpu_idx]
33
+ if(free_mem > 25_000):
34
+ print("GPU memory {}, run matrix cal".format(free_mem))
35
+ break
36
+ else:
37
+ print("GPU memory {}, sleep 1min".format(free_mem))
38
+ time.sleep(60)
39
+
40
+ mus_infos = []
41
+ with open(json_path) as f:
42
+ for line in f:
43
+ item = json.loads(line)
44
+ mus_infos.append(item)
45
+
46
+ tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
47
+
48
+
49
+ # Feature extraction loop
50
+ # for i in tqdm(range(2000)):
51
+ with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
52
+ print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
53
+ bar = torch.zeros(4, 16384)
54
+ for item_idx, item in tqdm(enumerate(mus_infos)):
55
+ try:
56
+ # if True:
57
+ idx = item['idx']
58
+ # print(idx)
59
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
60
+ if(os.path.exists(item['path'])):
61
+ codes = tango.file2code_ds(item['path'], ds)
62
+ else:
63
+ codes = tango.file2code_ds('/mnt/share/' + item['path'], ds)
64
+ codes = codes.cpu()
65
+ writer(str(idx), codes)
66
+ for i0 in range(codes.shape[-1]):
67
+ bar[0, codes[0, 0, i0]] += 1
68
+ bar[1, codes[0, 1, i0]] += 1
69
+ bar[2, codes[0, 2, i0]] += 1
70
+ bar[3, codes[0, 3, i0]] += 1
71
+ except Exception as e:
72
+ print(item['path'])
73
+ # print(e.message, e.args)
74
+ # exit(1)
75
+ continue
76
+
77
+ if(item_idx % 1000 == 0):
78
+ print("=========")
79
+ print(1 - (bar[0]==0).sum() / bar.shape[-1])
80
+ print("=========")
81
+
82
+ # idx = item['idx']
83
+ # # print(idx)
84
+ # with torch.autocast(device_type="cuda", dtype=torch.float16):
85
+ # codes = tango.file2code(item['path'])
86
+ # writer(str(idx), codes.cpu())
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ from model_1rvq import PromptCondAudioDiffusion
5
+ from diffusers import DDIMScheduler, DDPMScheduler
6
+ import torchaudio
7
+ import librosa
8
+ import os
9
+ import math
10
+ import numpy as np
11
+ from tools.get_1dvae_large import get_model
12
+ import tools.torch_tools as torch_tools
13
+ from safetensors.torch import load_file
14
+
15
+ class Tango:
16
+ def __init__(self, \
17
+ model_path, \
18
+ vae_config="",
19
+ vae_model="",
20
+ layer_num=6, \
21
+ device="cuda:0"):
22
+
23
+ self.sample_rate = 48000
24
+ scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
25
+ self.device = device
26
+
27
+ self.vae = get_model(vae_config, vae_model)
28
+ self.vae = self.vae.to(device)
29
+ self.vae=self.vae.eval()
30
+ self.layer_num = layer_num
31
+
32
+ self.MAX_DURATION = 360
33
+ main_config = {
34
+ "num_channels":32,
35
+ "unet_model_name":None,
36
+ "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
37
+ "snr_gamma":None,
38
+ }
39
+ self.model = PromptCondAudioDiffusion(**main_config).to(device)
40
+ if model_path.endswith(".safetensors"):
41
+ main_weights = load_file(model_path)
42
+ else:
43
+ main_weights = torch.load(model_path, map_location=device)
44
+ self.model.load_state_dict(main_weights, strict=False)
45
+ print ("Successfully loaded checkpoint from:", model_path)
46
+
47
+ self.model.eval()
48
+ self.model.init_device_dtype(torch.device(device), torch.float32)
49
+ print("scaling factor: ", self.model.normfeat.std)
50
+
51
+ # self.scheduler = DDIMScheduler.from_pretrained( \
52
+ # scheduler_name, subfolder="scheduler")
53
+ # self.scheduler = DDPMScheduler.from_pretrained( \
54
+ # scheduler_name, subfolder="scheduler")
55
+ # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
56
+
57
+ # def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"):
58
+ # """ Genrate audio without condition. """
59
+ # with torch.no_grad():
60
+ # if(orig_samples.shape[-1]<int(duration*48000)+480):
61
+ # orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000+480)-orig_samples.shape[-1], \
62
+ # dtype=orig_samples.dtype, device=orig_samples.device)], -1)
63
+
64
+ # orig_samples = orig_samples.to(self.device)
65
+ # saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
66
+ # orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
67
+ # max_volume = orig_samples.abs().max(dim=-1)[0]
68
+ # orig_samples = orig_samples/max_volume.unsqueeze(-1)
69
+ # print("orig_samples.shape", orig_samples.shape)
70
+
71
+ # latent_length = int((st_et[1] - st_et[0]) * 48000) // 1920 + 1
72
+
73
+ # true_latents = self.vae.encode_audio(orig_samples).permute(0,2,1)
74
+
75
+ # print("true_latents.shape", true_latents.shape)
76
+ # latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
77
+ # print("latents.shape", latents.shape)
78
+ # print("latent_length", latent_length)
79
+
80
+ # latents = latents[:,:,:latent_length]
81
+ # audio = self.vae.decode_audio(latents)
82
+ # print("audio.shape:",audio.shape)
83
+ # audio = torch.cat((audio, torch.zeros(audio.shape[0],audio.shape[1], 48000*40 - audio.shape[-1], dtype=audio.dtype, device=audio.device)), dim=-1)
84
+ # print("audio.shape:",audio.shape)
85
+ # # audio = audio.reshape(audio.shape[0]//2, 2, -1)
86
+ # # audio = torch.from_numpy(audio)
87
+
88
+ # if(saved_samples.shape[-1]<audio.shape[-1]):
89
+ # saved_samples = torch.cat([saved_samples, torch.zeros(saved_samples.shape[0], audio.shape[-1]-saved_samples.shape[-1], dtype=saved_samples.dtype, device=saved_samples.device)],-1)
90
+ # else:
91
+ # saved_samples = saved_samples[:,0:audio.shape[-1]]
92
+ # output = torch.cat([saved_samples.detach().cpu(),audio[0].detach().cpu()],0)
93
+ # return output
94
+
95
+ @torch.no_grad()
96
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
97
+ def sound2code(self, orig_samples, batch_size=3):
98
+ if(orig_samples.ndim == 2):
99
+ audios = orig_samples.unsqueeze(0).to(self.device)
100
+ elif(orig_samples.ndim == 3):
101
+ audios = orig_samples.to(self.device)
102
+ else:
103
+ assert orig_samples.ndim in (2,3), orig_samples.shape
104
+ audios = self.preprocess_audio(audios)
105
+ audios = audios.squeeze(0)
106
+ orig_length = audios.shape[-1]
107
+ min_samples = int(40 * self.sample_rate)
108
+ # 40秒对应10个token
109
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
110
+ print("output_len: ", output_len)
111
+
112
+ while(audios.shape[-1] < min_samples):
113
+ audios = torch.cat([audios, audios], -1)
114
+ int_max_len=audios.shape[-1]//min_samples+1
115
+ audios = torch.cat([audios, audios], -1)
116
+ audios=audios[:,:int(int_max_len*(min_samples))]
117
+ codes_list=[]
118
+
119
+ audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
120
+
121
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
122
+ # import pdb; pdb.set_trace()
123
+ codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
124
+ codes_list.append(torch.cat(codes, 1))
125
+ # print("codes_list",codes_list[0].shape)
126
+
127
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
128
+ codes=codes[:,:,:output_len]
129
+
130
+ return codes
131
+
132
+ @torch.no_grad()
133
+ def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
134
+ codes = codes.to(self.device)
135
+
136
+ min_samples = int(duration * 25) # 40ms per frame
137
+ hop_samples = min_samples // 4 * 3
138
+ ovlp_samples = min_samples - hop_samples
139
+ hop_frames = hop_samples
140
+ ovlp_frames = ovlp_samples
141
+ first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
142
+ first_latent_length = 0
143
+ first_latent_codes_length = 0
144
+
145
+ if(isinstance(prompt, torch.Tensor)):
146
+ # prepare prompt
147
+ prompt = prompt.to(self.device)
148
+ if(prompt.ndim == 3):
149
+ assert prompt.shape[0] == 1, prompt.shape
150
+ prompt = prompt[0]
151
+ elif(prompt.ndim == 1):
152
+ prompt = prompt.unsqueeze(0).repeat(2,1)
153
+ elif(prompt.ndim == 2):
154
+ if(prompt.shape[0] == 1):
155
+ prompt = prompt.repeat(2,1)
156
+
157
+ if(prompt.shape[-1] < int(30 * self.sample_rate)):
158
+ # if less than 30s, just choose the first 10s
159
+ prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
160
+ else:
161
+ # else choose from 20.48s which might includes verse or chorus
162
+ prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
163
+
164
+ true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
165
+ # print("true_latent.shape", true_latent.shape)
166
+ # print("first_latent.shape", first_latent.shape)
167
+ #true_latent.shape torch.Size([1, 250, 64])
168
+ # first_latent.shape torch.Size([1, 1000, 64])
169
+
170
+ first_latent[:,0:true_latent.shape[1],:] = true_latent
171
+ first_latent_length = true_latent.shape[1]
172
+ first_latent_codes = self.sound2code(prompt)
173
+ first_latent_codes_length = first_latent_codes.shape[-1]
174
+ codes = torch.cat([first_latent_codes, codes], -1)
175
+
176
+
177
+
178
+
179
+ codes_len= codes.shape[-1]
180
+ target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
181
+ # target_len = int(codes_len / 100 * 4 * self.sample_rate)
182
+ # code repeat
183
+ if(codes_len < min_samples):
184
+ while(codes.shape[-1] < min_samples):
185
+ codes = torch.cat([codes, codes], -1)
186
+ codes = codes[:,:,0:min_samples]
187
+ codes_len = codes.shape[-1]
188
+ if((codes_len - ovlp_samples) % hop_samples > 0):
189
+ len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
190
+ while(codes.shape[-1] < len_codes):
191
+ codes = torch.cat([codes, codes], -1)
192
+ codes = codes[:,:,0:len_codes]
193
+ latent_length = min_samples
194
+ latent_list = []
195
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
196
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
197
+ for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
198
+ codes_input=[]
199
+ codes_input.append(codes[:,:,sinx:sinx+min_samples])
200
+ if(sinx == 0):
201
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
202
+ incontext_length = first_latent_length
203
+ latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
204
+ latent_list.append(latents)
205
+ else:
206
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
207
+ true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
208
+ print("true_latent.shape", true_latent.shape)
209
+ len_add_to_1000 = min_samples - true_latent.shape[-2]
210
+ # print("len_add_to_1000", len_add_to_1000)
211
+ # exit()
212
+ incontext_length = true_latent.shape[-2]
213
+ true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
214
+ latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
215
+ latent_list.append(latents)
216
+
217
+ latent_list = [l.float() for l in latent_list]
218
+ latent_list[0] = latent_list[0][:,:,first_latent_length:]
219
+ min_samples = int(min_samples * self.sample_rate // 1000 * 40)
220
+ hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
221
+ ovlp_samples = min_samples - hop_samples
222
+ with torch.no_grad():
223
+ output = None
224
+ for i in range(len(latent_list)):
225
+ latent = latent_list[i]
226
+ cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
227
+
228
+ if output is None:
229
+ output = cur_output
230
+ else:
231
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
232
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
233
+ print("output.shape", output.shape)
234
+ print("ov_win.shape", ov_win.shape)
235
+ output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
236
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
237
+ output = output[:, 0:target_len]
238
+ return output
239
+
240
+ @torch.no_grad()
241
+ def preprocess_audio(self, input_audios, threshold=0.8):
242
+ assert len(input_audios.shape) == 3, input_audios.shape
243
+ nchan = input_audios.shape[1]
244
+ input_audios = input_audios.reshape(input_audios.shape[0], -1)
245
+ norm_value = torch.ones_like(input_audios[:,0])
246
+ max_volume = input_audios.abs().max(dim=-1)[0]
247
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
248
+ return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
249
+
250
+ @torch.no_grad()
251
+ def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
252
+ codes = self.sound2code(sound)
253
+ # print(codes.shape)
254
+ wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
255
+ # print(fname, wave.shape)
256
+ return wave
257
+
258
+ @torch.no_grad()
259
+ def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False):
260
+ min_samples = int(40 * 25) # 40ms per frame
261
+ hop_samples = min_samples // 4 * 3
262
+ ovlp_samples = min_samples - hop_samples
263
+ dur = 20
264
+
265
+ latent_list = []
266
+ for i in range(0, sound.shape[-1], dur*48000):
267
+ if(i+dur*2*48000 > sound.shape[-1]):
268
+ latent = tango.vae.encode_audio(sound.cuda()[None,:,i:])
269
+ break
270
+ else:
271
+ latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000])
272
+ latent_list.append(latent)
273
+
274
+ output = None
275
+ for i in range(len(latent_list)):
276
+ print(i)
277
+ latent = latent_list[i]
278
+ cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
279
+ if output is None:
280
+ output = cur_output
281
+ else:
282
+ output = torch.cat([output, cur_output], -1)
283
+ return output
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ from model_2rvq import PromptCondAudioDiffusion
5
+ from diffusers import DDIMScheduler, DDPMScheduler
6
+ import torchaudio
7
+ import librosa
8
+ import os
9
+ import math
10
+ import numpy as np
11
+ # from tools.get_mulan import get_mulan
12
+ from tools.get_1dvae_large import get_model
13
+ import tools.torch_tools as torch_tools
14
+ from safetensors.torch import load_file
15
+ from audio import AudioFile
16
+ import kaldiio
17
+
18
+ class Tango:
19
+ def __init__(self, \
20
+ model_path, \
21
+ layer_num=6, \
22
+ rvq_num=1, \
23
+ device="cuda:0"):
24
+
25
+ self.sample_rate = 48000
26
+ scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
27
+ self.device = device
28
+
29
+ self.vae = get_model()
30
+ self.vae = self.vae.to(device)
31
+ self.vae=self.vae.eval()
32
+ self.layer_num = layer_num
33
+
34
+ self.MAX_DURATION = 360
35
+ main_config = {
36
+ "num_channels":32,
37
+ "unet_model_name":None,
38
+ "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
39
+ "snr_gamma":None,
40
+ }
41
+ self.rvq_num = rvq_num
42
+ # print("rvq_num: ", self.rvq_num)
43
+ # exit()
44
+ self.model = PromptCondAudioDiffusion(**main_config).to(device)
45
+ if model_path.endswith(".safetensors"):
46
+ main_weights = load_file(model_path)
47
+ else:
48
+ main_weights = torch.load(model_path, map_location=device)
49
+ self.model.load_state_dict(main_weights, strict=False)
50
+ print ("Successfully loaded checkpoint from:", model_path)
51
+
52
+ self.model.eval()
53
+ self.model.init_device_dtype(torch.device(device), torch.float32)
54
+ print("scaling factor: ", self.model.normfeat.std)
55
+
56
+ # self.scheduler = DDIMScheduler.from_pretrained( \
57
+ # scheduler_name, subfolder="scheduler")
58
+ # self.scheduler = DDPMScheduler.from_pretrained( \
59
+ # scheduler_name, subfolder="scheduler")
60
+ print("Successfully loaded inference scheduler from {}".format(scheduler_name))
61
+
62
+
63
+
64
+ @torch.no_grad()
65
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
66
+ def sound2code(self, orig_samples, batch_size=8):
67
+ if(orig_samples.ndim == 2):
68
+ audios = orig_samples.unsqueeze(0).to(self.device)
69
+ elif(orig_samples.ndim == 3):
70
+ audios = orig_samples.to(self.device)
71
+ else:
72
+ assert orig_samples.ndim in (2,3), orig_samples.shape
73
+ audios = self.preprocess_audio(audios)
74
+ audios = audios.squeeze(0)
75
+ orig_length = audios.shape[-1]
76
+ min_samples = int(40 * self.sample_rate)
77
+ # 40秒对应10个token
78
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
79
+ # print("output_len: ", output_len)
80
+
81
+ while(audios.shape[-1] < min_samples):
82
+ audios = torch.cat([audios, audios], -1)
83
+ int_max_len=audios.shape[-1]//min_samples+1
84
+ audios = torch.cat([audios, audios], -1)
85
+ audios=audios[:,:int(int_max_len*(min_samples))]
86
+ codes_list=[]
87
+
88
+ audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
89
+
90
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
91
+ # import pdb; pdb.set_trace()
92
+ codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
93
+ # print("codes",codes[0].shape)
94
+
95
+ codes_list.append(torch.cat(codes, 1))
96
+ # print("codes_list",codes_list[0].shape)
97
+
98
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
99
+ codes=codes[:,:,:output_len]
100
+
101
+ return codes
102
+
103
+ @torch.no_grad()
104
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
105
+ def sound2code_ds(self, orig_samples, ds, batch_size=8):
106
+ if(orig_samples.ndim == 2):
107
+ audios = orig_samples.unsqueeze(0).to(self.device)
108
+ elif(orig_samples.ndim == 3):
109
+ audios = orig_samples.to(self.device)
110
+ else:
111
+ assert orig_samples.ndim in (2,3), orig_samples.shape
112
+ audios = self.preprocess_audio(audios)
113
+ audios = audios.squeeze(0)
114
+ orig_length = audios.shape[-1]
115
+ min_samples = int(40 * self.sample_rate)
116
+ # 40秒对应10个token
117
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
118
+ # print("output_len: ", output_len)
119
+
120
+ while(audios.shape[-1] < min_samples):
121
+ audios = torch.cat([audios, audios], -1)
122
+ int_max_len=audios.shape[-1]//min_samples+1
123
+ audios = torch.cat([audios, audios], -1)
124
+ audios=audios[:,:int(int_max_len*(min_samples))]
125
+ codes_list=[]
126
+
127
+ audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
128
+
129
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
130
+ # import pdb; pdb.set_trace()
131
+ codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
132
+ # print("codes",codes[0].shape)
133
+
134
+ codes_list.append(torch.cat(codes, 1))
135
+ # print("codes_list",codes_list[0].shape)
136
+
137
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
138
+ codes=codes[:,:,:output_len]
139
+
140
+ return codes
141
+
142
+ @torch.no_grad()
143
+ def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
144
+ codes = codes.to(self.device)
145
+
146
+ min_samples = duration * 25 # 40ms per frame
147
+ hop_samples = min_samples // 4 * 3
148
+ ovlp_samples = min_samples - hop_samples
149
+ hop_frames = hop_samples
150
+ ovlp_frames = ovlp_samples
151
+ first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
152
+ first_latent_length = 0
153
+ first_latent_codes_length = 0
154
+
155
+ if(isinstance(prompt, torch.Tensor)):
156
+ # prepare prompt
157
+ prompt = prompt.to(self.device)
158
+ if(prompt.ndim == 3):
159
+ assert prompt.shape[0] == 1, prompt.shape
160
+ prompt = prompt[0]
161
+ elif(prompt.ndim == 1):
162
+ prompt = prompt.unsqueeze(0).repeat(2,1)
163
+ elif(prompt.ndim == 2):
164
+ if(prompt.shape[0] == 1):
165
+ prompt = prompt.repeat(2,1)
166
+
167
+ if(prompt.shape[-1] < int(30 * self.sample_rate)):
168
+ # if less than 30s, just choose the first 10s
169
+ prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
170
+ else:
171
+ # else choose from 20.48s which might includes verse or chorus
172
+ prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
173
+
174
+ true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
175
+ # print("true_latent.shape", true_latent.shape)
176
+ # print("first_latent.shape", first_latent.shape)
177
+ #true_latent.shape torch.Size([1, 250, 64])
178
+ # first_latent.shape torch.Size([1, 1000, 64])
179
+
180
+ first_latent[:,0:true_latent.shape[1],:] = true_latent
181
+ first_latent_length = true_latent.shape[1]
182
+ first_latent_codes = self.sound2code(prompt)
183
+ first_latent_codes_length = first_latent_codes.shape[-1]
184
+ codes = torch.cat([first_latent_codes, codes], -1)
185
+
186
+ codes_len= codes.shape[-1]
187
+ target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
188
+ # target_len = int(codes_len / 100 * 4 * self.sample_rate)
189
+ # code repeat
190
+ if(codes_len < min_samples):
191
+ while(codes.shape[-1] < min_samples):
192
+ codes = torch.cat([codes, codes], -1)
193
+ codes = codes[:,:,0:min_samples]
194
+ codes_len = codes.shape[-1]
195
+ if((codes_len - ovlp_samples) % hop_samples > 0):
196
+ len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
197
+ while(codes.shape[-1] < len_codes):
198
+ codes = torch.cat([codes, codes], -1)
199
+ codes = codes[:,:,0:len_codes]
200
+ latent_length = min_samples
201
+ latent_list = []
202
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
203
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
204
+ for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
205
+ codes_input=[]
206
+ codes_input.append(codes[:,:,sinx:sinx+min_samples])
207
+ if(sinx == 0):
208
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
209
+ incontext_length = first_latent_length
210
+ latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
211
+ latent_list.append(latents)
212
+ else:
213
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
214
+ true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
215
+ print("true_latent.shape", true_latent.shape)
216
+ len_add_to_1000 = 1000 - true_latent.shape[-2]
217
+ # print("len_add_to_1000", len_add_to_1000)
218
+ # exit()
219
+ incontext_length = true_latent.shape[-2]
220
+ true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
221
+ latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
222
+ latent_list.append(latents)
223
+
224
+ latent_list = [l.float() for l in latent_list]
225
+ latent_list[0] = latent_list[0][:,:,first_latent_length:]
226
+ min_samples = int(min_samples * self.sample_rate // 1000 * 40)
227
+ hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
228
+ ovlp_samples = min_samples - hop_samples
229
+ with torch.no_grad():
230
+ output = None
231
+ for i in range(len(latent_list)):
232
+ latent = latent_list[i]
233
+ cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
234
+
235
+ if output is None:
236
+ output = cur_output
237
+ else:
238
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
239
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
240
+ print("output.shape", output.shape)
241
+ print("ov_win.shape", ov_win.shape)
242
+ output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
243
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
244
+ output = output[:, 0:target_len]
245
+ return output
246
+
247
+ @torch.no_grad()
248
+ def preprocess_audio(self, input_audios, threshold=0.8):
249
+ assert len(input_audios.shape) == 3, input_audios.shape
250
+ nchan = input_audios.shape[1]
251
+ input_audios = input_audios.reshape(input_audios.shape[0], -1)
252
+ norm_value = torch.ones_like(input_audios[:,0])
253
+ max_volume = input_audios.abs().max(dim=-1)[0]
254
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
255
+ return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
256
+
257
+ @torch.no_grad()
258
+ def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
259
+ codes = self.sound2code(sound)
260
+ # print(codes.shape)
261
+ # exit()
262
+ wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
263
+ # print(fname, wave.shape)
264
+ return wave
265
+
266
+ def file2code(self, fname):
267
+ try:
268
+ orig_samples, fs = torchaudio.load(fname)
269
+ except:
270
+ af = AudioFile(fname)
271
+ orig_samples = af.read()
272
+ fs = af.samplerate()
273
+ orig_samples = orig_samples[0]
274
+ if(fs!=self.sample_rate):
275
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
276
+ fs = self.sample_rate
277
+ if orig_samples.shape[0] == 1:
278
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
279
+ return self.sound2code(orig_samples)
280
+
281
+ def file2code_ds(self, fname, ds):
282
+ try:
283
+ orig_samples, fs = torchaudio.load(fname)
284
+ except:
285
+ af = AudioFile(fname)
286
+ orig_samples = af.read()
287
+ fs = af.samplerate()
288
+ orig_samples = orig_samples[0]
289
+ if(fs!=self.sample_rate):
290
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
291
+ fs = self.sample_rate
292
+ if orig_samples.shape[0] == 1:
293
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
294
+ return self.sound2code_ds(orig_samples, ds)
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ from model_4rvq import PromptCondAudioDiffusion
5
+ from diffusers import DDIMScheduler, DDPMScheduler
6
+ import torchaudio
7
+ import librosa
8
+ import os
9
+ import math
10
+ import numpy as np
11
+ # from tools.get_mulan import get_mulan
12
+ from tools.get_1dvae_large import get_model
13
+ import tools.torch_tools as torch_tools
14
+ from safetensors.torch import load_file
15
+ from audio import AudioFile
16
+
17
+ class Tango:
18
+ def __init__(self, \
19
+ model_path, \
20
+ layer_num=6, \
21
+ rvq_num=1, \
22
+ device="cuda:0"):
23
+
24
+ self.sample_rate = 48000
25
+ scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
26
+ self.device = device
27
+
28
+ self.vae = get_model()
29
+ self.vae = self.vae.to(device)
30
+ self.vae=self.vae.eval()
31
+ self.layer_num = layer_num
32
+
33
+ self.MAX_DURATION = 360
34
+ main_config = {
35
+ "num_channels":32,
36
+ "unet_model_name":None,
37
+ "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
38
+ "snr_gamma":None,
39
+ }
40
+ self.rvq_num = rvq_num
41
+ # print("rvq_num: ", self.rvq_num)
42
+ # exit()
43
+ self.model = PromptCondAudioDiffusion(**main_config).to(device)
44
+ if model_path.endswith(".safetensors"):
45
+ main_weights = load_file(model_path)
46
+ else:
47
+ main_weights = torch.load(model_path, map_location=device)
48
+ self.model.load_state_dict(main_weights, strict=False)
49
+ print ("Successfully loaded checkpoint from:", model_path)
50
+
51
+ self.model.eval()
52
+ self.model.init_device_dtype(torch.device(device), torch.float32)
53
+ print("scaling factor: ", self.model.normfeat.std)
54
+
55
+ # self.scheduler = DDIMScheduler.from_pretrained( \
56
+ # scheduler_name, subfolder="scheduler")
57
+ # self.scheduler = DDPMScheduler.from_pretrained( \
58
+ # scheduler_name, subfolder="scheduler")
59
+ print("Successfully loaded inference scheduler from {}".format(scheduler_name))
60
+
61
+
62
+
63
+ @torch.no_grad()
64
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
65
+ def sound2code(self, orig_samples, batch_size=8):
66
+ if(orig_samples.ndim == 2):
67
+ audios = orig_samples.unsqueeze(0).to(self.device)
68
+ elif(orig_samples.ndim == 3):
69
+ audios = orig_samples.to(self.device)
70
+ else:
71
+ assert orig_samples.ndim in (2,3), orig_samples.shape
72
+ audios = self.preprocess_audio(audios)
73
+ audios = audios.squeeze(0)
74
+ orig_length = audios.shape[-1]
75
+ min_samples = int(40 * self.sample_rate)
76
+ # 40秒对应10个token
77
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
78
+ # print("output_len: ", output_len)
79
+
80
+ while(audios.shape[-1] < min_samples):
81
+ audios = torch.cat([audios, audios], -1)
82
+ int_max_len=audios.shape[-1]//min_samples+1
83
+ audios = torch.cat([audios, audios], -1)
84
+ audios=audios[:,:int(int_max_len*(min_samples))]
85
+ codes_list=[]
86
+
87
+ audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
88
+
89
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
90
+ # import pdb; pdb.set_trace()
91
+ codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
92
+ # print("codes",codes[0].shape)
93
+
94
+ codes_list.append(torch.cat(codes, 1))
95
+ # print("codes_list",codes_list[0].shape)
96
+
97
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
98
+ codes=codes[:,:,:output_len]
99
+
100
+ return codes
101
+
102
+ @torch.no_grad()
103
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
104
+ def sound2code_ds(self, orig_samples, ds, batch_size=6):
105
+ if(orig_samples.ndim == 2):
106
+ audios = orig_samples.unsqueeze(0).to(self.device)
107
+ elif(orig_samples.ndim == 3):
108
+ audios = orig_samples.to(self.device)
109
+ else:
110
+ assert orig_samples.ndim in (2,3), orig_samples.shape
111
+ audios = self.preprocess_audio(audios)
112
+ audios = audios.squeeze(0)
113
+ orig_length = audios.shape[-1]
114
+ min_samples = int(40 * self.sample_rate)
115
+ # 40秒对应10个token
116
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
117
+ # print("output_len: ", output_len)
118
+
119
+ while(audios.shape[-1] < min_samples):
120
+ audios = torch.cat([audios, audios], -1)
121
+ int_max_len=audios.shape[-1]//min_samples+1
122
+ audios = torch.cat([audios, audios], -1)
123
+ audios=audios[:,:int(int_max_len*(min_samples))]
124
+ codes_list=[]
125
+
126
+ audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
127
+
128
+ for audio_inx in range(0, audio_input.shape[0], batch_size):
129
+ # import pdb; pdb.set_trace()
130
+ codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
131
+ # print("codes",codes[0].shape)
132
+
133
+ codes_list.append(torch.cat(codes, 1))
134
+ # print("codes_list",codes_list[0].shape)
135
+
136
+ codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
137
+ codes=codes[:,:,:output_len]
138
+
139
+ return codes
140
+
141
+ @torch.no_grad()
142
+ def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
143
+ codes = codes.to(self.device)
144
+
145
+ min_samples = duration * 25 # 40ms per frame
146
+ hop_samples = min_samples // 4 * 3
147
+ ovlp_samples = min_samples - hop_samples
148
+ hop_frames = hop_samples
149
+ ovlp_frames = ovlp_samples
150
+ first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
151
+ first_latent_length = 0
152
+ first_latent_codes_length = 0
153
+
154
+ if(isinstance(prompt, torch.Tensor)):
155
+ # prepare prompt
156
+ prompt = prompt.to(self.device)
157
+ if(prompt.ndim == 3):
158
+ assert prompt.shape[0] == 1, prompt.shape
159
+ prompt = prompt[0]
160
+ elif(prompt.ndim == 1):
161
+ prompt = prompt.unsqueeze(0).repeat(2,1)
162
+ elif(prompt.ndim == 2):
163
+ if(prompt.shape[0] == 1):
164
+ prompt = prompt.repeat(2,1)
165
+
166
+ if(prompt.shape[-1] < int(30 * self.sample_rate)):
167
+ # if less than 30s, just choose the first 10s
168
+ prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
169
+ else:
170
+ # else choose from 20.48s which might includes verse or chorus
171
+ prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
172
+
173
+ true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
174
+ # print("true_latent.shape", true_latent.shape)
175
+ # print("first_latent.shape", first_latent.shape)
176
+ #true_latent.shape torch.Size([1, 250, 64])
177
+ # first_latent.shape torch.Size([1, 1000, 64])
178
+
179
+ first_latent[:,0:true_latent.shape[1],:] = true_latent
180
+ first_latent_length = true_latent.shape[1]
181
+ first_latent_codes = self.sound2code(prompt)
182
+ first_latent_codes_length = first_latent_codes.shape[-1]
183
+ codes = torch.cat([first_latent_codes, codes], -1)
184
+
185
+ codes_len= codes.shape[-1]
186
+ target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
187
+ # target_len = int(codes_len / 100 * 4 * self.sample_rate)
188
+ # code repeat
189
+ if(codes_len < min_samples):
190
+ while(codes.shape[-1] < min_samples):
191
+ codes = torch.cat([codes, codes], -1)
192
+ codes = codes[:,:,0:min_samples]
193
+ codes_len = codes.shape[-1]
194
+ if((codes_len - ovlp_samples) % hop_samples > 0):
195
+ len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
196
+ while(codes.shape[-1] < len_codes):
197
+ codes = torch.cat([codes, codes], -1)
198
+ codes = codes[:,:,0:len_codes]
199
+ latent_length = min_samples
200
+ latent_list = []
201
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
202
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
203
+ for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
204
+ codes_input=[]
205
+ codes_input.append(codes[:,:,sinx:sinx+min_samples])
206
+ if(sinx == 0):
207
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
208
+ incontext_length = first_latent_length
209
+ latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
210
+ latent_list.append(latents)
211
+ else:
212
+ # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
213
+ true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
214
+ print("true_latent.shape", true_latent.shape)
215
+ len_add_to_1000 = 1000 - true_latent.shape[-2]
216
+ # print("len_add_to_1000", len_add_to_1000)
217
+ # exit()
218
+ incontext_length = true_latent.shape[-2]
219
+ true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
220
+ latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
221
+ latent_list.append(latents)
222
+
223
+ latent_list = [l.float() for l in latent_list]
224
+ latent_list[0] = latent_list[0][:,:,first_latent_length:]
225
+ min_samples = int(min_samples * self.sample_rate // 1000 * 40)
226
+ hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
227
+ ovlp_samples = min_samples - hop_samples
228
+ with torch.no_grad():
229
+ output = None
230
+ for i in range(len(latent_list)):
231
+ latent = latent_list[i]
232
+ cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
233
+
234
+ if output is None:
235
+ output = cur_output
236
+ else:
237
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
238
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
239
+ print("output.shape", output.shape)
240
+ print("ov_win.shape", ov_win.shape)
241
+ output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
242
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
243
+ output = output[:, 0:target_len]
244
+ return output
245
+
246
+ @torch.no_grad()
247
+ def preprocess_audio(self, input_audios, threshold=0.8):
248
+ assert len(input_audios.shape) == 3, input_audios.shape
249
+ nchan = input_audios.shape[1]
250
+ input_audios = input_audios.reshape(input_audios.shape[0], -1)
251
+ norm_value = torch.ones_like(input_audios[:,0])
252
+ max_volume = input_audios.abs().max(dim=-1)[0]
253
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
254
+ return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
255
+
256
+ @torch.no_grad()
257
+ def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
258
+ codes = self.sound2code(sound)
259
+ # print(codes.shape)
260
+ # exit()
261
+ wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
262
+ # print(fname, wave.shape)
263
+ return wave
264
+
265
+ def file2code(self, fname):
266
+ try:
267
+ orig_samples, fs = torchaudio.load(fname)
268
+ except:
269
+ af = AudioFile(fname)
270
+ orig_samples = af.read()
271
+ fs = af.samplerate()
272
+ orig_samples = orig_samples[0]
273
+ if(fs!=self.sample_rate):
274
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
275
+ fs = self.sample_rate
276
+ if orig_samples.shape[0] == 1:
277
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
278
+ return self.sound2code(orig_samples)
279
+
280
+ def file2code_ds(self, fname, ds):
281
+ try:
282
+ orig_samples, fs = torchaudio.load(fname)
283
+ except:
284
+ af = AudioFile(fname)
285
+ orig_samples = af.read()
286
+ fs = af.samplerate()
287
+ orig_samples = orig_samples[0]
288
+ if(fs!=self.sample_rate):
289
+ orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
290
+ fs = self.sample_rate
291
+ if orig_samples.shape[0] == 1:
292
+ orig_samples = torch.cat([orig_samples, orig_samples], 0)
293
+ return self.sound2code_ds(orig_samples, ds)
codeclm/tokenizer/Flow1dVAE/generate_septoken.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from tqdm import tqdm
4
+ from model_septoken import PromptCondAudioDiffusion
5
+ from diffusers import DDIMScheduler, DDPMScheduler
6
+ import torchaudio
7
+ import librosa
8
+ import os
9
+ import math
10
+ import numpy as np
11
+ # from tools.get_mulan import get_mulan
12
+ from tools.get_1dvae_large import get_model
13
+ import tools.torch_tools as torch_tools
14
+ from safetensors.torch import load_file
15
+ from third_party.demucs.models.pretrained import get_model_from_yaml
16
+ from filelock import FileLock
17
+ import kaldiio
18
+ # os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
19
+ class Separator:
20
+ def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
21
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
22
+ self.device = torch.device(f"cuda:{gpu_id}")
23
+ else:
24
+ self.device = torch.device("cpu")
25
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
26
+
27
+ def init_demucs_model(self, model_path, config_path):
28
+ model = get_model_from_yaml(config_path, model_path)
29
+ model.to(self.device)
30
+ model.eval()
31
+ return model
32
+
33
+ def load_audio(self, f):
34
+ a, fs = torchaudio.load(f)
35
+ if (fs != 48000):
36
+ a = torchaudio.functional.resample(a, fs, 48000)
37
+ # if a.shape[-1] >= 48000*10:
38
+ # a = a[..., :48000*10]
39
+ # else:
40
+ # a = torch.cat([a, a], -1)
41
+ # return a[:, 0:48000*10]
42
+ return a
43
+
44
+ def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
45
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
46
+ output_paths = []
47
+ # lock_path = os.path.join(output_dir, f"{name}.lock")
48
+ # with FileLock(lock_path): # 加一个避免多卡访问时死锁
49
+ for stem in self.demucs_model.sources:
50
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
51
+ if os.path.exists(output_path):
52
+ output_paths.append(output_path)
53
+ if len(output_paths) == 1: # 4
54
+ # drums_path, bass_path, other_path, vocal_path = output_paths
55
+ vocal_path = output_paths[0]
56
+ else:
57
+ lock_path = os.path.join(output_dir, f"{name}_separate.lock")
58
+ with FileLock(lock_path):
59
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
60
+ full_audio = self.load_audio(audio_path)
61
+ vocal_audio = self.load_audio(vocal_path)
62
+ minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
63
+ # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
64
+ bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
65
+ for path in [drums_path, bass_path, other_path, vocal_path]:
66
+ os.remove(path)
67
+ return full_audio, vocal_audio, bgm_audio
68
+
69
+ class Tango:
70
+ def __init__(self, \
71
+ model_path, \
72
+ vae_config,
73
+ vae_model,
74
+ layer_vocal=7,\
75
+ layer_bgm=3,\
76
+ device="cuda:0"):
77
+
78
+ self.sample_rate = 48000
79
+ scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
80
+ self.device = device
81
+
82
+ self.vae = get_model(vae_config, vae_model)
83
+ self.vae = self.vae.to(device)
84
+ self.vae=self.vae.eval()
85
+ self.layer_vocal=layer_vocal
86
+ self.layer_bgm=layer_bgm
87
+
88
+ self.MAX_DURATION = 360
89
+ main_config = {
90
+ "num_channels":32,
91
+ "unet_model_name":None,
92
+ "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
93
+ "snr_gamma":None,
94
+ }
95
+ self.model = PromptCondAudioDiffusion(**main_config).to(device)
96
+ if model_path.endswith(".safetensors"):
97
+ main_weights = load_file(model_path)
98
+ else:
99
+ main_weights = torch.load(model_path, map_location=device)
100
+ self.model.load_state_dict(main_weights, strict=False)
101
+ print ("Successfully loaded checkpoint from:", model_path)
102
+
103
+ self.model.eval()
104
+ self.model.init_device_dtype(torch.device(device), torch.float32)
105
+ print("scaling factor: ", self.model.normfeat.std)
106
+
107
+ # self.scheduler = DDIMScheduler.from_pretrained( \
108
+ # scheduler_name, subfolder="scheduler")
109
+ # self.scheduler = DDPMScheduler.from_pretrained( \
110
+ # scheduler_name, subfolder="scheduler")
111
+ print("Successfully loaded inference scheduler from {}".format(scheduler_name))
112
+
113
+
114
+ @torch.no_grad()
115
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
116
+ def sound2code(self, orig_vocal, orig_bgm, batch_size=8):
117
+ if(orig_vocal.ndim == 2):
118
+ audios_vocal = orig_vocal.unsqueeze(0).to(self.device)
119
+ elif(orig_vocal.ndim == 3):
120
+ audios_vocal = orig_vocal.to(self.device)
121
+ else:
122
+ assert orig_vocal.ndim in (2,3), orig_vocal.shape
123
+
124
+ if(orig_bgm.ndim == 2):
125
+ audios_bgm = orig_bgm.unsqueeze(0).to(self.device)
126
+ elif(orig_bgm.ndim == 3):
127
+ audios_bgm = orig_bgm.to(self.device)
128
+ else:
129
+ assert orig_bgm.ndim in (2,3), orig_bgm.shape
130
+
131
+
132
+ audios_vocal = self.preprocess_audio(audios_vocal)
133
+ audios_vocal = audios_vocal.squeeze(0)
134
+ audios_bgm = self.preprocess_audio(audios_bgm)
135
+ audios_bgm = audios_bgm.squeeze(0)
136
+ if audios_vocal.shape[-1] > audios_bgm.shape[-1]:
137
+ audios_vocal = audios_vocal[:,:audios_bgm.shape[-1]]
138
+ else:
139
+ audios_bgm = audios_bgm[:,:audios_vocal.shape[-1]]
140
+
141
+
142
+ orig_length = audios_vocal.shape[-1]
143
+ min_samples = int(40 * self.sample_rate)
144
+ # 40秒对应10个token
145
+ output_len = int(orig_length / float(self.sample_rate) * 25) + 1
146
+
147
+ while(audios_vocal.shape[-1] < min_samples):
148
+ audios_vocal = torch.cat([audios_vocal, audios_vocal], -1)
149
+ audios_bgm = torch.cat([audios_bgm, audios_bgm], -1)
150
+ int_max_len=audios_vocal.shape[-1]//min_samples+1
151
+ audios_vocal = torch.cat([audios_vocal, audios_vocal], -1)
152
+ audios_bgm = torch.cat([audios_bgm, audios_bgm], -1)
153
+ audios_vocal=audios_vocal[:,:int(int_max_len*(min_samples))]
154
+ audios_bgm=audios_bgm[:,:int(int_max_len*(min_samples))]
155
+ codes_vocal_list=[]
156
+ codes_bgm_list=[]
157
+
158
+
159
+
160
+ audio_vocal_input = audios_vocal.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
161
+ audio_bgm_input = audios_bgm.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
162
+
163
+ for audio_inx in range(0, audio_vocal_input.shape[0], batch_size):
164
+ [codes_vocal,codes_bgm], _, spk_embeds = self.model.fetch_codes_batch((audio_vocal_input[audio_inx:audio_inx+batch_size]), (audio_bgm_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer_vocal=self.layer_vocal,layer_bgm=self.layer_bgm)
165
+ codes_vocal_list.append(codes_vocal)
166
+ codes_bgm_list.append(codes_bgm)
167
+
168
+ codes_vocal = torch.cat(codes_vocal_list, 0).permute(1,0,2).reshape(1, -1)[None]
169
+ codes_bgm = torch.cat(codes_bgm_list, 0).permute(1,0,2).reshape(1, -1)[None]
170
+ codes_vocal=codes_vocal[:,:,:output_len]
171
+ codes_bgm=codes_bgm[:,:,:output_len]
172
+
173
+ return codes_vocal, codes_bgm
174
+
175
+ @torch.no_grad()
176
+ def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
177
+ codes_vocal,codes_bgm = codes
178
+ codes_vocal = codes_vocal.to(self.device)
179
+ codes_bgm = codes_bgm.to(self.device)
180
+
181
+ min_samples = duration * 25 # 40ms per frame
182
+ hop_samples = min_samples // 4 * 3
183
+ ovlp_samples = min_samples - hop_samples
184
+ hop_frames = hop_samples
185
+ ovlp_frames = ovlp_samples
186
+ first_latent = torch.randn(codes_vocal.shape[0], min_samples, 64).to(self.device)
187
+ first_latent_length = 0
188
+ first_latent_codes_length = 0
189
+
190
+
191
+ if(isinstance(prompt_vocal, torch.Tensor)):
192
+ # prepare prompt
193
+ prompt_vocal = prompt_vocal.to(self.device)
194
+ prompt_bgm = prompt_bgm.to(self.device)
195
+ if(prompt_vocal.ndim == 3):
196
+ assert prompt_vocal.shape[0] == 1, prompt_vocal.shape
197
+ prompt_vocal = prompt_vocal[0]
198
+ prompt_bgm = prompt_bgm[0]
199
+ elif(prompt_vocal.ndim == 1):
200
+ prompt_vocal = prompt_vocal.unsqueeze(0).repeat(2,1)
201
+ prompt_bgm = prompt_bgm.unsqueeze(0).repeat(2,1)
202
+ elif(prompt_vocal.ndim == 2):
203
+ if(prompt_vocal.shape[0] == 1):
204
+ prompt_vocal = prompt_vocal.repeat(2,1)
205
+ prompt_bgm = prompt_bgm.repeat(2,1)
206
+
207
+ if(prompt_vocal.shape[-1] < int(30 * self.sample_rate)):
208
+ # if less than 30s, just choose the first 10s
209
+ prompt_vocal = prompt_vocal[:,:int(10*self.sample_rate)] # limit max length to 10.24
210
+ prompt_bgm = prompt_bgm[:,:int(10*self.sample_rate)] # limit max length to 10.24
211
+ else:
212
+ # else choose from 20.48s which might includes verse or chorus
213
+ prompt_vocal = prompt_vocal[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
214
+ prompt_bgm = prompt_bgm[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
215
+
216
+ true_latent = self.vae.encode_audio(prompt_vocal+prompt_bgm).permute(0,2,1)
217
+
218
+ first_latent[:,0:true_latent.shape[1],:] = true_latent
219
+ first_latent_length = true_latent.shape[1]
220
+ first_latent_codes = self.sound2code(prompt_vocal, prompt_bgm)
221
+ first_latent_codes_vocal = first_latent_codes[0]
222
+ first_latent_codes_bgm = first_latent_codes[1]
223
+ first_latent_codes_length = first_latent_codes_vocal.shape[-1]
224
+ codes_vocal = torch.cat([first_latent_codes_vocal, codes_vocal], -1)
225
+ codes_bgm = torch.cat([first_latent_codes_bgm, codes_bgm], -1)
226
+
227
+
228
+ codes_len= codes_vocal.shape[-1]
229
+ target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
230
+ # target_len = int(codes_len / 100 * 4 * self.sample_rate)
231
+ # code repeat
232
+ if(codes_len < min_samples):
233
+ while(codes_vocal.shape[-1] < min_samples):
234
+ codes_vocal = torch.cat([codes_vocal, codes_vocal], -1)
235
+ codes_bgm = torch.cat([codes_bgm, codes_bgm], -1)
236
+
237
+ codes_vocal = codes_vocal[:,:,0:min_samples]
238
+ codes_bgm = codes_bgm[:,:,0:min_samples]
239
+ codes_len = codes_vocal.shape[-1]
240
+ if((codes_len - ovlp_samples) % hop_samples > 0):
241
+ len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
242
+ while(codes_vocal.shape[-1] < len_codes):
243
+ codes_vocal = torch.cat([codes_vocal, codes_vocal], -1)
244
+ codes_bgm = torch.cat([codes_bgm, codes_bgm], -1)
245
+ codes_vocal = codes_vocal[:,:,0:len_codes]
246
+ codes_bgm = codes_bgm[:,:,0:len_codes]
247
+ latent_length = min_samples
248
+ latent_list = []
249
+ spk_embeds = torch.zeros([1, 32, 1, 32], device=codes_vocal.device)
250
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
251
+ for sinx in range(0, codes_vocal.shape[-1]-hop_samples, hop_samples):
252
+ codes_vocal_input=codes_vocal[:,:,sinx:sinx+min_samples]
253
+ codes_bgm_input=codes_bgm[:,:,sinx:sinx+min_samples]
254
+ if(sinx == 0):
255
+ incontext_length = first_latent_length
256
+ latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
257
+ latent_list.append(latents)
258
+ else:
259
+ true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
260
+ len_add_to_1000 = min_samples - true_latent.shape[-2]
261
+ incontext_length = true_latent.shape[-2]
262
+ true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
263
+ latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
264
+ latent_list.append(latents)
265
+
266
+ latent_list = [l.float() for l in latent_list]
267
+ latent_list[0] = latent_list[0][:,:,first_latent_length:]
268
+ min_samples = int(min_samples * self.sample_rate // 1000 * 40)
269
+ hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
270
+ ovlp_samples = min_samples - hop_samples
271
+ with torch.no_grad():
272
+ output = None
273
+ for i in range(len(latent_list)):
274
+ latent = latent_list[i]
275
+ cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
276
+
277
+ if output is None:
278
+ output = cur_output
279
+ else:
280
+ ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
281
+ ov_win = torch.cat([ov_win, 1 - ov_win], -1)
282
+ output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
283
+ output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
284
+ output = output[:, 0:target_len]
285
+ return output
286
+
287
+ @torch.no_grad()
288
+ def preprocess_audio(self, input_audios_vocal, threshold=0.8):
289
+ assert len(input_audios_vocal.shape) == 3, input_audios_vocal.shape
290
+ nchan = input_audios_vocal.shape[1]
291
+ input_audios_vocal = input_audios_vocal.reshape(input_audios_vocal.shape[0], -1)
292
+ norm_value = torch.ones_like(input_audios_vocal[:,0])
293
+ max_volume = input_audios_vocal.abs().max(dim=-1)[0]
294
+ norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
295
+ return input_audios_vocal.reshape(input_audios_vocal.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
296
+
297
+ @torch.no_grad()
298
+ def sound2sound(self, orig_vocal,orig_bgm, prompt_vocal=None,prompt_bgm=None, steps=50, disable_progress=False):
299
+ codes_vocal, codes_bgm = self.sound2code(orig_vocal,orig_bgm)
300
+ codes=[codes_vocal, codes_bgm]
301
+ wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
302
+ return wave
codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py ADDED
@@ -0,0 +1,1278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union
3
+ from beartype import beartype
4
+ from beartype.door import is_bearable
5
+ import random
6
+ import pandas as pd
7
+ import os
8
+ from torchaudio.functional import resample
9
+ import torch
10
+ import typing as tp
11
+ from pathlib import Path
12
+ import torchaudio as ta
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import json
16
+ import yaml
17
+ import torchaudio
18
+ import math
19
+ import re
20
+ from loguru import logger
21
+ import ffmpeg
22
+
23
+ class Read_and_PadCrop_Normalized_T(torch.nn.Module):
24
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
25
+
26
+ super().__init__()
27
+
28
+ self.n_samples = n_samples
29
+ self.sample_rate = sample_rate
30
+ self.randomize = randomize
31
+
32
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
33
+ if self.n_samples < 0: #means not clip
34
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
35
+ t_start = 0.
36
+ t_end = 1.0
37
+ offset = 0
38
+ else:
39
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
40
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
41
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
42
+ t_start = 0.
43
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
44
+ offset = 0
45
+ # print('c1:',chunk.shape)
46
+ else:
47
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
48
+ t_start = offset / float(cur_sample_rate) / duration
49
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
50
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
51
+ # print('offset:',offset)
52
+ # print('c0:',chunk.shape)
53
+ # Pad with silence if necessary.
54
+ if(chunk.shape[0]>1):
55
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
56
+ else:
57
+ chunk = chunk[[0],:].float()
58
+ if(cur_sample_rate!=self.sample_rate):
59
+ # print('a:',cur_sample_rate,chunk.shape)
60
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
61
+ # print('b:',self.sample_rate,chunk.shape)
62
+
63
+ if self.n_samples > 0:
64
+ if chunk.shape[-1] < self.n_samples:
65
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
66
+ else:
67
+ chunk = chunk[:,0:self.n_samples]
68
+ seconds_start = math.floor(offset / cur_sample_rate)
69
+ seconds_total = math.floor(duration)
70
+
71
+ return (
72
+ chunk,
73
+ t_start,
74
+ t_end,
75
+ seconds_start,
76
+ seconds_total
77
+ )
78
+
79
+ class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module):
80
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3):
81
+
82
+ super().__init__()
83
+
84
+ self.n_samples = n_samples
85
+ self.sample_rate = sample_rate
86
+ self.randomize = randomize
87
+
88
+ self.w_start = w_start
89
+ self.w_interval = w_interval
90
+
91
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
92
+ if self.n_samples < 0: #means not clip
93
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
94
+ t_start = 0.
95
+ t_end = 1.0
96
+ offset = 0
97
+ else:
98
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
99
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
100
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
101
+ t_start = 0.
102
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
103
+ offset = 0
104
+ # print('c1:',chunk.shape)
105
+ else:
106
+ n_offset_option = (duration - self.w_start) // self.w_interval
107
+ if n_offset_option <= 1:
108
+ offset = 0
109
+ else:
110
+ offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate)
111
+ # offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
112
+ t_start = offset / float(cur_sample_rate) / duration
113
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
114
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
115
+ # print('offset:',offset)
116
+ # print('c0:',chunk.shape)
117
+ # Pad with silence if necessary.
118
+ if(chunk.shape[0]>1):
119
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
120
+ else:
121
+ chunk = chunk[[0],:].float()
122
+ if(cur_sample_rate!=self.sample_rate):
123
+ # print('a:',cur_sample_rate,chunk.shape)
124
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
125
+ # print('b:',self.sample_rate,chunk.shape)
126
+
127
+ if self.n_samples > 0:
128
+ if chunk.shape[-1] < self.n_samples:
129
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
130
+ else:
131
+ chunk = chunk[:,0:self.n_samples]
132
+ seconds_start = math.floor(offset / cur_sample_rate)
133
+ seconds_total = math.floor(duration)
134
+
135
+ return (
136
+ chunk,
137
+ t_start,
138
+ t_end,
139
+ seconds_start,
140
+ seconds_total
141
+ )
142
+
143
+ USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
144
+ if USE_DUMMY_AUDIO:
145
+ logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
146
+
147
+ class SafeAudioReader:
148
+ """
149
+ This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
150
+ """
151
+ def __init__(self,
152
+ duration: float, # 返回音频长度
153
+ sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
154
+ randomize: bool = True,
155
+ use_avoid_watermark_policy = False,
156
+ ):
157
+ self.n_samples = int(sample_rate * duration)
158
+ self.reader = (
159
+ Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \
160
+ else Read_and_PadCrop_Normalized_T
161
+ )(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
162
+
163
+ #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
164
+ def __call__(self,
165
+ filepath: os.PathLike, # 音频路径
166
+ origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
167
+ origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
168
+ ) -> torch.Tensor:
169
+ if USE_DUMMY_AUDIO:
170
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
171
+ return wav
172
+ try:
173
+ if origin_sample_rate is None or origin_duration is None:
174
+ # audio_info = torchaudio.info(filepath)
175
+ # origin_sample_rate = audio_info.sample_rate
176
+ # origin_duration = audio_info.num_frames / origin_sample_rate
177
+ info = ffmpeg.probe(filepath)
178
+ origin_duration = float(info['format']['duration'])
179
+ origin_sample_rate = int(info['streams'][0]['sample_rate'])
180
+ wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
181
+ wav = wav.squeeze_(0)
182
+ except Exception as e:
183
+ logger.error(f"Error reading {filepath}: {e}")
184
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
185
+ return wav
186
+
187
+
188
+ class PromptTemplate:
189
+ def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
190
+ self.template_text = template_text
191
+ self.tag_map = tag_map
192
+ self.lang = lang
193
+
194
+ @property
195
+ def tags(self):
196
+ return tuple(self.tag_map.keys())
197
+
198
+ def apply(self, **kwargs):
199
+ for tag in list(kwargs.keys()):
200
+ if kwargs[tag] == '':
201
+ kwargs.pop(tag)
202
+ for tag in self.tags:
203
+ if tag in kwargs:
204
+ kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
205
+ else:
206
+ kwargs[tag] = ''
207
+ prompt = self.template_text.format(**kwargs)
208
+
209
+ return self.beautify(prompt)
210
+
211
+ def beautify(self, text):
212
+ if self.lang == 'en':
213
+ return self._beautify_en(text)
214
+ elif self.lang == 'zh':
215
+ return self._beautify_zh(text)
216
+ else:
217
+ raise ValueError(f'Unknown language {self.lang}')
218
+
219
+ @staticmethod
220
+ def _beautify_en(text):
221
+ # no continuous commas without content between them
222
+ text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
223
+ # no continuous whitespace
224
+ text = re.sub(r'\s+', ' ', text)
225
+ # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
226
+ text = re.sub(r'\s+,', r',', text)
227
+ text = re.sub(r',\s+', r', ', text)
228
+ # no whitespace before the full stop
229
+ text = re.sub(r'\s+\.', r'.', text)
230
+ # strip whitespace, comma, and replace ',.'
231
+ text = text.strip(' ,')
232
+ text = text.replace(',.', '.')
233
+ return text
234
+
235
+ @staticmethod
236
+ def _beautify_zh(text):
237
+ # no continuous commas without content between them
238
+ text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
239
+ text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
240
+ # assume there should be NO whitespace in Chinese
241
+ text = re.sub(r'\s+', r'', text)
242
+ # strip whitespace, comma, and replace ',。'
243
+ text = text.strip(', 、')
244
+ text = text.replace(',。', '。')
245
+ return text
246
+
247
+ def __repr__(self):
248
+ return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
249
+
250
+ __str__ = __repr__
251
+
252
+ def parse_prompt_template(prompt_template_text, lang='en'):
253
+ span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
254
+ tag_pattern = re.compile(r'{.+?}', re.DOTALL)
255
+
256
+ template_text = prompt_template_text.strip()
257
+ span_texts = span_pattern.findall(prompt_template_text)
258
+ tag_map = {}
259
+ for span_text in span_texts:
260
+ tag = tag_pattern.findall(span_text)[0].strip('{}')
261
+ tag_map[tag] = span_text
262
+ template_text = template_text.replace(span_text, '{'+tag+'}')
263
+
264
+ return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
265
+
266
+ def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
267
+ with open(path, 'r') as f:
268
+ lines = f.readlines()
269
+ cnt = 0
270
+ pts = []
271
+ for line in lines:
272
+ pt = parse_prompt_template(line, lang=lang)
273
+ cnt += 1
274
+ if len(pt.tags) < num:
275
+ logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
276
+ pts.append(pt)
277
+
278
+ return pts
279
+
280
+
281
+ def get_base_dir_file(key: os.PathLike):
282
+ base = os.path.basename(key)
283
+ dirname = os.path.basename(os.path.dirname(key))
284
+ return os.path.join(dirname, base)
285
+
286
+ def read_jsonlike(path: os.PathLike):
287
+ #json or jsonl
288
+ if str(path).endswith(".json"):
289
+ with open(path, 'r', encoding='utf8') as f:
290
+ data = json.load(f)
291
+ return data
292
+ elif str(path).endswith(".jsonl"):
293
+ with open(path, 'r', encoding='utf8') as f:
294
+ data = [json.loads(line) for line in f.readlines()]
295
+ return data
296
+ else:
297
+ raise ValueError("Unknown file format")
298
+
299
+ dist_prob_map = {
300
+ 1: (1.0,),
301
+ 2: (0.5, 0.5),
302
+ 3: (0.3, 0.4, 0.3),
303
+ 4: (0.2, 0.3, 0.3, 0.2),
304
+ 5: (0.2, 0.2, 0.3, 0.2, 0.1),
305
+ 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
306
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
307
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
308
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
309
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
310
+ }
311
+
312
+ '''
313
+ #更加偏向短文本的方案
314
+ dist_prob_map = {
315
+ 1: (1.0,),
316
+ 2: (0.7, 0.3),
317
+ 3: (0.7, 0.2, 0.1),
318
+ 4: (0.6, 0.2, 0.1, 0.1),
319
+ 5: (0.6, 0.2, 0.1, 0.05, 0.05),
320
+ 6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05),
321
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
322
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
323
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
324
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
325
+ }
326
+ '''
327
+
328
+ #全部都用的方案
329
+ # dist_prob_map = {
330
+ # 1: (1.0,),
331
+ # 2: (0, 1.0),
332
+ # 3: (0, 0, 1.0),
333
+ # 4: (0, 0, 0, 1.0),
334
+ # 5: (0, 0, 0, 0, 1.0),
335
+ # 6: (0, 0, 0, 0, 0, 1.0),
336
+ # 7: (0, 0, 0, 0, 0, 0, 1.0),
337
+ # 8: (0, 0, 0, 0, 0, 0, 0, 1.0),
338
+ # 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0),
339
+ # 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0)
340
+ # }
341
+
342
+ dist_prob_map_low = {
343
+ 1: (1.0,),
344
+ 2: (0.8, 0.2),
345
+ 3: (0.8, 0.1, 0.1),
346
+ 4: (0.7, 0.1, 0.1, 0.1),
347
+ 5: (0.7, 0.1, 0.1, 0.05, 0.05),
348
+ 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
349
+ }
350
+
351
+ _bpm_range_rights = (
352
+ (40, '20-40'),
353
+ (60, '40-60'),
354
+ (66, '60-66'),
355
+ (76, '66-76'),
356
+ (108, '76-108'),
357
+ (120, '108-120'),
358
+ (168, '120-168'),
359
+ (176, '168-176'),
360
+ (200, '176-200')
361
+ )
362
+ _bpm_desc_map = {
363
+ '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
364
+ '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
365
+ '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
366
+ '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
367
+ '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
368
+ '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
369
+ '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
370
+ '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
371
+ '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
372
+ '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
373
+ }
374
+ _bpm_desc_map_zh = {
375
+ '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
376
+ '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
377
+ '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
378
+ '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
379
+ '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
380
+ '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
381
+ '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
382
+ '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
383
+ '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
384
+ '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
385
+ }
386
+ def get_bpm_range(bpm):
387
+ bpm = int(bpm)
388
+ for right, tag in _bpm_range_rights:
389
+ if bpm <= right:
390
+ return tag
391
+ return '>200'
392
+
393
+ def gen_bpm_descript(bpm, lang='en'):
394
+ bpm_range = get_bpm_range(bpm)
395
+ if lang == 'en':
396
+ return random.choice(_bpm_desc_map[bpm_range])
397
+ elif lang == 'zh':
398
+ return random.choice(_bpm_desc_map_zh[bpm_range])
399
+ else:
400
+ raise ValueError(f"Unknown language {lang}")
401
+
402
+ def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]):
403
+ if translate is None:
404
+ return None
405
+ if isinstance(translate, str):
406
+ return read_jsonlike(translate)
407
+ return {k: read_jsonlike(path) for k, path in translate.items()}
408
+
409
+
410
+ def gen_plain_prompt(key_list, sep=', '):
411
+ if len(key_list) == 0:
412
+ return 'none'
413
+
414
+ key_list = [k.strip() for k in key_list]
415
+
416
+ if len(key_list) > 10:
417
+ random.shuffle(key_list)
418
+ key_list = key_list[:10]
419
+
420
+ probs = dist_prob_map[len(key_list)]
421
+
422
+ num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
423
+
424
+ random.shuffle(key_list)
425
+ tags = key_list[:num_tags]
426
+ tags_str = sep.join(tags)
427
+ return tags_str
428
+
429
+
430
+ class MagnaTagATuneDataset(Dataset):
431
+ def __init__(self):
432
+ pass
433
+
434
+
435
+ def tags_to_desc(tag_list, sep=',') -> str:
436
+ if not isinstance(tag_list, Sequence):
437
+ return str(tag_list)
438
+ if isinstance(tag_list, str):
439
+ return tag_list
440
+ if len(tag_list) <= 0:
441
+ return ''
442
+ elif len(tag_list) <= 5:
443
+ probs = dist_prob_map[len(tag_list)]
444
+ tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
445
+ random.shuffle(tag_list)
446
+ tag_list = tag_list[:tags_num]
447
+ return sep.join(tag_list)
448
+ else:
449
+ probs = dist_prob_map[5]
450
+ tags_num = random.choices(range(1, 6), probs)[0]
451
+ random.shuffle(tag_list)
452
+ tag_list = tag_list[:tags_num]
453
+ return sep.join(tag_list)
454
+
455
+ def get_sr_and_duration_info(item):
456
+ return item.get('sample_rate', None), item.get('duration', None)
457
+
458
+ class MtgJamendoDatasetFromJson(Dataset):
459
+ def __init__(self,
460
+ data_dir:str,
461
+ json_path:str,
462
+ duration:float=10,
463
+ sr:int = 0,
464
+ lang = 'en',
465
+ plain_rate = 0,
466
+ return_audio = True,
467
+ return_path = False,
468
+ prompt_template_path: os.PathLike = None,
469
+ tag_types = [],
470
+ translate:Optional[Dict[str, os.PathLike]] = None,
471
+ use_literal_none = True,
472
+ ):
473
+ self.audio_reader = SafeAudioReader(duration, sr)
474
+
475
+ self.data_dir = data_dir
476
+ self._load_metadata_json(json_path)
477
+ self.sr = sr
478
+ self.duration = duration
479
+ self.plain_rate = plain_rate
480
+ self.return_audio = return_audio
481
+ self.return_path = return_path
482
+ self.use_literal_none = use_literal_none
483
+ self.lang = lang
484
+
485
+ self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
486
+ if self.use_dynamic_prompt:
487
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
488
+ self.tag_types = tag_types
489
+
490
+ self.translate = read_translate(translate)
491
+
492
+ #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
493
+ WEAK_TAG_LIST = ["title", "artist"]
494
+
495
+ def _load_metadata_json(self, json_path):
496
+ with open(json_path) as fp:
497
+ self.data = json.load(fp)
498
+
499
+ def convert_key_to_path(self, key):
500
+ return os.path.join(self.data_dir, get_base_dir_file(key))
501
+
502
+ def __len__(self):
503
+ return len(self.data)
504
+
505
+ def __getitem__(self, idx):
506
+ item = self.data[idx]
507
+ path = self.convert_key_to_path(item['key'])
508
+ description = self.generate_description(item)
509
+
510
+ if self.return_audio:
511
+ sr, duration = get_sr_and_duration_info(item)
512
+ audio = self.audio_reader(path, sr, duration)
513
+ else:
514
+ audio = None
515
+
516
+ if self.return_path:
517
+ return audio, description, path
518
+ return audio, description
519
+
520
+ def tags_to_desc(self, tag_list, tag_type) -> str:
521
+ if self.lang == 'en':
522
+ return tags_to_desc(tag_list)
523
+ elif self.lang == 'zh':
524
+ translator = self.translate[tag_type]
525
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
526
+ return tags_to_desc(translated_tag_list, sep='、')
527
+
528
+ def generate_description(self, item):
529
+ if random.random() > self.plain_rate:
530
+ # dynamically generate prompt from given prompt template
531
+ prompt_template = random.choice(self.prompt_templates)
532
+ description = self.generate_description_dynamic(item, prompt_template)
533
+ else:
534
+ # use plain prompt, i.e. tags sequence separated by comma
535
+ description = self.generate_description_plain(item)
536
+ return description
537
+
538
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
539
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
540
+ exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
541
+ exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
542
+
543
+ if len(exists_strong_tag) > 0:
544
+ probs = dist_prob_map[len(exists_strong_tag)]
545
+ tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
546
+ random.shuffle(exists_strong_tag)
547
+ tags = exists_strong_tag[:tags_num]
548
+ weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
549
+ weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
550
+ random.shuffle(exists_weak_tag)
551
+ weak_tags = exists_weak_tag[:weak_tags_num]
552
+ tags += weak_tags
553
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
554
+ prompt = prompt_template.apply(**tags_args)
555
+ else:
556
+ # no strong tags, use all weak tags instead
557
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
558
+ prompt = prompt_template.apply(**tags_args)
559
+
560
+ if self.use_literal_none and len(tags_args) == 0:
561
+ return 'none'
562
+
563
+ return prompt
564
+
565
+ def generate_description_plain(self, item):
566
+ keywords = []
567
+ for tag_t in self.tag_types:
568
+ this_key = item[tag_t]
569
+ if this_key is None:
570
+ continue
571
+ if isinstance(this_key, str):
572
+ this_key = [this_key]
573
+ if self.lang != 'en':
574
+ this_key = [self.get_translation(tag_t, k) for k in this_key]
575
+ keywords += this_key
576
+ return gen_plain_prompt(keywords, sep=self.keysep)
577
+
578
+ def get_translation(self, tag_t, k):
579
+ k = k.strip()
580
+ if k in self.translate[tag_t]:
581
+ return self.translate[tag_t][k]
582
+ else:
583
+ return k
584
+
585
+ @property
586
+ def keysep(self):
587
+ if self.lang == 'zh':
588
+ return ',' if random.random() > 0.5 else '、'
589
+ elif self.lang == 'en':
590
+ return ', '
591
+
592
+ class AudioStockDataset(Dataset):
593
+ def __init__(self,
594
+ metadata_path:str,
595
+ duration:float=10,
596
+ sr:int = 0,
597
+ plain_rate = 0,
598
+ return_path = False,
599
+ return_audio = True,
600
+ prompt_template_path: os.PathLike = None,
601
+ tag_types = [],
602
+ lang = 'en',
603
+ translate:Optional[Dict[str, os.PathLike]] = None,
604
+ use_literal_none = True,
605
+ ):
606
+ self.audio_reader = SafeAudioReader(duration, sr)
607
+
608
+ self._load_metadata(metadata_path)
609
+ self.sr = sr
610
+ self.duration = duration
611
+ self.plain_rate = plain_rate
612
+ self.return_path = return_path
613
+ self.return_audio = return_audio
614
+ self.use_literal_none = use_literal_none
615
+
616
+ self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
617
+ if self.use_dynamic_prompt:
618
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
619
+ self.tag_types = tag_types
620
+
621
+ self.lang = lang
622
+ self.translate = read_translate(translate)
623
+
624
+ def _load_metadata(self, metadata_path):
625
+ with open(metadata_path) as fp:
626
+ lines = fp.readlines()
627
+ self.data = []
628
+ for line in lines:
629
+ item = json.loads(line)
630
+ self.data.append(item)
631
+ self.is_info_recorded = bool('Tags' in self.data[0])
632
+
633
+ def __len__(self):
634
+ return len(self.data)
635
+
636
+ def __getitem__(self, idx):
637
+ path:str = self.data[idx]["path"]
638
+ json_path = path[:path.rfind('.')] + ".json"
639
+ if self.is_info_recorded:
640
+ item = self.data[idx]
641
+ else:
642
+ try:
643
+ with open(json_path) as fp:
644
+ item:dict = json.load(fp)
645
+ except Exception as e:
646
+ print(f"Error loading json file {json_path} :\n{e}")
647
+ item = {}
648
+ description = self.generate_description(item)
649
+ if self.return_audio:
650
+ sr, duration = get_sr_and_duration_info(item)
651
+ audio = self.audio_reader(path, sr, duration)
652
+ else:
653
+ audio = None
654
+ if self.return_path:
655
+ return audio, description, path
656
+ return audio, description
657
+
658
+ def generate_description(self, item):
659
+ if random.random() > self.plain_rate:
660
+ # dynamically generate prompt from given prompt template
661
+ prompt_template = random.choice(self.prompt_templates)
662
+ description = self.generate_description_dynamic(item, prompt_template)
663
+ else:
664
+ # use plain prompt, i.e. tags sequence separated by comma
665
+ description = self.generate_description_plain(item)
666
+ return description
667
+
668
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
669
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
670
+
671
+ if len(exists_tag) > 0:
672
+ probs = dist_prob_map[len(exists_tag)]
673
+ tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
674
+ random.shuffle(exists_tag)
675
+ tags = exists_tag[:tags_num]
676
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
677
+ tags_args = self.handle_BPM_tag(tags_args)
678
+ prompt = prompt_template.apply(**tags_args)
679
+ else:
680
+ return 'none'
681
+
682
+ if self.use_literal_none and len(tags_args) == 0:
683
+ return 'none'
684
+
685
+ return prompt
686
+
687
+ def get_translation(self, tag_t, k):
688
+ k = k.strip()
689
+ if k in self.translate[tag_t]:
690
+ return self.translate[tag_t][k]
691
+ else:
692
+ return k
693
+
694
+ def generate_description_plain(self, item):
695
+ keywords = []
696
+ for tag_t in self.tag_types:
697
+ if tag_t == 'BPMDescript':
698
+ bpm = item['BPM']
699
+ if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
700
+ continue
701
+ this_key = gen_bpm_descript(bpm.strip(), lang=self.lang)
702
+ elif tag_t == 'BPM':
703
+ bpm = item['BPM']
704
+ if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
705
+ continue
706
+ this_key = f"{bpm.strip()} bpm"
707
+ else:
708
+ this_key = item[tag_t]
709
+ if this_key is None:
710
+ continue
711
+ if isinstance(this_key, str):
712
+ this_key = [this_key]
713
+ if self.lang != 'en':
714
+ this_key = [self.get_translation(tag_t, k) for k in this_key]
715
+ if this_key is None:
716
+ continue
717
+ if isinstance(this_key, str):
718
+ this_key = [this_key]
719
+ keywords += this_key
720
+ return gen_plain_prompt(keywords, sep=self.keysep)
721
+
722
+ @property
723
+ def keysep(self):
724
+ if self.lang == 'zh':
725
+ return ',' if random.random() > 0.5 else '、'
726
+ elif self.lang == 'en':
727
+ return ', '
728
+
729
+ def tags_to_desc(self, tag_list, tag_type) -> str:
730
+ if self.lang == 'en':
731
+ return tags_to_desc(tag_list)
732
+ elif self.lang == 'zh':
733
+ if tag_type == 'BPM':
734
+ return tags_to_desc(tag_list, sep='、')
735
+ translator = self.translate[tag_type]
736
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
737
+ return tags_to_desc(translated_tag_list, sep='、')
738
+
739
+ def handle_BPM_tag(self, tags_args):
740
+ if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
741
+ bpm = tags_args["BPM"]
742
+ del tags_args["BPM"]
743
+ tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
744
+ for tag_type in tag_types_used:
745
+ tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
746
+ return tags_args
747
+
748
+ def mp3_path_to_id(mp3_path):
749
+ return int(
750
+ mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')]
751
+ )
752
+
753
+ class TmeDataset(Dataset):
754
+ def __init__(self,
755
+ data_index:str,
756
+ music_info:str = None,
757
+ duration:float = 10,
758
+ sr:int = 0,
759
+ plain_rate = 0,
760
+ return_path = False,
761
+ return_audio = True,
762
+ return_ID = False,
763
+ prompt_format_path: os.PathLike = None,
764
+ tag_types = ['*'],
765
+ lang = 'zh',
766
+ translate: Optional[os.PathLike] = None,
767
+ prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt
768
+ ):
769
+ if plain_rate > 0:
770
+ print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.")
771
+ plain_rate = 0
772
+ self.audio_reader = SafeAudioReader(duration, sr)
773
+
774
+ self.sr = sr
775
+ self.duration = duration
776
+ self.plain_rate = plain_rate
777
+ self.return_path = return_path
778
+ self.return_audio = return_audio
779
+ self.return_ID = return_ID
780
+ self.lang = lang
781
+
782
+ self.use_ready_prompt = prompt_dir is not None
783
+
784
+ data_index = read_jsonlike(data_index)
785
+ self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
786
+ self.data_ids = list(self.data_index_dict.keys())
787
+
788
+ if not self.use_ready_prompt:
789
+ #读取音乐的信息文件
790
+ music_info = read_jsonlike(music_info)
791
+ if 'music' in music_info:
792
+ music_info = music_info['music']
793
+ self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
794
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
795
+ self.data_ids = list(self.data_index_dict.keys())
796
+
797
+ with open(prompt_format_path) as fp:
798
+ self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
799
+
800
+ #加载tag types,并分成一般的tag_types和关键的key_tag_types
801
+ if '*' in tag_types:
802
+ self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
803
+ else:
804
+ self.tag_types = tag_types
805
+
806
+ self.key_tag_types = []
807
+ if 'tag' in self.tag_types:
808
+ self.tag_types.remove('tag')
809
+ self.key_tag_types = list(self.prompt_formats['tag'].keys())
810
+
811
+ #加载translate翻译
812
+ if translate is not None:
813
+ self.translator = read_jsonlike(translate)
814
+ else:
815
+ data_ids_set = set(self.data_ids)
816
+ self.prompts_dict = {}
817
+ for fname in os.listdir(prompt_dir):
818
+ items = read_jsonlike(os.path.join(prompt_dir, fname))
819
+ for item in items:
820
+ if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
821
+ continue
822
+ if item['ID'] not in self.prompts_dict:
823
+ self.prompts_dict[item['ID']] = []
824
+ self.prompts_dict[item['ID']].append(item['Text'])
825
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
826
+ self.data_ids = list(self.data_index_dict.keys())
827
+
828
+ def tags_to_desc(self, tag_list) -> str:
829
+ if is_bearable(tag_list, int):
830
+ return str(tag_list)
831
+ if self.lang == 'zh':
832
+ return tags_to_desc(tag_list, sep=self.sep)
833
+ else:
834
+ translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
835
+ return tags_to_desc(translated_tag_list, sep=self.sep)
836
+
837
+ def gen_desc_of_tag(self, formats, tags):
838
+ fmt = random.choice(formats)
839
+ return fmt.format(self.tags_to_desc(tags))
840
+
841
+ @staticmethod
842
+ def check_valid(value):
843
+ if isinstance(value, int) or isinstance(value, float):
844
+ return value > 0
845
+ if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
846
+ return True
847
+ return False
848
+
849
+ @staticmethod
850
+ def remove_repeat(data):
851
+ #若专辑名和歌曲名相同,则只使用后者
852
+ album_name = data.get('专辑名', None)
853
+ if album_name is not None and album_name == data.get('歌曲名', None):
854
+ del data['专辑名']
855
+ return data
856
+
857
+ @property
858
+ def comma(self):
859
+ if self.lang == 'zh':
860
+ return ','
861
+ elif self.lang == 'en':
862
+ return ', '
863
+
864
+ @property
865
+ def sep(self):
866
+ if self.lang == 'zh':
867
+ return '、'
868
+ elif self.lang == 'en':
869
+ return ', '
870
+
871
+
872
+ def generate_description(self, item):
873
+ if random.random() > self.plain_rate:
874
+ # dynamically generate prompt from given prompt template
875
+ description = self.generate_description_dynamic(item)
876
+ else:
877
+ # use plain prompt, i.e. tags sequence separated by comma
878
+ description = self.generate_description_plain(item)
879
+ return description
880
+
881
+ def generate_description_dynamic(self, data):
882
+ data = self.remove_repeat(data)
883
+
884
+ weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
885
+
886
+ key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
887
+
888
+ prompts = []
889
+ if len(weak_tags) > 0:
890
+ probs = dist_prob_map_low[len(weak_tags)]
891
+ if len(key_tags) > 0:
892
+ tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
893
+ else:
894
+ tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
895
+ random.shuffle(weak_tags)
896
+ tags = weak_tags[:tags_num]
897
+ for tag_type in tags:
898
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
899
+ prompts.append(tag_desc)
900
+
901
+ if len(key_tags) > 0:
902
+ probs = dist_prob_map[len(key_tags)]
903
+ tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
904
+ random.shuffle(key_tags)
905
+ tags = key_tags[:tags_num]
906
+ for tag_type in tags:
907
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
908
+ prompts.append(tag_desc)
909
+
910
+ random.shuffle(prompts)
911
+ return self.comma.join(prompts)
912
+
913
+ def generate_description_plain(self, item):
914
+ keywords = item['tag']
915
+ if self.lang != 'en':
916
+ keywords = [self.translator[k.strip()] for k in keywords]
917
+ return gen_plain_prompt(keywords, sep=self.keysep)
918
+
919
+ @property
920
+ def keysep(self):
921
+ if self.lang == 'zh':
922
+ return ',' if random.random() > 0.5 else '、'
923
+ elif self.lang == 'en':
924
+ return ', '
925
+
926
+ def is_valid_prompt_text(self, text):
927
+ for bad in ('抱歉','sorry', 'Sorry'):
928
+ if bad in text:
929
+ return False
930
+ return True
931
+
932
+ def get_ready_prompt(self, path):
933
+ sid = mp3_path_to_id(path)
934
+ return random.choice(self.prompts_dict[sid])
935
+
936
+ def __len__(self):
937
+ return len(self.data_ids)
938
+
939
+ def __getitem__(self, idx):
940
+ data_id = self.data_ids[idx]
941
+ item = self.data_index_dict[data_id]
942
+ path = item['path']
943
+ if not self.use_ready_prompt:
944
+ info = self.music_info_dict[data_id]
945
+ description = self.generate_description(info)
946
+ else:
947
+ description = self.get_ready_prompt(path)
948
+ if self.return_audio:
949
+ sr, duration = get_sr_and_duration_info(item)
950
+ audio = self.audio_reader(path, sr, duration)
951
+ else:
952
+ audio = None
953
+ if self.return_path:
954
+ if self.return_ID:
955
+ return audio, description, path, info['歌曲ID']
956
+ return audio, description, path
957
+ if self.return_ID:
958
+ return audio, description, info['歌曲ID']
959
+ return audio, description
960
+
961
+
962
+ class Pond5Dataset(Dataset):
963
+ MAX_PROMPT_LEN = 200
964
+ def __init__(self,
965
+ metadata_path:str,
966
+ index_path:str,
967
+ duration:float=10,
968
+ sr:int = 0,
969
+ plain_rate = 0,
970
+ return_path = False,
971
+ return_audio = True,
972
+ lang = 'en',
973
+ translate:Optional[Dict[str, os.PathLike]] = None,
974
+ use_literal_none = True,
975
+ use_avoid_watermark_policy = None,
976
+ ):
977
+
978
+ if use_avoid_watermark_policy is None:
979
+ raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
980
+ self.use_avoid_watermark_policy = use_avoid_watermark_policy
981
+ self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy)
982
+
983
+ self._load_metadata(metadata_path, index_path)
984
+ self.sr = sr
985
+ self.duration = duration
986
+ self.plain_rate = plain_rate
987
+ self.return_path = return_path
988
+ self.return_audio = return_audio
989
+ self.use_literal_none = use_literal_none
990
+
991
+ self.lang = lang
992
+ self.translate = read_translate(translate)
993
+
994
+ def _load_metadata(self, metadata_path, index_path):
995
+ data_index = read_jsonlike(index_path)
996
+ data_ids = set([item['id'] for item in data_index])
997
+
998
+ with open(metadata_path) as fp:
999
+ lines = fp.readlines()
1000
+
1001
+ append_ids = set()
1002
+
1003
+ self.data = []
1004
+ for line in lines:
1005
+ item = json.loads(line)
1006
+ if item['id'] in data_ids and item['id'] not in append_ids:
1007
+ self.data.append(item)
1008
+ append_ids.add(item['id'])
1009
+
1010
+ def __len__(self):
1011
+ return len(self.data)
1012
+
1013
+ def __getitem__(self, idx):
1014
+ item = self.data[idx]
1015
+ path:str = item["path"]
1016
+ description = self.generate_description(item)
1017
+ if self.return_audio:
1018
+ sr, duration = get_sr_and_duration_info(item)
1019
+ audio = self.audio_reader(path, sr, duration)
1020
+ else:
1021
+ audio = None
1022
+ if self.return_path:
1023
+ return audio, description, path
1024
+ return audio, description
1025
+
1026
+ @property
1027
+ def keysep(self):
1028
+ if self.lang == 'zh':
1029
+ return ',' if random.random() > 0.5 else '、'
1030
+ elif self.lang == 'en':
1031
+ return ', '
1032
+
1033
+ def generate_description(self, item):
1034
+ if random.random() > self.plain_rate:
1035
+ # dynamically generate prompt from given prompt template
1036
+ description = self.generate_description_dynamic(item)
1037
+ else:
1038
+ # use plain prompt, i.e. tags sequence separated by comma
1039
+ description = self.generate_description_plain(item)
1040
+ return description
1041
+
1042
+ def get_translation(self, k):
1043
+ k = k.strip()
1044
+ if k in self.translate:
1045
+ return self.translate[k]
1046
+ else:
1047
+ return k
1048
+
1049
+ def generate_description_plain(self, item):
1050
+ keywords = item['keywords']
1051
+ if self.lang != 'en':
1052
+ keywords = [self.get_translation(k) for k in keywords]
1053
+ return gen_plain_prompt(keywords, sep=self.keysep)
1054
+
1055
+ def generate_description_dynamic(self,item):
1056
+ desc = item.get('desc', 'none')
1057
+ if desc is None:
1058
+ desc = 'none'
1059
+ desc = desc.strip()
1060
+ if len(desc) > self.MAX_PROMPT_LEN:
1061
+ shorter_desc = desc[:self.MAX_PROMPT_LEN]
1062
+ # find last stop
1063
+ stop_idx = shorter_desc.rfind('.')
1064
+ if stop_idx == -1:
1065
+ stop_idx = shorter_desc.rfind('!')
1066
+ if stop_idx == -1:
1067
+ stop_idx = shorter_desc.rfind(',')
1068
+ if stop_idx == -1:
1069
+ stop_idx = self.MAX_PROMPT_LEN - 1
1070
+ desc = desc[:stop_idx+1]
1071
+ return desc
1072
+
1073
+ class SoundDataset(Dataset):
1074
+ def __init__(self,
1075
+ metadata_index: str,
1076
+ duration:float = 10,
1077
+ min_non_silent_duration:float = 3,
1078
+ sr:int = 0,
1079
+ return_path = False,
1080
+ return_audio = True,
1081
+ ):
1082
+ self.data = read_jsonlike(metadata_index)
1083
+ self.sr = sr
1084
+ self.reader = SafeAudioReader(duration, sr)
1085
+ self.duration = duration
1086
+ self.min_non_silent_duration = min_non_silent_duration
1087
+ self.return_audio = return_audio
1088
+ self.return_path = return_path
1089
+
1090
+ def __getitem__(self, index):
1091
+ item = self.data[index]
1092
+ if self.return_audio:
1093
+ origin_duration = item['duration']
1094
+ if origin_duration < self.min_non_silent_duration:
1095
+ audio = self.read_and_repeat_and_pad(item)
1096
+ else:
1097
+ audio = self.reader(item['path'], item['sample_rate'], origin_duration)
1098
+ else:
1099
+ audio = None
1100
+ desc = item['caption']
1101
+ if self.return_path:
1102
+ return audio, desc, item['path']
1103
+ else:
1104
+ return audio, desc
1105
+
1106
+ def __len__(self):
1107
+ return len(self.data)
1108
+
1109
+ def read_and_repeat_and_pad(self, item):
1110
+ path = item['path']
1111
+ try:
1112
+ # read
1113
+ clip, sr = torchaudio.load(path)
1114
+ if len(clip.shape) > 1:
1115
+ clip = torch.mean(clip, dim=0, keepdim=True)
1116
+ clip = resample(clip, sr, self.sr)
1117
+ #repeat
1118
+ n_repeats = math.ceil(self.min_non_silent_duration/item['duration'])
1119
+ clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1)
1120
+ #pad
1121
+ n_samples = int(self.duration * self.sr)
1122
+ if clip.shape[0] >= n_samples:
1123
+ audio = clip[:n_samples]
1124
+ else:
1125
+ audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype)
1126
+ start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0])))
1127
+ audio[start_pos:start_pos+clip.shape[0]] = clip
1128
+ return audio
1129
+
1130
+ except Exception as e:
1131
+ logger.error(f"Error reading {path}: {e}")
1132
+ wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32)
1133
+ return wav
1134
+
1135
+ class CombinedDataset(Dataset):
1136
+ @beartype
1137
+ def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
1138
+ self.datasets = datasets
1139
+ self.datasets_index = []
1140
+
1141
+ for i,dataset in enumerate(datasets):
1142
+ if dataset is None:
1143
+ continue
1144
+ for dup in range(ratios[i]):
1145
+ for j in range(len(dataset)):
1146
+ self.datasets_index.append((i,j))
1147
+
1148
+ def __len__(self):
1149
+ return len(self.datasets_index)
1150
+
1151
+ def __getitem__(self, idx):
1152
+ index = self.datasets_index[idx]
1153
+ i,j = index
1154
+ return self.datasets[i][j]
1155
+
1156
+ class CombinedDataset_random(Dataset):
1157
+ @beartype
1158
+ def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]):
1159
+ self.datasets = datasets
1160
+ self.datasets_index = []
1161
+
1162
+ for i,dataset in enumerate(datasets):
1163
+ if dataset is None:
1164
+ continue
1165
+ for dup in range(ratios[i]):
1166
+ for j in range(len(dataset)):
1167
+ self.datasets_index.append((i,j))
1168
+
1169
+ if num_examples > 0:
1170
+ self.random_choose = True
1171
+ self.dataset_len = num_examples
1172
+ else:
1173
+ self.random_choose = False
1174
+ self.dataset_len = len(self.datasets_index)
1175
+
1176
+ def __len__(self):
1177
+ return self.dataset_len
1178
+
1179
+ def __getitem__(self, idx):
1180
+ first_try = True
1181
+ try_cnt = 0
1182
+ while True:
1183
+ try:
1184
+ if(self.random_choose or not first_try):
1185
+ index2 = []
1186
+ index2.append(np.random.randint(0,len(self.datasets)))
1187
+ index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
1188
+ else:
1189
+ index2 = self.datasets_index[idx]
1190
+ first_try = False
1191
+ out = list(self.datasets[index2[0]][index2[1]])
1192
+ return out
1193
+ except:
1194
+ print("Error loadding ", index2)
1195
+ try_cnt += 1
1196
+ if(try_cnt>10):
1197
+ raise ValueError()
1198
+
1199
+ class SoundMixedDataset(Dataset):
1200
+ @staticmethod
1201
+ def music_desc(desc):
1202
+ return f'Music:<{desc}>'
1203
+ @staticmethod
1204
+ def sound_desc(desc):
1205
+ return f'Effect:<{desc}>'
1206
+
1207
+ def __init__(self,
1208
+ music_dataset: Dataset,
1209
+ sound_dataset: Dataset,
1210
+ mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例
1211
+ ) -> None:
1212
+ self.music_dataset = music_dataset
1213
+ self.sound_dataset = sound_dataset
1214
+ music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例
1215
+ #三个概率区间的左端点
1216
+ self.music_anchor = 0
1217
+ self.sound_anchor = music_r
1218
+ self.mix_anchor = music_r + sound_r
1219
+
1220
+ def __len__(self):
1221
+ return len(self.music_dataset)
1222
+
1223
+ def get_random_sound_data(self):
1224
+ idx = random.randint(0, len(self.sound_dataset)-1)
1225
+ return self.sound_dataset[idx]
1226
+
1227
+ def __getitem__(self, idx):
1228
+ p = random.random()
1229
+ if p >= self.mix_anchor:
1230
+ music, m_desc = self.music_dataset[idx]
1231
+ sound, s_desc = self.get_random_sound_data()
1232
+ audio = music + sound
1233
+ if(audio.abs().max()>1.0):
1234
+ music = music / audio.abs().max() * 0.95
1235
+ audio = audio / audio.abs().max() * 0.95
1236
+ desc = self.music_desc(m_desc) + self.sound_desc(s_desc)
1237
+ return audio[None,:], music[None,:], desc
1238
+ elif p >= self.sound_anchor:
1239
+ audio, desc = self.get_random_sound_data()
1240
+ return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc)
1241
+ else:
1242
+ audio, desc = self.music_dataset[idx]
1243
+ return audio[None,:], audio[None,:], self.music_desc(desc)
1244
+
1245
+
1246
+ class DecoTagDataset(Dataset):
1247
+ '''这个类把普通的datatset包装成适用于标签解耦学习的dataset'''
1248
+
1249
+ TAG_TYPES = ('genre', 'mood', 'insrument')
1250
+
1251
+ def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs):
1252
+ self.datasets = []
1253
+ for i, tag_t in enumerate(self.TAG_TYPES):
1254
+ kwargs['tag_types'] = [tag_map[tag_t]]
1255
+ kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本
1256
+ self.datasets.append(dataset_class(*args, **kwargs))
1257
+
1258
+ def __len__(self):
1259
+ return len(self.datasets[0])
1260
+
1261
+ def __getitem__(self, idx):
1262
+ audio, text = self.datasets[0][idx]
1263
+ texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1])
1264
+ return audio, texts
1265
+
1266
+
1267
+ class DecoTagWrapper:
1268
+ '''这是一个包装器,便于选择是否使用标签解耦学习'''
1269
+ def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False):
1270
+ self.dataset_class = dataset_class
1271
+ self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types))
1272
+ self.switch_on = switch_on
1273
+
1274
+ def __call__(self, *args, **kwargs):
1275
+ if self.switch_on:
1276
+ return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs)
1277
+ else:
1278
+ return self.dataset_class(*args, **kwargs)
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import json
4
+ from typing import List, Union
5
+
6
+ from torch.utils.data import Dataset
7
+ import torchaudio
8
+ from torchaudio.functional import resample
9
+ import torch
10
+ import numpy as np
11
+
12
+ from torch.nn.utils.rnn import pad_sequence
13
+
14
+ PARAGRAPH_GAP = 6
15
+ MIN_MUSIC_LEN = 3
16
+
17
+ def check_lryics(lyric):
18
+ _FILTER_STRING = [
19
+ '作词', '作曲', '编曲', '【', '策划',
20
+ '录音', '混音', '母带', ':', '制作',
21
+ '版权', '校对', '演奏', '制作', '伴奏'
22
+ ]
23
+ for item in _FILTER_STRING:
24
+ if item in lyric:
25
+ return True
26
+
27
+ return False
28
+
29
+
30
+
31
+ def process_lyrics(lines):
32
+ lyric_part = []
33
+ timestamp_part = []
34
+
35
+ timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
36
+
37
+ for i, line in enumerate(lines):
38
+
39
+ # 删除前几行的特定信息
40
+ if i<10 and check_lryics(line):
41
+ continue
42
+
43
+ # 检查是否包含有效的时间戳和歌词内容
44
+ if timestamp_pattern.match(line):
45
+ timestamp_end = line.rfind(']')
46
+ lyrics = line[timestamp_end + 1:].strip()
47
+ timestamps = line[:timestamp_end + 1]
48
+
49
+ if ':' in lyrics:
50
+ if len(lyrics.split(":")[0]) <=5:
51
+ lyrics = "".join(lyrics.split(":")[1:])
52
+ # if lyrics: # 确保歌词部分不是空的
53
+ # lyric_part.append(lyrics)
54
+ # timestamp_part.append(timestamps)
55
+ # print(processed_lyrics)
56
+ return timestamp_part, lyric_part
57
+
58
+ def get_timestamps(timestamp_part):
59
+
60
+ # 转换为秒
61
+
62
+ timestamps = []
63
+
64
+ for line in timestamp_part:
65
+ match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
66
+ if match:
67
+ minutes = int(match.group(1))
68
+ seconds = float(match.group(2))
69
+ millis = float(match.group(3)) if match.group(3) else 0
70
+ total_seconds = minutes * 60 + seconds + millis
71
+ timestamps.append(total_seconds)
72
+
73
+
74
+ return timestamps
75
+
76
+ def process_lyrics_lrc(lyrics):
77
+ timestamp_part, lyric_part = process_lyrics(lyrics)
78
+ # print(timestamp_part)
79
+ # print(lyric_part)
80
+ timestamps = get_timestamps(timestamp_part)
81
+ # print(timestamps)
82
+ if len(timestamps) == 0:
83
+ # print(f'{lyric_path}')
84
+ return []
85
+
86
+ slice_start = timestamps[0]
87
+ slice_start_idx = 0
88
+
89
+ output_list = []
90
+ for i in range(1, len(timestamps)):
91
+ # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
92
+ if timestamps[i] - slice_start > 30:
93
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
94
+
95
+ slice_start = timestamps[i]
96
+ slice_start_idx = i
97
+
98
+ return output_list
99
+
100
+
101
+
102
+ def process_lyrics_yrc(lyrics):
103
+
104
+ timestamps, lyric_part = extract_lrc(lyrics)
105
+
106
+ # timestamp_part, lyric_part = process_lyrics(lyrics)
107
+ # import pdb; pdb.set_trace()
108
+ # print(timestamp_part)
109
+ # print(lyric_part)
110
+ # timestamps = get_timestamps(timestamp_part)
111
+ # print(timestamps)
112
+ if len(timestamps) == 0:
113
+ # print(f'{lyric_path}')
114
+ return []
115
+
116
+ slice_start = timestamps[0]
117
+ slice_start_idx = 0
118
+
119
+ output_list = []
120
+ for i in range(1, len(timestamps)):
121
+ # 如果累积时间超过30秒,则进行切分
122
+ if timestamps[i] - slice_start > 30:
123
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
124
+
125
+ slice_start = timestamps[i]
126
+ slice_start_idx = i
127
+ # import pdb; pdb.set_trace()
128
+ return output_list
129
+
130
+ def extract_lrc(lyrics):
131
+ timestamp_part, lyric_part = [], []
132
+
133
+ for i, text in enumerate(lyrics):
134
+ # 提取中括号内的内容
135
+ bracket_content = re.search(r'\[(.*?)\]', text).group(1)
136
+ bracket_content = bracket_content.split(',')
137
+ # 提取小括号内的内容
138
+ parentheses_content = re.findall(r'\((.*?)\)', text)
139
+ # 提取其他内容
140
+ other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
141
+
142
+ # 数据怎么处理?
143
+ if i<10 and check_lryics(other_content):
144
+ continue
145
+ timestamp_part.append(float(bracket_content[0])/1000)
146
+ lyric_part.append(other_content)
147
+ return timestamp_part, lyric_part
148
+
149
+
150
+
151
+ class WYYSongDataset(Dataset):
152
+ def __init__(self,
153
+ metadata_path: Union[str, List[str]],
154
+ sr:int = 0,
155
+ use_lang = ['en', 'zh-cn'],
156
+ num_examples = -1,
157
+ max_dur = 20,
158
+ min_dur=0,
159
+ add_music=False,
160
+ pad_to_max= True,
161
+ ):
162
+
163
+ self.sr = sr
164
+ self.use_lang = use_lang
165
+ self.data = []
166
+ if type(metadata_path) == str:
167
+ metadata_path = [metadata_path]
168
+ for _meta in metadata_path:
169
+ self._load_metadata(_meta)
170
+ self.max_dur = max_dur
171
+ self.min_dur = min_dur
172
+ self.pad_to_max = pad_to_max
173
+ self.add_music = add_music
174
+
175
+ # buffer
176
+ self.lyric_buffer = {}
177
+
178
+ if(num_examples<=0):
179
+ self.dataset_len = len(self.data)
180
+ self.random_slc = False
181
+ else:
182
+ self.dataset_len = num_examples
183
+ self.random_slc = True
184
+
185
+
186
+ # 读取jsonl文件
187
+ def _load_metadata(self, metadata_path):
188
+ with open(metadata_path) as fp:
189
+ lines = fp.readlines()
190
+ for line in lines:
191
+ item = json.loads(line)
192
+ if '伴奏' not in item['path']:
193
+ # if "lang_type" in item and item['lang_type'] == 'en':
194
+ if "lang_type" in item:
195
+ self.data.append(item)
196
+
197
+
198
+ def __len__(self):
199
+ return self.dataset_len
200
+
201
+
202
+ def __getitem__(self, idx):
203
+ try_cnt = 0
204
+ while True:
205
+ if(self.random_slc):
206
+ idx = np.random.randint(0, len(self.data))
207
+ yrc_lyrics = []
208
+ lrc_lyrics = []
209
+ try:
210
+ info = self.data[idx]
211
+
212
+ # audio path
213
+ path = info["path"]
214
+ lang_type = info["lang_type"]
215
+ lyrics = info['lyrics'] # chinese
216
+ # lyrics = info['lyrics_phone']
217
+
218
+ # 随机选取一个lyric段落
219
+
220
+ parsed_lyrics = []
221
+ # st_idx = np.random.randint(0, len(lyrics))
222
+ for ly_id in range(len(lyrics)):
223
+ lyric = lyrics[ly_id].strip()
224
+ st, et, lyric = self.parse_lyric(lyric)
225
+
226
+ if et - st >= self.max_dur:
227
+ continue #TODO 前后外沿 [MUSIC]
228
+
229
+ if parsed_lyrics != []:
230
+ if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap
231
+ parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]'))
232
+ elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN:
233
+ parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]'))
234
+
235
+ lyric = lyric.replace("\xa0", " ")
236
+ lyric = " ".join(lyric.split())
237
+ parsed_lyrics.append((st, et, lyric))
238
+
239
+ assert parsed_lyrics != []
240
+ # if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur:
241
+ # print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a'))
242
+
243
+ parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics
244
+
245
+ possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]']
246
+ st_idx = np.random.choice(possible_starts)
247
+
248
+ paraphrase = []
249
+ for i in parsed_lyrics[st_idx+1:]:
250
+ if i[2] == '[GAP]':
251
+ break
252
+ paraphrase.append(i)
253
+ # print(paraphrase, lyrics)
254
+
255
+ while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur:
256
+ if np.random.rand() > 0.2:
257
+ paraphrase.pop(-1) # 大概率从后面截断
258
+ else:
259
+ paraphrase.pop(0) # 小概率截前面
260
+
261
+ st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP]
262
+ # print(st, et, lyric)
263
+ # import pdb; pdb.set_trace()
264
+ assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}"
265
+ # print(et-st, lyric)
266
+ # import pdb; pdb.set_trace()
267
+
268
+ if info["lang_type"] == 'en':
269
+ # print(len(lyric.split())/(et-st))
270
+ char_num = sum([len(lrc[-1].split()) for lrc in paraphrase])
271
+ assert 6 > char_num / (et-st) > 1
272
+ else:
273
+ # print(len(lyric.split())/(et-st))
274
+ char_num = sum([len(lrc[-1]) for lrc in paraphrase])
275
+ assert 6 > char_num / (et-st) > 1
276
+
277
+ # 读取音频文件
278
+ cur_sample_rate = torchaudio.info(path).sample_rate
279
+ offset = int(cur_sample_rate*st)
280
+ num_frames = int(cur_sample_rate * (et -st))
281
+ chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
282
+ # chunk = torch.zeros(1, 48000*15)
283
+ if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致
284
+ print(f"fail to load {path} from {st} to {et} !")
285
+ raise FileNotFoundError
286
+ # 随机选取一个channel
287
+ if(chunk.shape[0]>1):
288
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
289
+ else:
290
+ chunk = chunk[[0],:].float()
291
+
292
+ if(cur_sample_rate!=self.sr):
293
+ # print('a:',cur_sample_rate,chunk.shape)
294
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
295
+
296
+ if self.pad_to_max:
297
+ chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
298
+
299
+ # print(self.sz_cnt)
300
+ return chunk, lyric, [st, et], path, lang_type
301
+ except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok
302
+ # print("Error loadding ", info["path"])
303
+ try_cnt += 1
304
+ idx = np.random.randint(0, len(self.data))
305
+ if(try_cnt>100):
306
+ raise e
307
+
308
+ def parse_lyric(self, lyric):
309
+ pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
310
+ match = re.search(pattern, lyric)
311
+
312
+ start_time = float(match.group(1))
313
+ end_time = float(match.group(2))
314
+ content = match.group(3)
315
+ return start_time, end_time, content
316
+
317
+ def pad_2d_tensor(self, x, max_len, pad_id):
318
+ # 获取输入 tensor 的形状
319
+ batch_size, seq_len = x.size()
320
+ max_len = max(max_len, seq_len)
321
+ # 计算需要填充的长度
322
+ pad_len = max_len - seq_len
323
+
324
+ # 如果需要填充
325
+ if pad_len > 0:
326
+ # 创建填充 tensor
327
+ pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
328
+
329
+ # 沿第二个维度(列)连接输入 tensor 和填充 tensor
330
+ padded_tensor = torch.cat([x, pad_tensor], dim=1)
331
+ else:
332
+ # 如果不需要填充,直接返回输入 tensor
333
+ padded_tensor = x
334
+
335
+ return padded_tensor
336
+
337
+ def collect_data(data_list):
338
+ audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
339
+ lyrics = [data[1] for data in data_list]
340
+ st_et = [data[2] for data in data_list]
341
+ paths = [data[3] for data in data_list]
342
+ lang_types = [data[4] for data in data_list]
343
+ return audios, lyrics, st_et
344
+ # return audios, lyrics, st_et
345
+
346
+
347
+ def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False):
348
+ print(min_dur,max_dur)
349
+ print(train_jsonl_list)
350
+ # ["exp/wyy3_20240418_v2f.jsonl",
351
+ # "exp/tme_lyric_baokuan.jsonl"]
352
+ train_dataset = WYYSongDataset(
353
+ metadata_path = train_jsonl_list,
354
+ sr = 48000,
355
+ use_lang = ['zh-cn', 'en'],
356
+ num_examples = 10*10000,
357
+ min_dur=min_dur,
358
+ max_dur=max_dur,
359
+ add_music=add_music
360
+ )
361
+
362
+ valid_dataset = WYYSongDataset(
363
+ metadata_path = val_jsonl_list,
364
+ sr = 48000,
365
+ use_lang = ['zh-cn', 'en'],
366
+ num_examples = 500,
367
+ min_dur=min_dur,
368
+ max_dur=max_dur,
369
+ add_music=add_music
370
+ )
371
+ print(train_jsonl_list, "\t total_song = ", len(train_dataset.data))
372
+ return train_dataset, valid_dataset
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
3
+ from beartype import beartype
4
+ from beartype.door import is_bearable
5
+ import random
6
+ import pandas as pd
7
+ import os
8
+ from torchaudio.functional import resample
9
+ import torch
10
+ import typing as tp
11
+ from pathlib import Path
12
+ import torchaudio as ta
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import json
16
+ import yaml
17
+ import torchaudio
18
+ import math
19
+ import re
20
+ from loguru import logger
21
+
22
+ class Read_and_PadCrop_Normalized_T(torch.nn.Module):
23
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
24
+
25
+ super().__init__()
26
+
27
+ self.n_samples = n_samples
28
+ self.sample_rate = sample_rate
29
+ self.randomize = randomize
30
+
31
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
32
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
33
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
34
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
35
+ t_start = 0.
36
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
37
+ offset = 0
38
+ # print('c1:',chunk.shape)
39
+ else:
40
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
41
+ t_start = offset / float(cur_sample_rate) / duration
42
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
43
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
44
+ # print('offset:',offset)
45
+ # print('c0:',chunk.shape)
46
+ # Pad with silence if necessary.
47
+ if(chunk.shape[0]>1):
48
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
49
+ else:
50
+ chunk = chunk[[0],:].float()
51
+ if(cur_sample_rate!=self.sample_rate):
52
+ # print('a:',cur_sample_rate,chunk.shape)
53
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
54
+ # print('b:',self.sample_rate,chunk.shape)
55
+ if chunk.shape[-1] < self.n_samples:
56
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
57
+ else:
58
+ chunk = chunk[:,0:self.n_samples]
59
+ seconds_start = math.floor(offset / cur_sample_rate)
60
+ seconds_total = math.floor(duration)
61
+
62
+ return (
63
+ chunk,
64
+ t_start,
65
+ t_end,
66
+ seconds_start,
67
+ seconds_total
68
+ )
69
+
70
+
71
+ USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
72
+ if USE_DUMMY_AUDIO:
73
+ logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
74
+
75
+ class SafeAudioReader:
76
+ """
77
+ This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
78
+ """
79
+ def __init__(self,
80
+ duration: float, # 返回音频长度
81
+ sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
82
+ randomize: bool = True
83
+ ):
84
+ self.n_samples = int(sample_rate * max(duration, 0))
85
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
86
+
87
+ #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
88
+ def __call__(self,
89
+ filepath: os.PathLike, # 音频路径
90
+ origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
91
+ origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
92
+ ) -> torch.Tensor:
93
+ if USE_DUMMY_AUDIO:
94
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
95
+ return wav
96
+ try:
97
+ if origin_sample_rate is None or origin_duration is None:
98
+ audio_info = torchaudio.info(filepath)
99
+ origin_sample_rate = audio_info.sample_rate
100
+ origin_duration = audio_info.num_frames / origin_sample_rate
101
+ wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
102
+ except Exception as e:
103
+ logger.error(f"Error reading {filepath}: {e}")
104
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
105
+ return wav
106
+
107
+
108
+ class PromptTemplate:
109
+ def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
110
+ self.template_text = template_text
111
+ self.tag_map = tag_map
112
+ self.lang = lang
113
+
114
+ @property
115
+ def tags(self):
116
+ return tuple(self.tag_map.keys())
117
+
118
+ def apply(self, **kwargs):
119
+ for tag in list(kwargs.keys()):
120
+ if kwargs[tag] == '':
121
+ kwargs.pop(tag)
122
+ for tag in self.tags:
123
+ if tag in kwargs:
124
+ kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
125
+ else:
126
+ kwargs[tag] = ''
127
+ prompt = self.template_text.format(**kwargs)
128
+
129
+ return self.beautify(prompt)
130
+
131
+ def beautify(self, text):
132
+ if self.lang == 'en':
133
+ return self._beautify_en(text)
134
+ elif self.lang == 'zh':
135
+ return self._beautify_zh(text)
136
+ else:
137
+ raise ValueError(f'Unknown language {self.lang}')
138
+
139
+ @staticmethod
140
+ def _beautify_en(text):
141
+ # no continuous commas without content between them
142
+ text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
143
+ # no continuous whitespace
144
+ text = re.sub(r'\s+', ' ', text)
145
+ # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
146
+ text = re.sub(r'\s+,', r',', text)
147
+ text = re.sub(r',\s+', r', ', text)
148
+ # no whitespace before the full stop
149
+ text = re.sub(r'\s+\.', r'.', text)
150
+ # strip whitespace, comma, and replace ',.'
151
+ text = text.strip(' ,')
152
+ text = text.replace(',.', '.')
153
+ return text
154
+
155
+ @staticmethod
156
+ def _beautify_zh(text):
157
+ # no continuous commas without content between them
158
+ text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
159
+ text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
160
+ # assume there should be NO whitespace in Chinese
161
+ text = re.sub(r'\s+', r'', text)
162
+ # strip whitespace, comma, and replace ',。'
163
+ text = text.strip(', 、')
164
+ text = text.replace(',。', '。')
165
+ return text
166
+
167
+ def __repr__(self):
168
+ return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
169
+
170
+ __str__ = __repr__
171
+
172
+ def parse_prompt_template(prompt_template_text, lang='en'):
173
+ span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
174
+ tag_pattern = re.compile(r'{.+?}', re.DOTALL)
175
+
176
+ template_text = prompt_template_text.strip()
177
+ span_texts = span_pattern.findall(prompt_template_text)
178
+ tag_map = {}
179
+ for span_text in span_texts:
180
+ tag = tag_pattern.findall(span_text)[0].strip('{}')
181
+ tag_map[tag] = span_text
182
+ template_text = template_text.replace(span_text, '{'+tag+'}')
183
+
184
+ return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
185
+
186
+ def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
187
+ with open(path, 'r') as f:
188
+ lines = f.readlines()
189
+ cnt = 0
190
+ pts = []
191
+ for line in lines:
192
+ pt = parse_prompt_template(line, lang=lang)
193
+ cnt += 1
194
+ if len(pt.tags) < num:
195
+ logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
196
+ pts.append(pt)
197
+
198
+ return pts
199
+
200
+
201
+ def get_base_dir_file(key: os.PathLike):
202
+ base = os.path.basename(key)
203
+ dirname = os.path.basename(os.path.dirname(key))
204
+ return os.path.join(dirname, base)
205
+
206
+ def read_jsonlike(path: os.PathLike):
207
+ #json or jsonl
208
+ if str(path).endswith(".json"):
209
+ with open(path, 'r', encoding='utf8') as f:
210
+ data = json.load(f)
211
+ return data
212
+ elif str(path).endswith(".jsonl"):
213
+ with open(path, 'r', encoding='utf8') as f:
214
+ data = [json.loads(line) for line in f.readlines()]
215
+ return data
216
+ else:
217
+ raise ValueError("Unknown file format")
218
+
219
+ dist_prob_map = {
220
+ 1: (1.0,),
221
+ 2: (0.5, 0.5),
222
+ 3: (0.3, 0.4, 0.3),
223
+ 4: (0.2, 0.3, 0.3, 0.2),
224
+ 5: (0.2, 0.2, 0.3, 0.2, 0.1),
225
+ 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
226
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
227
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
228
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
229
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
230
+ }
231
+
232
+ dist_prob_map_low = {
233
+ 1: (1.0,),
234
+ 2: (0.8, 0.2),
235
+ 3: (0.8, 0.1, 0.1),
236
+ 4: (0.7, 0.1, 0.1, 0.1),
237
+ 5: (0.7, 0.1, 0.1, 0.05, 0.05),
238
+ 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
239
+ }
240
+
241
+ _bpm_range_rights = (
242
+ (40, '20-40'),
243
+ (60, '40-60'),
244
+ (66, '60-66'),
245
+ (76, '66-76'),
246
+ (108, '76-108'),
247
+ (120, '108-120'),
248
+ (168, '120-168'),
249
+ (176, '168-176'),
250
+ (200, '176-200')
251
+ )
252
+ _bpm_desc_map = {
253
+ '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
254
+ '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
255
+ '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
256
+ '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
257
+ '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
258
+ '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
259
+ '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
260
+ '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
261
+ '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
262
+ '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
263
+ }
264
+ _bpm_desc_map_zh = {
265
+ '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
266
+ '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
267
+ '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
268
+ '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
269
+ '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
270
+ '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
271
+ '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
272
+ '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
273
+ '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
274
+ '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
275
+ }
276
+ def get_bpm_range(bpm):
277
+ bpm = int(bpm)
278
+ for right, tag in _bpm_range_rights:
279
+ if bpm <= right:
280
+ return tag
281
+ return '>200'
282
+
283
+ def gen_bpm_descript(bpm, lang='en'):
284
+ bpm_range = get_bpm_range(bpm)
285
+ if lang == 'en':
286
+ return random.choice(_bpm_desc_map[bpm_range])
287
+ elif lang == 'zh':
288
+ return random.choice(_bpm_desc_map_zh[bpm_range])
289
+ else:
290
+ raise ValueError(f"Unknown language {lang}")
291
+
292
+ def read_translate(translate: Optional[Dict[str, os.PathLike]]):
293
+ if translate is None:
294
+ return None
295
+ return {k: read_jsonlike(path) for k, path in translate.items()}
296
+
297
+
298
+ class MagnaTagATuneDataset(Dataset):
299
+ def __init__(self):
300
+ pass
301
+
302
+
303
+ def tags_to_desc(tag_list, sep=',') -> str:
304
+ if not isinstance(tag_list, Sequence):
305
+ return str(tag_list)
306
+ if isinstance(tag_list, str):
307
+ return tag_list
308
+ if len(tag_list) <= 0:
309
+ return ''
310
+ elif len(tag_list) <= 5:
311
+ probs = dist_prob_map[len(tag_list)]
312
+ tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
313
+ random.shuffle(tag_list)
314
+ tag_list = tag_list[:tags_num]
315
+ return sep.join(tag_list)
316
+ else:
317
+ probs = dist_prob_map[5]
318
+ tags_num = random.choices(range(1, 6), probs)[0]
319
+ random.shuffle(tag_list)
320
+ tag_list = tag_list[:tags_num]
321
+ return sep.join(tag_list)
322
+
323
+ def get_sr_and_duration_info(item):
324
+ return item.get('sample_rate', None), item.get('duration', None)
325
+
326
+ class MtgJamendoDatasetFromJson(Dataset):
327
+ def __init__(self,
328
+ data_dir:str,
329
+ json_path:str,
330
+ duration:float=10,
331
+ sr:int = 0,
332
+ *,
333
+ lang = 'en',
334
+ return_path = False,
335
+ prompt_template_path: os.PathLike = None,
336
+ tag_types = [],
337
+ translate:Optional[Dict[str, os.PathLike]] = None,
338
+ ):
339
+ self.audio_reader = SafeAudioReader(duration, sr)
340
+
341
+ self.data_dir = data_dir
342
+ self._load_metadata_json(json_path)
343
+ self.sr = sr
344
+ self.duration = duration
345
+ self.return_path = return_path
346
+ self.lang = lang
347
+
348
+ self.use_dynamic_prompt = prompt_template_path is not None
349
+ if self.use_dynamic_prompt:
350
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
351
+ self.tag_types = tag_types
352
+
353
+ self.translate = read_translate(translate)
354
+ if not self.use_dynamic_prompt and self.lang != 'en':
355
+ raise NotImplementedError
356
+
357
+ #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
358
+ WEAK_TAG_LIST = ["title", "artist"]
359
+
360
+ def _load_metadata_json(self, json_path):
361
+ with open(json_path) as fp:
362
+ self.data = json.load(fp)
363
+
364
+ def convert_key_to_path(self, key):
365
+ return os.path.join(self.data_dir, get_base_dir_file(key))
366
+
367
+ def __len__(self):
368
+ return len(self.data)
369
+
370
+ def __getitem__(self, idx):
371
+ item = self.data[idx]
372
+ path = self.convert_key_to_path(item['key'])
373
+ description = self.generate_description(item)
374
+
375
+ sr, duration = get_sr_and_duration_info(item)
376
+ audio = self.audio_reader(path, sr, duration)
377
+
378
+ if self.return_path:
379
+ return audio, description, path
380
+ return audio, description
381
+
382
+ def tags_to_desc(self, tag_list, tag_type) -> str:
383
+ if self.lang == 'en':
384
+ return tags_to_desc(tag_list)
385
+ elif self.lang == 'zh':
386
+ translator = self.translate[tag_type]
387
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
388
+ return tags_to_desc(translated_tag_list, sep='、')
389
+
390
+ def generate_description(self, item):
391
+ if self.use_dynamic_prompt:
392
+ # dynamically generate prompt from given prompt template
393
+ prompt_template = random.choice(self.prompt_templates)
394
+ description = self.generate_description_dynamic(item, prompt_template)
395
+
396
+ else:
397
+ # use ordinary static prompt instead
398
+ description = self.generate_description_ordinary(item)
399
+ return description
400
+
401
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
402
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
403
+ exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
404
+ exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
405
+
406
+ if len(exists_strong_tag) > 0:
407
+ probs = dist_prob_map[len(exists_strong_tag)]
408
+ tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
409
+ random.shuffle(exists_strong_tag)
410
+ tags = exists_strong_tag[:tags_num]
411
+ weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
412
+ weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
413
+ random.shuffle(exists_weak_tag)
414
+ weak_tags = exists_weak_tag[:weak_tags_num]
415
+ tags += weak_tags
416
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
417
+ prompt = prompt_template.apply(**tags_args)
418
+ else:
419
+ # no strong tags, use all weak tags instead
420
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
421
+ prompt = prompt_template.apply(**tags_args)
422
+
423
+ return prompt
424
+
425
+ def generate_description_ordinary(self, data, thresh = 0.3):
426
+ # Initialize the description with title and artist
427
+ description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
428
+
429
+ # Add genre if available
430
+ if data["genre"] and random.random() > thresh:
431
+ genres = ', '.join(data["genre"])
432
+ description += f', belonging to the {genres} genres'
433
+
434
+ # Add moods if available
435
+ if data["moods"] and random.random() > thresh:
436
+ moods = ', '.join(data["moods"])
437
+ description += f'. This track conveys a {moods} mood'
438
+
439
+ # Add instruments if available
440
+ if data["instrument"] and random.random() > thresh:
441
+ instruments = ', '.join(data["instrument"])
442
+ description += f', and primarily features the following instruments: {instruments}'
443
+
444
+ # Add a period to end the description
445
+ description += '.'
446
+
447
+ return description
448
+
449
+ class AudioStockDataset(Dataset):
450
+ def __init__(self,
451
+ metadata_path:str,
452
+ duration:float=10,
453
+ sr:int = 0,
454
+ return_path = False,
455
+ return_audio = True,
456
+ prompt_template_path: os.PathLike = None,
457
+ tag_types = [],
458
+ lang = 'en',
459
+ translate:Optional[Dict[str, os.PathLike]] = None
460
+ ):
461
+ self.audio_reader = SafeAudioReader(duration, sr)
462
+
463
+ self._load_metadata(metadata_path)
464
+ self.sr = sr
465
+ self.duration = duration
466
+ self.return_path = return_path
467
+ self.return_audio = return_audio
468
+
469
+ self.use_dynamic_prompt = prompt_template_path is not None
470
+ if self.use_dynamic_prompt:
471
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
472
+ self.tag_types = tag_types
473
+
474
+ self.lang = lang
475
+ self.translate = read_translate(translate)
476
+
477
+ def _load_metadata(self, metadata_path):
478
+ with open(metadata_path) as fp:
479
+ lines = fp.readlines()
480
+ self.data = []
481
+ for line in lines:
482
+ item = json.loads(line)
483
+ self.data.append(item)
484
+ self.is_info_recorded = bool('Tags' in self.data[0])
485
+
486
+ def __len__(self):
487
+ return len(self.data)
488
+
489
+ def __getitem__(self, idx):
490
+ path:str = self.data[idx]["path"]
491
+ json_path = path[:path.rfind('.')] + ".json"
492
+ if self.is_info_recorded:
493
+ item = self.data[idx]
494
+ else:
495
+ try:
496
+ with open(json_path) as fp:
497
+ item:dict = json.load(fp)
498
+ except Exception as e:
499
+ print(f"Error loading json file {json_path} :\n{e}")
500
+ item = {}
501
+ description = self.generate_description(item)
502
+ if self.return_audio:
503
+ sr, duration = get_sr_and_duration_info(item)
504
+ audio = self.audio_reader(path, sr, duration)
505
+ else:
506
+ audio = None
507
+ if self.return_path:
508
+ return audio, description, path
509
+ return audio, description
510
+
511
+ def generate_description(self, item):
512
+ if self.use_dynamic_prompt:
513
+ # dynamically generate prompt from given prompt template
514
+ prompt_template = random.choice(self.prompt_templates)
515
+ description = self.generate_description_dynamic(item, prompt_template)
516
+ else:
517
+ # use ordinary static prompt instead
518
+ description = self.generate_description_ordinary(item)
519
+ return description
520
+
521
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
522
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
523
+
524
+ if len(exists_tag) > 0:
525
+ probs = dist_prob_map[len(exists_tag)]
526
+ tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
527
+ random.shuffle(exists_tag)
528
+ tags = exists_tag[:tags_num]
529
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
530
+ tags_args = self.handle_BPM_tag(tags_args)
531
+ prompt = prompt_template.apply(**tags_args)
532
+ else:
533
+ # no strong tags, use all weak tags instead
534
+ prompt = prompt_template.apply()
535
+
536
+ return prompt
537
+
538
+ def tags_to_desc(self, tag_list, tag_type) -> str:
539
+ if self.lang == 'en':
540
+ return tags_to_desc(tag_list)
541
+ elif self.lang == 'zh':
542
+ if tag_type == 'BPM':
543
+ return tags_to_desc(tag_list, sep='、')
544
+ translator = self.translate[tag_type]
545
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
546
+ return tags_to_desc(translated_tag_list, sep='、')
547
+
548
+ def handle_BPM_tag(self, tags_args):
549
+ if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
550
+ bpm = tags_args["BPM"]
551
+ del tags_args["BPM"]
552
+ tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
553
+ for tag_type in tag_types_used:
554
+ tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
555
+ return tags_args
556
+
557
+ def generate_description_ordinary(self, data, thresh = 0.3):
558
+ if self.lang != 'en':
559
+ raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
560
+ description = f'a piece of music by {data["Artist"]}'
561
+
562
+ # Add genre if available
563
+ if data["Genre"] and random.random() > thresh:
564
+ genres = ', '.join(data["Genre"])
565
+ description += f', belonging to the {genres} genres'
566
+
567
+ # Add moods if available
568
+ if data["Tags"] and random.random() > thresh:
569
+ tags = ', '.join(data["Tags"])
570
+ description += f'. This track contains the tags:{tags}'
571
+
572
+ # Add moods if available
573
+ if data["Mood"] and random.random() > thresh:
574
+ moods = ', '.join(data["Mood"])
575
+ description += f'. This track conveys a {moods} mood.'
576
+
577
+ # Add instruments if available
578
+ if data["Instrument"] and random.random() > thresh:
579
+ instruments = ', '.join(data["Instrument"])
580
+ description += f'. and primarily features the following instruments: {instruments}'
581
+
582
+ # Add a period to end the description
583
+ description += '.'
584
+
585
+ return description
586
+
587
+ def mp3_path_to_id(mp3_path):
588
+ return int(
589
+ mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
590
+ )
591
+
592
+ class TmeDataset(Dataset):
593
+ def __init__(self,
594
+ data_index:str,
595
+ music_info:str = None,
596
+ duration:float = 10,
597
+ sr:int = 0,
598
+ return_path = False,
599
+ return_audio = True,
600
+ prompt_format_path: os.PathLike = None,
601
+ tag_types = ['*'],
602
+ lang = 'zh',
603
+ translate: Optional[os.PathLike] = None,
604
+ prompt_dir: os.PathLike = None,
605
+ ):
606
+ self.audio_reader = SafeAudioReader(duration, sr)
607
+
608
+ self.sr = sr
609
+ self.duration = duration
610
+ self.return_path = return_path
611
+ self.return_audio = return_audio
612
+ self.lang = lang
613
+
614
+ self.use_ready_prompt = prompt_dir is not None
615
+
616
+ data_index = read_jsonlike(data_index)
617
+ self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
618
+ self.data_ids = list(self.data_index_dict.keys())
619
+
620
+ if not self.use_ready_prompt:
621
+ #读取音乐的信息文件
622
+ music_info = read_jsonlike(music_info)
623
+ if 'music' in music_info:
624
+ music_info = music_info['music']
625
+ self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
626
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
627
+ self.data_ids = list(self.data_index_dict.keys())
628
+
629
+ with open(prompt_format_path) as fp:
630
+ self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
631
+
632
+ #加载tag types,并分成一般的tag_types和关键的key_tag_types
633
+ if '*' in tag_types:
634
+ self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
635
+ else:
636
+ self.tag_types = tag_types
637
+
638
+ self.key_tag_types = []
639
+ if 'tag' in self.tag_types:
640
+ self.tag_types.remove('tag')
641
+ self.key_tag_types = list(self.prompt_formats['tag'].keys())
642
+
643
+ #加载translate翻译
644
+ if translate is not None:
645
+ self.translator = read_jsonlike(translate)
646
+ else:
647
+ data_ids_set = set(self.data_ids)
648
+ self.prompts_dict = {}
649
+ for fname in os.listdir(prompt_dir):
650
+ items = read_jsonlike(os.path.join(prompt_dir, fname))
651
+ for item in items:
652
+ if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
653
+ continue
654
+ if item['ID'] not in self.prompts_dict:
655
+ self.prompts_dict[item['ID']] = []
656
+ self.prompts_dict[item['ID']].append(item['Text'])
657
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
658
+ self.data_ids = list(self.data_index_dict.keys())
659
+
660
+ def tags_to_desc(self, tag_list) -> str:
661
+ if is_bearable(tag_list, int):
662
+ return str(tag_list)
663
+ if self.lang == 'zh':
664
+ return tags_to_desc(tag_list, sep=self.sep)
665
+ else:
666
+ translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
667
+ return tags_to_desc(translated_tag_list, sep=self.sep)
668
+
669
+ def gen_desc_of_tag(self, formats, tags):
670
+ fmt = random.choice(formats)
671
+ return fmt.format(self.tags_to_desc(tags))
672
+
673
+ @staticmethod
674
+ def check_valid(value):
675
+ if isinstance(value, int) or isinstance(value, float):
676
+ return value > 0
677
+ if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
678
+ return True
679
+ return False
680
+
681
+ @staticmethod
682
+ def remove_repeat(data):
683
+ #若专辑名和歌曲名相同,则只使用后者
684
+ album_name = data.get('专辑名', None)
685
+ if album_name is not None and album_name == data.get('歌曲名', None):
686
+ del data['专辑名']
687
+ return data
688
+
689
+ @property
690
+ def comma(self):
691
+ if self.lang == 'zh':
692
+ return ','
693
+ elif self.lang == 'en':
694
+ return ', '
695
+
696
+ @property
697
+ def sep(self):
698
+ if self.lang == 'zh':
699
+ return '、'
700
+ elif self.lang == 'en':
701
+ return ', '
702
+
703
+ def generate_description(self, data):
704
+ data = self.remove_repeat(data)
705
+ weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
706
+
707
+ key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
708
+
709
+ prompts = []
710
+ if len(weak_tags) > 0:
711
+ probs = dist_prob_map_low[len(weak_tags)]
712
+ if len(key_tags) > 0:
713
+ tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
714
+ else:
715
+ tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
716
+ random.shuffle(weak_tags)
717
+ tags = weak_tags[:tags_num]
718
+ for tag_type in tags:
719
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
720
+ prompts.append(tag_desc)
721
+
722
+ if len(key_tags) > 0:
723
+ probs = dist_prob_map[len(key_tags)]
724
+ tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
725
+ random.shuffle(key_tags)
726
+ tags = key_tags[:tags_num]
727
+ for tag_type in tags:
728
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
729
+ prompts.append(tag_desc)
730
+
731
+ random.shuffle(prompts)
732
+ return self.comma.join(prompts)
733
+
734
+ def is_valid_prompt_text(self, text):
735
+ for bad in ('抱歉','sorry', 'Sorry'):
736
+ if bad in text:
737
+ return False
738
+ return True
739
+
740
+ def get_ready_prompt(self, path):
741
+ sid = mp3_path_to_id(path)
742
+ return random.choice(self.prompts_dict[sid])
743
+
744
+ def __len__(self):
745
+ return len(self.data_ids)
746
+
747
+ def __getitem__(self, idx):
748
+ data_id = self.data_ids[idx]
749
+ item = self.data_index_dict[data_id]
750
+ path = item['path']
751
+ if not self.use_ready_prompt:
752
+ info = self.music_info_dict[data_id]
753
+ description = self.generate_description(info)
754
+ else:
755
+ description = self.get_ready_prompt(path)
756
+ if self.return_audio:
757
+ sr, duration = get_sr_and_duration_info(item)
758
+ audio = self.audio_reader(path, sr, duration)
759
+ else:
760
+ audio = None
761
+ if self.return_path:
762
+ return audio, description, path
763
+ return audio, description
764
+
765
+ class CombinedDataset(Dataset):
766
+ @beartype
767
+ def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
768
+ self.datasets = datasets
769
+ self.datasets_index = []
770
+
771
+ for i,dataset in enumerate(datasets):
772
+ if dataset is None:
773
+ continue
774
+ for dup in range(ratios[i]):
775
+ for j in range(len(dataset)):
776
+ self.datasets_index.append((i,j))
777
+
778
+ def __len__(self):
779
+ return len(self.datasets_index)
780
+
781
+ def __getitem__(self, idx):
782
+ index = self.datasets_index[idx]
783
+ i,j = index
784
+ return self.datasets[i][j]
785
+
786
+ class CombinedDataset_random(Dataset):
787
+ @beartype
788
+ def __init__(self,
789
+ num_examples:int,
790
+ datasets: Sequence[Dataset], ratios: Sequence[int]
791
+ ):
792
+ self.datasets = datasets
793
+ self.datasets_index = []
794
+
795
+ for i,dataset in enumerate(datasets):
796
+ if dataset is None:
797
+ continue
798
+ for dup in range(ratios[i]):
799
+ for j in range(len(dataset)):
800
+ self.datasets_index.append((i,j))
801
+ if num_examples > 0:
802
+ self.random_choose = True
803
+ self.dataset_len = num_examples
804
+ else:
805
+ self.random_choose = False
806
+ self.dataset_len = len(self.datasets_index)
807
+
808
+ def __len__(self):
809
+ return self.dataset_len
810
+
811
+ def __getitem__(self, idx):
812
+ first_try = True
813
+ try_cnt = 0
814
+ while True:
815
+ try:
816
+ if(self.random_choose or not first_try):
817
+ index2 = []
818
+ index2.append(np.random.randint(0,len(self.datasets)))
819
+ index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
820
+ else:
821
+ index2 = self.datasets_index[idx]
822
+ first_try = False
823
+ out = self.datasets[index2[0]][index2[1]]
824
+ if(len(out[0].shape)==1):out[0]=out[0][None,:]
825
+ return out
826
+ except:
827
+ print("Error loadding ", index2)
828
+ try_cnt += 1
829
+ if(try_cnt>10):
830
+ raise ValueError()
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
3
+ from beartype import beartype
4
+ from beartype.door import is_bearable
5
+ import random
6
+ import pandas as pd
7
+ import os
8
+ from torchaudio.functional import resample
9
+ import torch
10
+ import typing as tp
11
+ from pathlib import Path
12
+ import torchaudio as ta
13
+ import torch.nn.functional as F
14
+ import numpy as np
15
+ import json
16
+ import yaml
17
+ import torchaudio
18
+ import math
19
+ import re
20
+ from loguru import logger
21
+
22
+ def gen_plain_prompt(key_list, sep=', '):
23
+ if len(key_list) == 0:
24
+ return 'none'
25
+
26
+ key_list = [k.strip() for k in key_list]
27
+
28
+ if len(key_list) > 10:
29
+ random.shuffle(key_list)
30
+ key_list = key_list[:10]
31
+
32
+ probs = dist_prob_map[len(key_list)]
33
+
34
+ num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
35
+
36
+ random.shuffle(key_list)
37
+ tags = key_list[:num_tags]
38
+ tags_str = sep.join(tags)
39
+ return tags_str
40
+
41
+ class Read_and_PadCrop_Normalized_T(torch.nn.Module):
42
+
43
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
44
+
45
+ super().__init__()
46
+
47
+ self.n_samples = n_samples
48
+ self.sample_rate = sample_rate
49
+ self.randomize = randomize
50
+ self.prob = {"is_start":0.2, "is_end":0.9}
51
+ self.shift_secs = 5
52
+
53
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
54
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
55
+ raise ValueError(duration,float(self.n_samples),self.sample_rate)
56
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
57
+ t_start = 0.
58
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
59
+ offset = 0
60
+ is_start = True
61
+ is_end = True
62
+ else:
63
+ prob = random.uniform(0,1)
64
+ if(prob<self.prob['is_start']):
65
+ is_start = True
66
+ is_end = False
67
+ offset = 0
68
+ elif(prob>self.prob['is_end']):
69
+ is_start = False
70
+ is_end = True
71
+ offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)
72
+ else:
73
+ is_start = False
74
+ is_end = False
75
+ offset = np.random.randint(self.shift_secs*cur_sample_rate, \
76
+ int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate)
77
+ t_start = offset / float(cur_sample_rate) / duration
78
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
79
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
80
+ if(chunk.shape[0]>1):
81
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
82
+ else:
83
+ chunk = chunk[[0],:].float()
84
+ if(cur_sample_rate!=self.sample_rate):
85
+ # print('a:',cur_sample_rate,chunk.shape)
86
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
87
+ # print('b:',self.sample_rate,chunk.shape)
88
+ if chunk.shape[-1] != self.n_samples:
89
+ raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
90
+ # if chunk.shape[-1] < self.n_samples:
91
+ # chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
92
+ # else:
93
+ # chunk = chunk[:,0:self.n_samples]
94
+ seconds_start = math.floor(offset / cur_sample_rate)
95
+ seconds_total = math.floor(duration)
96
+
97
+ # # In this dataset, we do not introduce zeros
98
+ # if(is_start):
99
+ # chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples]
100
+ # elif(is_end):
101
+ # chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:]
102
+
103
+ return (
104
+ chunk,
105
+ t_start,
106
+ t_end,
107
+ seconds_start,
108
+ seconds_total,
109
+ is_start,
110
+ is_end,
111
+ )
112
+
113
+
114
+ USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
115
+ if USE_DUMMY_AUDIO:
116
+ logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
117
+
118
+ class SafeAudioReader:
119
+ """
120
+ This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
121
+ """
122
+ def __init__(self,
123
+ duration: float, # 返回音频长度
124
+ sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
125
+ randomize: bool = True
126
+ ):
127
+ self.n_samples = int(sample_rate * max(duration, 0))
128
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
129
+
130
+ #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
131
+ def __call__(self,
132
+ filepath: os.PathLike, # 音频路径
133
+ origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
134
+ origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
135
+ ) -> torch.Tensor:
136
+ if USE_DUMMY_AUDIO:
137
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
138
+ return wav
139
+ try:
140
+ # if origin_sample_rate is None or origin_duration is None:
141
+ # audio_info = torchaudio.info(filepath)
142
+ # origin_sample_rate = audio_info.sample_rate
143
+ # origin_duration = audio_info.num_frames / origin_sample_rate
144
+ audio_info = torchaudio.info(filepath)
145
+ origin_sample_rate = audio_info.sample_rate
146
+ origin_duration = audio_info.num_frames / origin_sample_rate
147
+ wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate)
148
+ except Exception as e:
149
+ logger.error(f"Error reading {filepath}: {e}")
150
+ raise FileNotFoundError(filepath)
151
+ return wav, is_start, is_end
152
+
153
+
154
+ class PromptTemplate:
155
+ def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
156
+ self.template_text = template_text
157
+ self.tag_map = tag_map
158
+ self.lang = lang
159
+
160
+ @property
161
+ def tags(self):
162
+ return tuple(self.tag_map.keys())
163
+
164
+ def apply(self, **kwargs):
165
+ for tag in list(kwargs.keys()):
166
+ if kwargs[tag] == '':
167
+ kwargs.pop(tag)
168
+ for tag in self.tags:
169
+ if tag in kwargs:
170
+ kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
171
+ else:
172
+ kwargs[tag] = ''
173
+ prompt = self.template_text.format(**kwargs)
174
+
175
+ return self.beautify(prompt)
176
+
177
+ def beautify(self, text):
178
+ if self.lang == 'en':
179
+ return self._beautify_en(text)
180
+ elif self.lang == 'zh':
181
+ return self._beautify_zh(text)
182
+ else:
183
+ raise ValueError(f'Unknown language {self.lang}')
184
+
185
+ @staticmethod
186
+ def _beautify_en(text):
187
+ # no continuous commas without content between them
188
+ text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
189
+ # no continuous whitespace
190
+ text = re.sub(r'\s+', ' ', text)
191
+ # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
192
+ text = re.sub(r'\s+,', r',', text)
193
+ text = re.sub(r',\s+', r', ', text)
194
+ # no whitespace before the full stop
195
+ text = re.sub(r'\s+\.', r'.', text)
196
+ # strip whitespace, comma, and replace ',.'
197
+ text = text.strip(' ,')
198
+ text = text.replace(',.', '.')
199
+ return text
200
+
201
+ @staticmethod
202
+ def _beautify_zh(text):
203
+ # no continuous commas without content between them
204
+ text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
205
+ text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
206
+ # assume there should be NO whitespace in Chinese
207
+ text = re.sub(r'\s+', r'', text)
208
+ # strip whitespace, comma, and replace ',。'
209
+ text = text.strip(', 、')
210
+ text = text.replace(',。', '。')
211
+ return text
212
+
213
+ def __repr__(self):
214
+ return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
215
+
216
+ __str__ = __repr__
217
+
218
+ def parse_prompt_template(prompt_template_text, lang='en'):
219
+ span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
220
+ tag_pattern = re.compile(r'{.+?}', re.DOTALL)
221
+
222
+ template_text = prompt_template_text.strip()
223
+ span_texts = span_pattern.findall(prompt_template_text)
224
+ tag_map = {}
225
+ for span_text in span_texts:
226
+ tag = tag_pattern.findall(span_text)[0].strip('{}')
227
+ tag_map[tag] = span_text
228
+ template_text = template_text.replace(span_text, '{'+tag+'}')
229
+
230
+ return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
231
+
232
+ def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
233
+ with open(path, 'r') as f:
234
+ lines = f.readlines()
235
+ cnt = 0
236
+ pts = []
237
+ for line in lines:
238
+ pt = parse_prompt_template(line, lang=lang)
239
+ cnt += 1
240
+ if len(pt.tags) < num:
241
+ logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
242
+ pts.append(pt)
243
+
244
+ return pts
245
+
246
+
247
+ def get_base_dir_file(key: os.PathLike):
248
+ base = os.path.basename(key)
249
+ dirname = os.path.basename(os.path.dirname(key))
250
+ return os.path.join(dirname, base)
251
+
252
+ def read_jsonlike(path: os.PathLike):
253
+ #json or jsonl
254
+ if str(path).endswith(".json"):
255
+ with open(path, 'r', encoding='utf8') as f:
256
+ data = json.load(f)
257
+ return data
258
+ elif str(path).endswith(".jsonl"):
259
+ with open(path, 'r', encoding='utf8') as f:
260
+ data = [json.loads(line) for line in f.readlines()]
261
+ return data
262
+ else:
263
+ raise ValueError("Unknown file format")
264
+
265
+ dist_prob_map = {
266
+ 1: (1.0,),
267
+ 2: (0.5, 0.5),
268
+ 3: (0.3, 0.4, 0.3),
269
+ 4: (0.2, 0.3, 0.3, 0.2),
270
+ 5: (0.2, 0.2, 0.3, 0.2, 0.1),
271
+ 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
272
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
273
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
274
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
275
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
276
+ }
277
+
278
+ dist_prob_map_low = {
279
+ 1: (1.0,),
280
+ 2: (0.8, 0.2),
281
+ 3: (0.8, 0.1, 0.1),
282
+ 4: (0.7, 0.1, 0.1, 0.1),
283
+ 5: (0.7, 0.1, 0.1, 0.05, 0.05),
284
+ 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
285
+ }
286
+
287
+ _bpm_range_rights = (
288
+ (40, '20-40'),
289
+ (60, '40-60'),
290
+ (66, '60-66'),
291
+ (76, '66-76'),
292
+ (108, '76-108'),
293
+ (120, '108-120'),
294
+ (168, '120-168'),
295
+ (176, '168-176'),
296
+ (200, '176-200')
297
+ )
298
+ _bpm_desc_map = {
299
+ '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
300
+ '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
301
+ '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
302
+ '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
303
+ '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
304
+ '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
305
+ '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
306
+ '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
307
+ '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
308
+ '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
309
+ }
310
+ _bpm_desc_map_zh = {
311
+ '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
312
+ '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
313
+ '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
314
+ '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
315
+ '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
316
+ '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
317
+ '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
318
+ '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
319
+ '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
320
+ '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
321
+ }
322
+ def get_bpm_range(bpm):
323
+ bpm = int(bpm)
324
+ for right, tag in _bpm_range_rights:
325
+ if bpm <= right:
326
+ return tag
327
+ return '>200'
328
+
329
+ def gen_bpm_descript(bpm, lang='en'):
330
+ bpm_range = get_bpm_range(bpm)
331
+ if lang == 'en':
332
+ return random.choice(_bpm_desc_map[bpm_range])
333
+ elif lang == 'zh':
334
+ return random.choice(_bpm_desc_map_zh[bpm_range])
335
+ else:
336
+ raise ValueError(f"Unknown language {lang}")
337
+
338
+ def read_translate(translate: Optional[Dict[str, os.PathLike]]):
339
+ if translate is None:
340
+ return None
341
+ if isinstance(translate, str):
342
+ return read_jsonlike(translate)
343
+ return {k: read_jsonlike(path) for k, path in translate.items()}
344
+
345
+
346
+ class MagnaTagATuneDataset(Dataset):
347
+ def __init__(self):
348
+ pass
349
+
350
+
351
+ def tags_to_desc(tag_list, sep=',') -> str:
352
+ if not isinstance(tag_list, Sequence):
353
+ return str(tag_list)
354
+ if isinstance(tag_list, str):
355
+ return tag_list
356
+ if len(tag_list) <= 0:
357
+ return ''
358
+ elif len(tag_list) <= 5:
359
+ probs = dist_prob_map[len(tag_list)]
360
+ tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
361
+ random.shuffle(tag_list)
362
+ tag_list = tag_list[:tags_num]
363
+ return sep.join(tag_list)
364
+ else:
365
+ probs = dist_prob_map[5]
366
+ tags_num = random.choices(range(1, 6), probs)[0]
367
+ random.shuffle(tag_list)
368
+ tag_list = tag_list[:tags_num]
369
+ return sep.join(tag_list)
370
+
371
+ def get_sr_and_duration_info(item):
372
+ return item.get('sample_rate', None), item.get('duration', None)
373
+
374
+ class MtgJamendoDatasetFromJson(Dataset):
375
+ def __init__(self,
376
+ data_dir:str,
377
+ json_path:str,
378
+ duration:float=10,
379
+ sr:int = 0,
380
+ *,
381
+ lang = 'en',
382
+ return_path = False,
383
+ prompt_template_path: os.PathLike = None,
384
+ tag_types = [],
385
+ translate:Optional[Dict[str, os.PathLike]] = None,
386
+ ):
387
+ self.audio_reader = SafeAudioReader(duration, sr)
388
+
389
+ self.data_dir = data_dir
390
+ self._load_metadata_json(json_path)
391
+ self.sr = sr
392
+ self.duration = duration
393
+ self.return_path = return_path
394
+ self.lang = lang
395
+
396
+ self.use_dynamic_prompt = prompt_template_path is not None
397
+ if self.use_dynamic_prompt:
398
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
399
+ self.tag_types = tag_types
400
+
401
+ self.translate = read_translate(translate)
402
+ if not self.use_dynamic_prompt and self.lang != 'en':
403
+ raise NotImplementedError
404
+
405
+ #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
406
+ WEAK_TAG_LIST = ["title", "artist"]
407
+
408
+ def _load_metadata_json(self, json_path):
409
+ with open(json_path) as fp:
410
+ self.data = json.load(fp)
411
+
412
+ def convert_key_to_path(self, key):
413
+ return os.path.join(self.data_dir, get_base_dir_file(key))
414
+
415
+ def __len__(self):
416
+ return len(self.data)
417
+
418
+ def __getitem__(self, idx):
419
+ item = self.data[idx]
420
+ path = self.convert_key_to_path(item['key'])
421
+ description = self.generate_description(item)
422
+
423
+ sr, duration = get_sr_and_duration_info(item)
424
+ audio, is_start, is_end = self.audio_reader(path, sr, duration)
425
+
426
+ if self.return_path:
427
+ return audio, description, path
428
+ return audio, description, is_start, is_end
429
+
430
+ def tags_to_desc(self, tag_list, tag_type) -> str:
431
+ if self.lang == 'en':
432
+ return tags_to_desc(tag_list)
433
+ elif self.lang == 'zh':
434
+ translator = self.translate[tag_type]
435
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
436
+ return tags_to_desc(translated_tag_list, sep='、')
437
+
438
+ def generate_description(self, item):
439
+ if self.use_dynamic_prompt:
440
+ # dynamically generate prompt from given prompt template
441
+ prompt_template = random.choice(self.prompt_templates)
442
+ description = self.generate_description_dynamic(item, prompt_template)
443
+
444
+ else:
445
+ # use ordinary static prompt instead
446
+ description = self.generate_description_ordinary(item)
447
+ return description
448
+
449
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
450
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
451
+ exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
452
+ exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
453
+
454
+ if len(exists_strong_tag) > 0:
455
+ probs = dist_prob_map[len(exists_strong_tag)]
456
+ tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
457
+ random.shuffle(exists_strong_tag)
458
+ tags = exists_strong_tag[:tags_num]
459
+ weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
460
+ weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
461
+ random.shuffle(exists_weak_tag)
462
+ weak_tags = exists_weak_tag[:weak_tags_num]
463
+ tags += weak_tags
464
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
465
+ prompt = prompt_template.apply(**tags_args)
466
+ else:
467
+ # no strong tags, use all weak tags instead
468
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
469
+ prompt = prompt_template.apply(**tags_args)
470
+
471
+ return prompt
472
+
473
+ def generate_description_ordinary(self, data, thresh = 0.3):
474
+ # Initialize the description with title and artist
475
+ description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
476
+
477
+ # Add genre if available
478
+ if data["genre"] and random.random() > thresh:
479
+ genres = ', '.join(data["genre"])
480
+ description += f', belonging to the {genres} genres'
481
+
482
+ # Add moods if available
483
+ if data["moods"] and random.random() > thresh:
484
+ moods = ', '.join(data["moods"])
485
+ description += f'. This track conveys a {moods} mood'
486
+
487
+ # Add instruments if available
488
+ if data["instrument"] and random.random() > thresh:
489
+ instruments = ', '.join(data["instrument"])
490
+ description += f', and primarily features the following instruments: {instruments}'
491
+
492
+ # Add a period to end the description
493
+ description += '.'
494
+
495
+ return description
496
+
497
+ class AudioStockDataset(Dataset):
498
+ def __init__(self,
499
+ metadata_path:str,
500
+ duration:float=10,
501
+ sr:int = 0,
502
+ return_path = False,
503
+ return_audio = True,
504
+ prompt_template_path: os.PathLike = None,
505
+ tag_types = [],
506
+ lang = 'en',
507
+ translate:Optional[Dict[str, os.PathLike]] = None
508
+ ):
509
+ self.audio_reader = SafeAudioReader(duration, sr)
510
+
511
+ self.duration = duration
512
+ self._load_metadata(metadata_path)
513
+ self.sr = sr
514
+ self.return_path = return_path
515
+ self.return_audio = return_audio
516
+
517
+ self.use_dynamic_prompt = prompt_template_path is not None
518
+ if self.use_dynamic_prompt:
519
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
520
+ self.tag_types = tag_types
521
+
522
+ self.lang = lang
523
+ self.translate = read_translate(translate)
524
+
525
+ def _load_metadata(self, metadata_path):
526
+ with open(metadata_path) as fp:
527
+ lines = fp.readlines()
528
+ self.data = []
529
+ for line in lines:
530
+ item = json.loads(line)
531
+ if(item['duration']>self.duration+10):
532
+ self.data.append(item)
533
+ self.is_info_recorded = bool('Tags' in self.data[0])
534
+
535
+ def __len__(self):
536
+ return len(self.data)
537
+
538
+ def __getitem__(self, idx):
539
+ path:str = self.data[idx]["path"]
540
+ json_path = path[:path.rfind('.')] + ".json"
541
+ if self.is_info_recorded:
542
+ item = self.data[idx]
543
+ else:
544
+ try:
545
+ with open(json_path) as fp:
546
+ item:dict = json.load(fp)
547
+ except Exception as e:
548
+ print(f"Error loading json file {json_path} :\n{e}")
549
+ item = {}
550
+ description = self.generate_description(item)
551
+ if self.return_audio:
552
+ sr, duration = get_sr_and_duration_info(item)
553
+ audio, is_start, is_end = self.audio_reader(path, sr, duration)
554
+ else:
555
+ audio = None
556
+ if self.return_path:
557
+ return audio, description, path, is_start, is_end
558
+ else:
559
+ return audio, description, is_start, is_end
560
+
561
+ def generate_description(self, item):
562
+ if self.use_dynamic_prompt:
563
+ # dynamically generate prompt from given prompt template
564
+ prompt_template = random.choice(self.prompt_templates)
565
+ description = self.generate_description_dynamic(item, prompt_template)
566
+ else:
567
+ # use ordinary static prompt instead
568
+ description = self.generate_description_ordinary(item)
569
+ return description
570
+
571
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
572
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
573
+
574
+ if len(exists_tag) > 0:
575
+ probs = dist_prob_map[len(exists_tag)]
576
+ tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
577
+ random.shuffle(exists_tag)
578
+ tags = exists_tag[:tags_num]
579
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
580
+ tags_args = self.handle_BPM_tag(tags_args)
581
+ prompt = prompt_template.apply(**tags_args)
582
+ else:
583
+ # no strong tags, use all weak tags instead
584
+ prompt = prompt_template.apply()
585
+
586
+ return prompt
587
+
588
+ def tags_to_desc(self, tag_list, tag_type) -> str:
589
+ if self.lang == 'en':
590
+ return tags_to_desc(tag_list)
591
+ elif self.lang == 'zh':
592
+ if tag_type == 'BPM':
593
+ return tags_to_desc(tag_list, sep='、')
594
+ translator = self.translate[tag_type]
595
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
596
+ return tags_to_desc(translated_tag_list, sep='、')
597
+
598
+ def handle_BPM_tag(self, tags_args):
599
+ if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
600
+ bpm = tags_args["BPM"]
601
+ del tags_args["BPM"]
602
+ tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
603
+ for tag_type in tag_types_used:
604
+ tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
605
+ return tags_args
606
+
607
+ def generate_description_ordinary(self, data, thresh = 0.3):
608
+ if self.lang != 'en':
609
+ raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
610
+ description = f'a piece of music by {data["Artist"]}'
611
+
612
+ # Add genre if available
613
+ if data["Genre"] and random.random() > thresh:
614
+ genres = ', '.join(data["Genre"])
615
+ description += f', belonging to the {genres} genres'
616
+
617
+ # Add moods if available
618
+ if data["Tags"] and random.random() > thresh:
619
+ tags = ', '.join(data["Tags"])
620
+ description += f'. This track contains the tags:{tags}'
621
+
622
+ # Add moods if available
623
+ if data["Mood"] and random.random() > thresh:
624
+ moods = ', '.join(data["Mood"])
625
+ description += f'. This track conveys a {moods} mood.'
626
+
627
+ # Add instruments if available
628
+ if data["Instrument"] and random.random() > thresh:
629
+ instruments = ', '.join(data["Instrument"])
630
+ description += f'. and primarily features the following instruments: {instruments}'
631
+
632
+ # Add a period to end the description
633
+ description += '.'
634
+
635
+ return description
636
+
637
+ def mp3_path_to_id(mp3_path):
638
+ return int(
639
+ mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
640
+ )
641
+
642
+ class TmeDataset(Dataset):
643
+ def __init__(self,
644
+ data_index:str,
645
+ music_info:str = None,
646
+ duration:float = 10,
647
+ sr:int = 0,
648
+ return_path = False,
649
+ return_audio = True,
650
+ prompt_format_path: os.PathLike = None,
651
+ tag_types = ['*'],
652
+ lang = 'zh',
653
+ translate: Optional[os.PathLike] = None,
654
+ prompt_dir: os.PathLike = None,
655
+ ):
656
+ self.audio_reader = SafeAudioReader(duration, sr)
657
+
658
+ self.sr = sr
659
+ self.duration = duration
660
+ self.return_path = return_path
661
+ self.return_audio = return_audio
662
+ self.lang = lang
663
+
664
+ self.use_ready_prompt = prompt_dir is not None
665
+
666
+ data_index = read_jsonlike(data_index)
667
+ data_index = [d for d in data_index if d['duration']>self.duration+10]
668
+ self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
669
+ self.data_ids = list(self.data_index_dict.keys())
670
+
671
+ if not self.use_ready_prompt:
672
+ #读取音乐的信息文件
673
+ music_info = read_jsonlike(music_info)
674
+ if 'music' in music_info:
675
+ music_info = music_info['music']
676
+ self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
677
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
678
+ self.data_ids = list(self.data_index_dict.keys())
679
+
680
+ with open(prompt_format_path) as fp:
681
+ self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
682
+
683
+ #加载tag types,并分成一般的tag_types和关键的key_tag_types
684
+ if '*' in tag_types:
685
+ self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
686
+ else:
687
+ self.tag_types = tag_types
688
+
689
+ self.key_tag_types = []
690
+ if 'tag' in self.tag_types:
691
+ self.tag_types.remove('tag')
692
+ self.key_tag_types = list(self.prompt_formats['tag'].keys())
693
+
694
+ #加载translate翻译
695
+ if translate is not None:
696
+ self.translator = read_jsonlike(translate)
697
+ else:
698
+ data_ids_set = set(self.data_ids)
699
+ self.prompts_dict = {}
700
+ for fname in os.listdir(prompt_dir):
701
+ items = read_jsonlike(os.path.join(prompt_dir, fname))
702
+ for item in items:
703
+ if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
704
+ continue
705
+ if item['ID'] not in self.prompts_dict:
706
+ self.prompts_dict[item['ID']] = []
707
+ self.prompts_dict[item['ID']].append(item['Text'])
708
+ self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
709
+ self.data_ids = list(self.data_index_dict.keys())
710
+
711
+ def tags_to_desc(self, tag_list) -> str:
712
+ if is_bearable(tag_list, int):
713
+ return str(tag_list)
714
+ if self.lang == 'zh':
715
+ return tags_to_desc(tag_list, sep=self.sep)
716
+ else:
717
+ translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
718
+ return tags_to_desc(translated_tag_list, sep=self.sep)
719
+
720
+ def gen_desc_of_tag(self, formats, tags):
721
+ fmt = random.choice(formats)
722
+ return fmt.format(self.tags_to_desc(tags))
723
+
724
+ @staticmethod
725
+ def check_valid(value):
726
+ if isinstance(value, int) or isinstance(value, float):
727
+ return value > 0
728
+ if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
729
+ return True
730
+ return False
731
+
732
+ @staticmethod
733
+ def remove_repeat(data):
734
+ #若专辑名和歌曲名相同,则只使用后者
735
+ album_name = data.get('专辑名', None)
736
+ if album_name is not None and album_name == data.get('歌曲名', None):
737
+ del data['专辑名']
738
+ return data
739
+
740
+ @property
741
+ def comma(self):
742
+ if self.lang == 'zh':
743
+ return ','
744
+ elif self.lang == 'en':
745
+ return ', '
746
+
747
+ @property
748
+ def sep(self):
749
+ if self.lang == 'zh':
750
+ return '、'
751
+ elif self.lang == 'en':
752
+ return ', '
753
+
754
+ def generate_description(self, data):
755
+ data = self.remove_repeat(data)
756
+ weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
757
+
758
+ key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
759
+
760
+ prompts = []
761
+ if len(weak_tags) > 0:
762
+ probs = dist_prob_map_low[len(weak_tags)]
763
+ if len(key_tags) > 0:
764
+ tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
765
+ else:
766
+ tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
767
+ random.shuffle(weak_tags)
768
+ tags = weak_tags[:tags_num]
769
+ for tag_type in tags:
770
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
771
+ prompts.append(tag_desc)
772
+
773
+ if len(key_tags) > 0:
774
+ probs = dist_prob_map[len(key_tags)]
775
+ tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
776
+ random.shuffle(key_tags)
777
+ tags = key_tags[:tags_num]
778
+ for tag_type in tags:
779
+ tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
780
+ prompts.append(tag_desc)
781
+
782
+ random.shuffle(prompts)
783
+ return self.comma.join(prompts)
784
+
785
+ def is_valid_prompt_text(self, text):
786
+ for bad in ('抱歉','sorry', 'Sorry'):
787
+ if bad in text:
788
+ return False
789
+ return True
790
+
791
+ def get_ready_prompt(self, path):
792
+ sid = mp3_path_to_id(path)
793
+ return random.choice(self.prompts_dict[sid])
794
+
795
+ def __len__(self):
796
+ return len(self.data_ids)
797
+
798
+ def __getitem__(self, idx):
799
+ data_id = self.data_ids[idx]
800
+ item = self.data_index_dict[data_id]
801
+ path = item['path']
802
+ if not self.use_ready_prompt:
803
+ info = self.music_info_dict[data_id]
804
+ description = self.generate_description(info)
805
+ else:
806
+ description = self.get_ready_prompt(path)
807
+ if self.return_audio:
808
+ sr, duration = get_sr_and_duration_info(item)
809
+ audio, is_start, is_end = self.audio_reader(path, sr, duration)
810
+ else:
811
+ audio = None
812
+ if self.return_path:
813
+ return audio, description, path, is_start, is_end
814
+ else:
815
+ return audio, description, is_start, is_end
816
+
817
+ class Pond5Dataset(Dataset):
818
+ MAX_PROMPT_LEN = 200
819
+ def __init__(self,
820
+ metadata_path:str,
821
+ index_path:str,
822
+ duration:float=10,
823
+ sr:int = 0,
824
+ plain_rate = 0,
825
+ return_path = False,
826
+ return_audio = True,
827
+ lang = 'en',
828
+ translate:Optional[Dict[str, os.PathLike]] = None,
829
+ use_literal_none = True,
830
+ use_avoid_watermark_policy = None,
831
+ ):
832
+
833
+ if use_avoid_watermark_policy is None:
834
+ raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
835
+ self.use_avoid_watermark_policy = use_avoid_watermark_policy
836
+ assert self.use_avoid_watermark_policy is False
837
+ self.audio_reader = SafeAudioReader(duration, sr)
838
+
839
+ self.duration = duration
840
+ self._load_metadata(metadata_path, index_path)
841
+ self.sr = sr
842
+ self.plain_rate = plain_rate
843
+ self.return_path = return_path
844
+ self.return_audio = return_audio
845
+ self.use_literal_none = use_literal_none
846
+
847
+ self.lang = lang
848
+ self.translate = read_translate(translate)
849
+
850
+ def _load_metadata(self, metadata_path, index_path):
851
+ data_index = read_jsonlike(index_path)
852
+ data_ids = set([item['id'] for item in data_index])
853
+
854
+ with open(metadata_path) as fp:
855
+ lines = fp.readlines()
856
+
857
+ append_ids = set()
858
+
859
+ self.data = []
860
+ for line in lines:
861
+ item = json.loads(line)
862
+ if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10:
863
+ self.data.append(item)
864
+ append_ids.add(item['id'])
865
+
866
+ def __len__(self):
867
+ return len(self.data)
868
+
869
+ def __getitem__(self, idx):
870
+ item = self.data[idx]
871
+ path:str = item["path"]
872
+ description = self.generate_description(item)
873
+ if self.return_audio:
874
+ sr, duration = get_sr_and_duration_info(item)
875
+ audio, is_start, is_end = self.audio_reader(path, sr, duration)
876
+ else:
877
+ audio = None
878
+ if self.return_path:
879
+ return audio, description, path
880
+ return audio, description, is_start, is_end
881
+
882
+ @property
883
+ def keysep(self):
884
+ if self.lang == 'zh':
885
+ return ',' if random.random() > 0.5 else '、'
886
+ elif self.lang == 'en':
887
+ return ', '
888
+
889
+ def generate_description(self, item):
890
+ if random.random() > self.plain_rate:
891
+ # dynamically generate prompt from given prompt template
892
+ description = self.generate_description_dynamic(item)
893
+ else:
894
+ # use plain prompt, i.e. tags sequence separated by comma
895
+ description = self.generate_description_plain(item)
896
+ return description
897
+
898
+ def get_translation(self, k):
899
+ k = k.strip()
900
+ if k in self.translate:
901
+ return self.translate[k]
902
+ else:
903
+ return k
904
+
905
+ def generate_description_plain(self, item):
906
+ keywords = item['keywords']
907
+ if self.lang != 'en':
908
+ keywords = [self.get_translation(k) for k in keywords]
909
+ return gen_plain_prompt(keywords, sep=self.keysep)
910
+
911
+ def generate_description_dynamic(self,item):
912
+ desc = item.get('desc', 'none')
913
+ if desc is None:
914
+ desc = 'none'
915
+ desc = desc.strip()
916
+ if len(desc) > self.MAX_PROMPT_LEN:
917
+ shorter_desc = desc[:self.MAX_PROMPT_LEN]
918
+ # find last stop
919
+ stop_idx = shorter_desc.rfind('.')
920
+ if stop_idx == -1:
921
+ stop_idx = shorter_desc.rfind('!')
922
+ if stop_idx == -1:
923
+ stop_idx = shorter_desc.rfind(',')
924
+ if stop_idx == -1:
925
+ stop_idx = self.MAX_PROMPT_LEN - 1
926
+ desc = desc[:stop_idx+1]
927
+ return desc
928
+
929
+ class CombinedDataset(Dataset):
930
+ @beartype
931
+ def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
932
+ self.datasets = datasets
933
+ self.datasets_index = []
934
+
935
+ for i,dataset in enumerate(datasets):
936
+ if dataset is None:
937
+ continue
938
+ for dup in range(ratios[i]):
939
+ for j in range(len(dataset)):
940
+ self.datasets_index.append((i,j))
941
+
942
+ def __len__(self):
943
+ return len(self.datasets_index)
944
+
945
+ def __getitem__(self, idx):
946
+ index = self.datasets_index[idx]
947
+ i,j = index
948
+ return self.datasets[i][j]
949
+
950
+ class CombinedDataset_random(Dataset):
951
+ @beartype
952
+ def __init__(self,
953
+ num_examples:int,
954
+ datasets: Sequence[Dataset], ratios: Sequence[int]
955
+ ):
956
+ self.datasets = datasets
957
+ self.datasets_index = []
958
+
959
+ for i,dataset in enumerate(datasets):
960
+ if dataset is None:
961
+ continue
962
+ for dup in range(ratios[i]):
963
+ for j in range(len(dataset)):
964
+ self.datasets_index.append((i,j))
965
+ if num_examples > 0:
966
+ self.random_choose = True
967
+ self.dataset_len = num_examples
968
+ else:
969
+ self.random_choose = False
970
+ self.dataset_len = len(self.datasets_index)
971
+
972
+ def __len__(self):
973
+ return self.dataset_len
974
+
975
+ def __getitem__(self, idx):
976
+ first_try = True
977
+ try_cnt = 0
978
+ while True:
979
+ try:
980
+ if(self.random_choose or not first_try):
981
+ index2 = []
982
+ index2.append(np.random.randint(0,len(self.datasets)))
983
+ index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
984
+ else:
985
+ index2 = self.datasets_index[idx]
986
+ first_try = False
987
+ out = self.datasets[index2[0]][index2[1]]
988
+ if(len(out[0].shape)==1):out[0]=out[0][None,:]
989
+ return out
990
+ except:
991
+ print("Error loadding ", index2)
992
+ try_cnt += 1
993
+ if(try_cnt>10):
994
+ raise FileNotFoundError()
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import json
4
+
5
+ from torch.utils.data import Dataset
6
+ import torchaudio
7
+ from torchaudio.functional import resample
8
+ import torch
9
+ import numpy as np
10
+
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+
14
+
15
+ def check_lryics(lyric):
16
+ _FILTER_STRING = [
17
+ '作词', '作曲', '编曲', '【', '策划',
18
+ '录音', '混音', '母带', ':', '制作',
19
+ '版权', '校对', '演奏', '制作', '伴奏'
20
+ ]
21
+ for item in _FILTER_STRING:
22
+ if item in lyric:
23
+ return True
24
+
25
+ return False
26
+
27
+
28
+
29
+ def process_lyrics(lines):
30
+ lyric_part = []
31
+ timestamp_part = []
32
+
33
+ timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
+
35
+ for i, line in enumerate(lines):
36
+
37
+ # 删除前几行的特定信息
38
+ if i<10 and check_lryics(line):
39
+ continue
40
+
41
+ # 检查是否包含有效的时间戳和歌词内容
42
+ if timestamp_pattern.match(line):
43
+ timestamp_end = line.rfind(']')
44
+ lyrics = line[timestamp_end + 1:].strip()
45
+ timestamps = line[:timestamp_end + 1]
46
+
47
+ if ':' in lyrics:
48
+ if len(lyrics.split(":")[0]) <=5:
49
+ lyrics = "".join(lyrics.split(":")[1:])
50
+ # if lyrics: # 确保歌词部分不是空的
51
+ # lyric_part.append(lyrics)
52
+ # timestamp_part.append(timestamps)
53
+ # print(processed_lyrics)
54
+ return timestamp_part, lyric_part
55
+
56
+ def get_timestamps(timestamp_part):
57
+
58
+ # 转换为秒
59
+
60
+ timestamps = []
61
+
62
+ for line in timestamp_part:
63
+ match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
+ if match:
65
+ minutes = int(match.group(1))
66
+ seconds = float(match.group(2))
67
+ millis = float(match.group(3)) if match.group(3) else 0
68
+ total_seconds = minutes * 60 + seconds + millis
69
+ timestamps.append(total_seconds)
70
+
71
+
72
+ return timestamps
73
+
74
+ def process_lyrics_lrc(lyrics):
75
+ timestamp_part, lyric_part = process_lyrics(lyrics)
76
+ # print(timestamp_part)
77
+ # print(lyric_part)
78
+ timestamps = get_timestamps(timestamp_part)
79
+ # print(timestamps)
80
+ if len(timestamps) == 0:
81
+ # print(f'{lyric_path}')
82
+ return []
83
+
84
+ slice_start = timestamps[0]
85
+ slice_start_idx = 0
86
+
87
+ output_list = []
88
+ for i in range(1, len(timestamps)):
89
+ # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
+ if timestamps[i] - slice_start > 30:
91
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
+
93
+ slice_start = timestamps[i]
94
+ slice_start_idx = i
95
+
96
+ return output_list
97
+
98
+
99
+
100
+ def process_lyrics_yrc(lyrics):
101
+
102
+ timestamps, lyric_part = extract_lrc(lyrics)
103
+
104
+ # timestamp_part, lyric_part = process_lyrics(lyrics)
105
+ # import pdb; pdb.set_trace()
106
+ # print(timestamp_part)
107
+ # print(lyric_part)
108
+ # timestamps = get_timestamps(timestamp_part)
109
+ # print(timestamps)
110
+ if len(timestamps) == 0:
111
+ # print(f'{lyric_path}')
112
+ return []
113
+
114
+ slice_start = timestamps[0]
115
+ slice_start_idx = 0
116
+
117
+ output_list = []
118
+ for i in range(1, len(timestamps)):
119
+ # 如果累积时间超过30秒,则进行切分
120
+ if timestamps[i] - slice_start > 30:
121
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
+
123
+ slice_start = timestamps[i]
124
+ slice_start_idx = i
125
+ # import pdb; pdb.set_trace()
126
+ return output_list
127
+
128
+ def extract_lrc(lyrics):
129
+ timestamp_part, lyric_part = [], []
130
+
131
+ for i, text in enumerate(lyrics):
132
+ # 提取中括号内的内容
133
+ bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
+ bracket_content = bracket_content.split(',')
135
+ # 提取小括号内的内容
136
+ parentheses_content = re.findall(r'\((.*?)\)', text)
137
+ # 提取其他内容
138
+ other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
+
140
+ # 数据怎么处理?
141
+ # import pdb; pdb.set_trace()
142
+ if i<10 and check_lryics(other_content):
143
+ continue
144
+
145
+ # import pdb; pdb.set_trace()
146
+ timestamp_part.append(float(bracket_content[0])/1000)
147
+ lyric_part.append(other_content)
148
+ # import pdb; pdb.set_trace()
149
+ return timestamp_part, lyric_part
150
+
151
+
152
+
153
+ class WYYSongDataset(Dataset):
154
+ def __init__(self,
155
+ metadata_path:str,
156
+ sr:int = 0,
157
+ use_lang = ['en', 'zh-cn'],
158
+ num_examples = -1,
159
+ ):
160
+
161
+ self.sr = sr
162
+ self.use_lang = use_lang
163
+ self._load_metadata(metadata_path)
164
+
165
+ # buffer
166
+ self.lyric_buffer = {}
167
+
168
+ if(num_examples<=0):
169
+ self.dataset_len = len(self.data)
170
+ self.random_slc = False
171
+ else:
172
+ self.dataset_len = num_examples
173
+ self.random_slc = True
174
+
175
+ # 读取jsonl文件
176
+ def _load_metadata(self, metadata_path):
177
+ with open(metadata_path) as fp:
178
+ lines = fp.readlines()
179
+ self.data = []
180
+ for line in lines:
181
+ item = json.loads(line)
182
+ # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
183
+ if 'lyrics' in item and 'lang_info' in item:
184
+ if len(item['lyrics']) > 0:
185
+ for lang in self.use_lang:
186
+ if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
187
+ # if '伴奏' not in item['path'] and "cloud" in item['path']:
188
+ if '伴奏' not in item['path']:
189
+ self.data.append(item)
190
+
191
+
192
+ def __len__(self):
193
+ return self.dataset_len
194
+
195
+
196
+ def __getitem__(self, idx):
197
+ try_cnt = 0
198
+ while True:
199
+ if(self.random_slc):
200
+ idx = np.random.randint(0, len(self.data))
201
+ yrc_lyrics = []
202
+ lrc_lyrics = []
203
+ try:
204
+ info = self.data[idx]
205
+
206
+ # audio path
207
+ path:str = info["path"]
208
+
209
+ # 读取歌词段落
210
+ if 'lyrics' not in info:
211
+ if idx not in self.lyric_buffer:
212
+ # 字级别align的歌词
213
+ if info['yrc-lyric'] is not None:
214
+ with open(info['yrc-lyric']) as f_in:
215
+ yrc_lyric = json.load(f_in)
216
+ yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
217
+
218
+ # 句子级align的歌词
219
+ if info['lrc-lyric'] is not None:
220
+ with open(info['lrc-lyric']) as f_in:
221
+ lrc_lyric = json.load(f_in)
222
+ lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
223
+
224
+ # 优先使用字级别align的歌词
225
+ if len(yrc_lyrics) > 0:
226
+ lyrics = yrc_lyrics
227
+ else:
228
+ lyrics = lrc_lyrics
229
+ self.lyric_buffer[idx] = lyrics
230
+
231
+ # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
232
+ else:
233
+ lyrics = self.lyric_buffer[idx]
234
+ else:
235
+ lyrics = info['lyrics']
236
+
237
+ # 随机选取一个lyric段落
238
+ ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
239
+ # ly_id = 0
240
+
241
+ lyric = lyrics[ly_id]
242
+
243
+
244
+
245
+ st, et, lyric = self.parse_lyric(lyric)
246
+
247
+ assert et - st < 40
248
+
249
+ # 文本过滤
250
+
251
+ lyric = re.sub(r'【.*?】', '', lyric)
252
+ if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
253
+ assert 200 > len(lyric.replace(" ", "")) > 30
254
+ if ':' in lyrics:
255
+ if len(lyrics.split(":")[0]) <=5:
256
+ lyrics = "".join(lyrics.split(":")[1:])
257
+
258
+ if ':' in lyrics:
259
+ if len(lyrics.split(":")[0]) <=5:
260
+ lyrics = "".join(lyrics.split(":")[1:])
261
+
262
+ if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
263
+ assert 200 > len(lyric.split()) > 20
264
+
265
+ if ':' in lyrics:
266
+ if len(lyrics.split(":")[0].split()) <=3:
267
+ lyrics = "".join(lyrics.split(":")[1:])
268
+
269
+ if ':' in lyrics:
270
+ if len(lyrics.split(":")[0].split()) <=3:
271
+ lyrics = "".join(lyrics.split(":")[1:])
272
+
273
+
274
+
275
+ # 读取音频文件
276
+ cur_sample_rate = torchaudio.info(path).sample_rate
277
+ offset = int(cur_sample_rate*st)
278
+ num_frames = int(cur_sample_rate * (et -st))
279
+ chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
280
+
281
+ # 随机选取一个channel
282
+ if(chunk.shape[0]>1):
283
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
284
+ else:
285
+ chunk = chunk[[0],:].float()
286
+
287
+ if(cur_sample_rate!=self.sr):
288
+ # print('a:',cur_sample_rate,chunk.shape)
289
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
290
+
291
+ return chunk, lyric, [st, et], path
292
+ except:
293
+ print("Error loadding ", info["path"])
294
+ try_cnt += 1
295
+ idx = np.random.randint(0, len(self.data))
296
+ if(try_cnt>10):
297
+ raise FileNotFoundError()
298
+
299
+ def parse_lyric(self, lyric):
300
+ pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
301
+ match = re.search(pattern, lyric)
302
+
303
+ start_time = float(match.group(1))
304
+ end_time = float(match.group(2))
305
+ content = match.group(3)
306
+ return start_time, end_time, content
307
+
308
+ def collect_song(data_list):
309
+ audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
310
+ lyrics = [data[1] for data in data_list]
311
+ st_et = [data[2] for data in data_list]
312
+ paths = [data[3] for data in data_list]
313
+ return audios, lyrics, st_et
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import json
4
+
5
+ from torch.utils.data import Dataset
6
+ import torchaudio
7
+ from torchaudio.functional import resample
8
+ import torch
9
+ import numpy as np
10
+
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+
14
+
15
+ def check_lryics(lyric):
16
+ _FILTER_STRING = [
17
+ '作词', '作曲', '编曲', '【', '策划',
18
+ '录音', '混音', '母带', ':', '制作',
19
+ '版权', '校对', '演奏', '制作', '伴奏'
20
+ ]
21
+ for item in _FILTER_STRING:
22
+ if item in lyric:
23
+ return True
24
+
25
+ return False
26
+
27
+
28
+
29
+ def process_lyrics(lines):
30
+ lyric_part = []
31
+ timestamp_part = []
32
+
33
+ timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
+
35
+ for i, line in enumerate(lines):
36
+
37
+ # 删除前几行的特定信息
38
+ if i<10 and check_lryics(line):
39
+ continue
40
+
41
+ # 检查是否包含有效的时间戳和歌词内容
42
+ if timestamp_pattern.match(line):
43
+ timestamp_end = line.rfind(']')
44
+ lyrics = line[timestamp_end + 1:].strip()
45
+ timestamps = line[:timestamp_end + 1]
46
+
47
+ if ':' in lyrics:
48
+ if len(lyrics.split(":")[0]) <=5:
49
+ lyrics = "".join(lyrics.split(":")[1:])
50
+ # if lyrics: # 确保歌词部分不是空的
51
+ # lyric_part.append(lyrics)
52
+ # timestamp_part.append(timestamps)
53
+ # print(processed_lyrics)
54
+ return timestamp_part, lyric_part
55
+
56
+ def get_timestamps(timestamp_part):
57
+
58
+ # 转换为秒
59
+
60
+ timestamps = []
61
+
62
+ for line in timestamp_part:
63
+ match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
+ if match:
65
+ minutes = int(match.group(1))
66
+ seconds = float(match.group(2))
67
+ millis = float(match.group(3)) if match.group(3) else 0
68
+ total_seconds = minutes * 60 + seconds + millis
69
+ timestamps.append(total_seconds)
70
+
71
+
72
+ return timestamps
73
+
74
+ def process_lyrics_lrc(lyrics):
75
+ timestamp_part, lyric_part = process_lyrics(lyrics)
76
+ # print(timestamp_part)
77
+ # print(lyric_part)
78
+ timestamps = get_timestamps(timestamp_part)
79
+ # print(timestamps)
80
+ if len(timestamps) == 0:
81
+ # print(f'{lyric_path}')
82
+ return []
83
+
84
+ slice_start = timestamps[0]
85
+ slice_start_idx = 0
86
+
87
+ output_list = []
88
+ for i in range(1, len(timestamps)):
89
+ # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
+ if timestamps[i] - slice_start > 30:
91
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
+
93
+ slice_start = timestamps[i]
94
+ slice_start_idx = i
95
+
96
+ return output_list
97
+
98
+
99
+
100
+ def process_lyrics_yrc(lyrics):
101
+
102
+ timestamps, lyric_part = extract_lrc(lyrics)
103
+
104
+ # timestamp_part, lyric_part = process_lyrics(lyrics)
105
+ # import pdb; pdb.set_trace()
106
+ # print(timestamp_part)
107
+ # print(lyric_part)
108
+ # timestamps = get_timestamps(timestamp_part)
109
+ # print(timestamps)
110
+ if len(timestamps) == 0:
111
+ # print(f'{lyric_path}')
112
+ return []
113
+
114
+ slice_start = timestamps[0]
115
+ slice_start_idx = 0
116
+
117
+ output_list = []
118
+ for i in range(1, len(timestamps)):
119
+ # 如果累积时间超过30秒,则进行切分
120
+ if timestamps[i] - slice_start > 30:
121
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
+
123
+ slice_start = timestamps[i]
124
+ slice_start_idx = i
125
+ # import pdb; pdb.set_trace()
126
+ return output_list
127
+
128
+ def extract_lrc(lyrics):
129
+ timestamp_part, lyric_part = [], []
130
+
131
+ for i, text in enumerate(lyrics):
132
+ # 提取中括号内的内容
133
+ bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
+ bracket_content = bracket_content.split(',')
135
+ # 提取小括号内的内容
136
+ parentheses_content = re.findall(r'\((.*?)\)', text)
137
+ # 提取其他内容
138
+ other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
+
140
+ # 数据怎么处理?
141
+ # import pdb; pdb.set_trace()
142
+ if i<10 and check_lryics(other_content):
143
+ continue
144
+
145
+ # import pdb; pdb.set_trace()
146
+ timestamp_part.append(float(bracket_content[0])/1000)
147
+ lyric_part.append(other_content)
148
+ # import pdb; pdb.set_trace()
149
+ return timestamp_part, lyric_part
150
+
151
+
152
+
153
+ class WYYSongDataset(Dataset):
154
+ def __init__(self,
155
+ metadata_path:str,
156
+ sr:int = 0,
157
+ use_lang = ['en', 'zh-cn'],
158
+ num_examples = -1,
159
+ ):
160
+
161
+ self.sr = sr
162
+ self.use_lang = use_lang
163
+ self._load_metadata(metadata_path)
164
+
165
+ # buffer
166
+ self.lyric_buffer = {}
167
+
168
+ if(num_examples<=0):
169
+ self.dataset_len = len(self.data)
170
+ self.random_slc = False
171
+ else:
172
+ self.dataset_len = num_examples
173
+ self.random_slc = True
174
+
175
+ # 读取jsonl文件
176
+ def _load_metadata(self, metadata_path):
177
+ with open(metadata_path) as fp:
178
+ lines = fp.readlines()
179
+ self.data = []
180
+ for line in lines:
181
+ item = json.loads(line)
182
+ # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
183
+ if 'lyrics' in item and 'lang_info' in item:
184
+ if len(item['lyrics']) > 0:
185
+ for lang in self.use_lang:
186
+ if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
187
+ # if '伴奏' not in item['path'] and "cloud" in item['path']:
188
+ if '伴奏' not in item['path']:
189
+ self.data.append(item)
190
+
191
+
192
+ def __len__(self):
193
+ return self.dataset_len
194
+
195
+
196
+ def __getitem__(self, idx):
197
+ try_cnt = 0
198
+ while True:
199
+ if(self.random_slc):
200
+ idx = np.random.randint(0, len(self.data))
201
+ yrc_lyrics = []
202
+ lrc_lyrics = []
203
+ try:
204
+ info = self.data[idx]
205
+
206
+ # audio path
207
+ path:str = info["path"]
208
+
209
+ # 读取歌词段落
210
+ if 'lyrics' not in info:
211
+ if idx not in self.lyric_buffer:
212
+ # 字级别align的歌词
213
+ if info['yrc-lyric'] is not None:
214
+ with open(info['yrc-lyric']) as f_in:
215
+ yrc_lyric = json.load(f_in)
216
+ yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
217
+
218
+ # 句子级align的歌词
219
+ if info['lrc-lyric'] is not None:
220
+ with open(info['lrc-lyric']) as f_in:
221
+ lrc_lyric = json.load(f_in)
222
+ lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
223
+
224
+ # 优先使用字级别align的歌词
225
+ if len(yrc_lyrics) > 0:
226
+ lyrics = yrc_lyrics
227
+ else:
228
+ lyrics = lrc_lyrics
229
+ self.lyric_buffer[idx] = lyrics
230
+
231
+ # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
232
+ else:
233
+ lyrics = self.lyric_buffer[idx]
234
+ else:
235
+ lyrics = info['lyrics']
236
+
237
+ # 随机选取一个lyric段落
238
+ ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
239
+ # ly_id = 0
240
+
241
+ lyric = lyrics[ly_id]
242
+
243
+
244
+
245
+ st, et, lyric = self.parse_lyric(lyric)
246
+
247
+ assert et - st < 20
248
+
249
+ # 文本过滤
250
+
251
+ lyric = re.sub(r'【.*?】', '', lyric)
252
+ if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
253
+ assert 100 > len(lyric.replace(" ", "")) > 5
254
+ if ':' in lyrics:
255
+ if len(lyrics.split(":")[0]) <=5:
256
+ lyrics = "".join(lyrics.split(":")[1:])
257
+
258
+ if ':' in lyrics:
259
+ if len(lyrics.split(":")[0]) <=5:
260
+ lyrics = "".join(lyrics.split(":")[1:])
261
+
262
+ if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
263
+ assert 100 > len(lyric.split()) > 5
264
+
265
+ if ':' in lyrics:
266
+ if len(lyrics.split(":")[0].split()) <=3:
267
+ lyrics = "".join(lyrics.split(":")[1:])
268
+
269
+ if ':' in lyrics:
270
+ if len(lyrics.split(":")[0].split()) <=3:
271
+ lyrics = "".join(lyrics.split(":")[1:])
272
+
273
+
274
+
275
+ # 读取音频文件
276
+ cur_sample_rate = torchaudio.info(path).sample_rate
277
+ offset = int(cur_sample_rate*st)
278
+ num_frames = int(cur_sample_rate * (et -st))
279
+ chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
280
+
281
+ # 随机选取一个channel
282
+ if(chunk.shape[0]>1):
283
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
284
+ else:
285
+ chunk = chunk[[0],:].float()
286
+
287
+ if(cur_sample_rate!=self.sr):
288
+ # print('a:',cur_sample_rate,chunk.shape)
289
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
290
+
291
+ return chunk, lyric, [st, et], path
292
+ except:
293
+ print("Error loadding ", info["path"])
294
+ try_cnt += 1
295
+ idx = np.random.randint(0, len(self.data))
296
+ if(try_cnt>10):
297
+ raise FileNotFoundError()
298
+
299
+ def parse_lyric(self, lyric):
300
+ pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
301
+ match = re.search(pattern, lyric)
302
+
303
+ start_time = float(match.group(1))
304
+ end_time = float(match.group(2))
305
+ content = match.group(3)
306
+ return start_time, end_time, content
307
+
308
+ def collect_song(data_list):
309
+ audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
310
+ lyrics = [data[1] for data in data_list]
311
+ st_et = [data[2] for data in data_list]
312
+ paths = [data[3] for data in data_list]
313
+ return audios, lyrics, st_et
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import json
4
+
5
+ from torch.utils.data import Dataset
6
+ import torchaudio
7
+ from torchaudio.functional import resample
8
+ import torch
9
+ import numpy as np
10
+
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+
14
+
15
+ def check_lryics(lyric):
16
+ _FILTER_STRING = [
17
+ '作词', '作曲', '编曲', '【', '策划',
18
+ '录音', '混音', '母带', ':', '制作',
19
+ '版权', '校对', '演奏', '制作', '伴奏'
20
+ ]
21
+ for item in _FILTER_STRING:
22
+ if item in lyric:
23
+ return True
24
+
25
+ return False
26
+
27
+
28
+
29
+ def process_lyrics(lines):
30
+ lyric_part = []
31
+ timestamp_part = []
32
+
33
+ timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
+
35
+ for i, line in enumerate(lines):
36
+
37
+ # 删除前几行的特定信息
38
+ if i<10 and check_lryics(line):
39
+ continue
40
+
41
+ # 检查是否包含有效的时间戳和歌词内容
42
+ if timestamp_pattern.match(line):
43
+ timestamp_end = line.rfind(']')
44
+ lyrics = line[timestamp_end + 1:].strip()
45
+ timestamps = line[:timestamp_end + 1]
46
+
47
+ if ':' in lyrics:
48
+ if len(lyrics.split(":")[0]) <=5:
49
+ lyrics = "".join(lyrics.split(":")[1:])
50
+ # if lyrics: # 确保歌词部分不是空的
51
+ # lyric_part.append(lyrics)
52
+ # timestamp_part.append(timestamps)
53
+ # print(processed_lyrics)
54
+ return timestamp_part, lyric_part
55
+
56
+ def get_timestamps(timestamp_part):
57
+
58
+ # 转换为秒
59
+
60
+ timestamps = []
61
+
62
+ for line in timestamp_part:
63
+ match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
+ if match:
65
+ minutes = int(match.group(1))
66
+ seconds = float(match.group(2))
67
+ millis = float(match.group(3)) if match.group(3) else 0
68
+ total_seconds = minutes * 60 + seconds + millis
69
+ timestamps.append(total_seconds)
70
+
71
+
72
+ return timestamps
73
+
74
+ def process_lyrics_lrc(lyrics):
75
+ timestamp_part, lyric_part = process_lyrics(lyrics)
76
+ # print(timestamp_part)
77
+ # print(lyric_part)
78
+ timestamps = get_timestamps(timestamp_part)
79
+ # print(timestamps)
80
+ if len(timestamps) == 0:
81
+ # print(f'{lyric_path}')
82
+ return []
83
+
84
+ slice_start = timestamps[0]
85
+ slice_start_idx = 0
86
+
87
+ output_list = []
88
+ for i in range(1, len(timestamps)):
89
+ # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
+ if timestamps[i] - slice_start > 30:
91
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
+
93
+ slice_start = timestamps[i]
94
+ slice_start_idx = i
95
+
96
+ return output_list
97
+
98
+
99
+
100
+ def process_lyrics_yrc(lyrics):
101
+
102
+ timestamps, lyric_part = extract_lrc(lyrics)
103
+
104
+ # timestamp_part, lyric_part = process_lyrics(lyrics)
105
+ # import pdb; pdb.set_trace()
106
+ # print(timestamp_part)
107
+ # print(lyric_part)
108
+ # timestamps = get_timestamps(timestamp_part)
109
+ # print(timestamps)
110
+ if len(timestamps) == 0:
111
+ # print(f'{lyric_path}')
112
+ return []
113
+
114
+ slice_start = timestamps[0]
115
+ slice_start_idx = 0
116
+
117
+ output_list = []
118
+ for i in range(1, len(timestamps)):
119
+ # 如果累积时间超过30秒,则进行切分
120
+ if timestamps[i] - slice_start > 30:
121
+ output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
+
123
+ slice_start = timestamps[i]
124
+ slice_start_idx = i
125
+ # import pdb; pdb.set_trace()
126
+ return output_list
127
+
128
+ def extract_lrc(lyrics):
129
+ timestamp_part, lyric_part = [], []
130
+
131
+ for i, text in enumerate(lyrics):
132
+ # 提取中括号内的内容
133
+ bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
+ bracket_content = bracket_content.split(',')
135
+ # 提取小括号内的内容
136
+ parentheses_content = re.findall(r'\((.*?)\)', text)
137
+ # 提取其他内容
138
+ other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
+
140
+ # 数据怎么处理?
141
+ if i<10 and check_lryics(other_content):
142
+ continue
143
+ timestamp_part.append(float(bracket_content[0])/1000)
144
+ lyric_part.append(other_content)
145
+ return timestamp_part, lyric_part
146
+
147
+
148
+
149
+ class WYYSongDataset(Dataset):
150
+ def __init__(self,
151
+ metadata_path:str,
152
+ sr:int = 0,
153
+ use_lang = ['en', 'zh-cn'],
154
+ num_examples = -1,
155
+ max_dur = 20,
156
+ pad_to_max= True,
157
+ ):
158
+
159
+ self.sr = sr
160
+ self.use_lang = use_lang
161
+ self._load_metadata(metadata_path)
162
+ self.max_dur = max_dur
163
+ self.pad_to_max = pad_to_max
164
+
165
+ # buffer
166
+ self.lyric_buffer = {}
167
+
168
+ if(num_examples<=0):
169
+ self.dataset_len = len(self.data)
170
+ self.random_slc = False
171
+ else:
172
+ self.dataset_len = num_examples
173
+ self.random_slc = True
174
+
175
+ # 读取jsonl文件
176
+ def _load_metadata(self, metadata_path):
177
+ with open(metadata_path) as fp:
178
+ lines = fp.readlines()
179
+ self.data = []
180
+ for line in lines:
181
+ item = json.loads(line)
182
+ if '伴奏' not in item['path']:
183
+ # if "lang_type" in item and item['lang_type'] == 'en':
184
+ if "lang_type" in item:
185
+ self.data.append(item)
186
+
187
+
188
+ def __len__(self):
189
+ return self.dataset_len
190
+
191
+
192
+ def __getitem__(self, idx):
193
+ try_cnt = 0
194
+ while True:
195
+ if(self.random_slc):
196
+ idx = np.random.randint(0, len(self.data))
197
+ yrc_lyrics = []
198
+ lrc_lyrics = []
199
+ try:
200
+ info = self.data[idx]
201
+
202
+ # audio path
203
+ path = info["path"]
204
+ lang_type = info["lang_type"]
205
+ if info["lang_type"] == 'en':
206
+ lyrics = info['lyrics']
207
+ else:
208
+ lyrics = info['lyrics_phone']
209
+
210
+ # 随机选取一个lyric段落
211
+ ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
212
+ lyric = lyrics[ly_id].strip()
213
+
214
+ st, et, lyric = self.parse_lyric(lyric)
215
+ lyric = lyric.replace("\xa0", " ")
216
+
217
+ lyric = " ".join(lyric.split())
218
+
219
+ assert et - st < self.max_dur
220
+
221
+
222
+ if info["lang_type"] == 'en':
223
+ # print(len(lyric.split())/(et-st))
224
+ assert 6 > len(lyric.split())/(et-st) > 1
225
+ else:
226
+ # print(len(lyric.split())/(et-st))
227
+ lyric = lyric.replace("-", "")
228
+ assert 6 > len(lyric.split())/(et-st) > 1
229
+
230
+
231
+ # 读取音频文件
232
+ cur_sample_rate = torchaudio.info(path).sample_rate
233
+ offset = int(cur_sample_rate*st)
234
+ num_frames = int(cur_sample_rate * (et -st))
235
+ chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
236
+ # chunk = torch.zeros(1, 48000*15)
237
+
238
+ # 随机选取一个channel
239
+ if(chunk.shape[0]>1):
240
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
241
+ else:
242
+ chunk = chunk[[0],:].float()
243
+
244
+ if(cur_sample_rate!=self.sr):
245
+ # print('a:',cur_sample_rate,chunk.shape)
246
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
247
+
248
+ if self.pad_to_max:
249
+ chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
250
+
251
+ return chunk, lyric, et-st, path, lang_type
252
+ except:
253
+ # print("Error loadding ", info["path"])
254
+ try_cnt += 1
255
+ idx = np.random.randint(0, len(self.data))
256
+ if(try_cnt>20):
257
+ raise FileNotFoundError()
258
+
259
+ def parse_lyric(self, lyric):
260
+ pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
261
+ match = re.search(pattern, lyric)
262
+
263
+ start_time = float(match.group(1))
264
+ end_time = float(match.group(2))
265
+ content = match.group(3)
266
+ return start_time, end_time, content
267
+
268
+ def pad_2d_tensor(self, x, max_len, pad_id):
269
+ # 获取输入 tensor 的形状
270
+ batch_size, seq_len = x.size()
271
+ max_len = max(max_len, seq_len)
272
+ # 计算需要填充的长度
273
+ pad_len = max_len - seq_len
274
+
275
+ # 如果需要填充
276
+ if pad_len > 0:
277
+ # 创建填充 tensor
278
+ pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
279
+
280
+ # 沿第二个维度(列)连接输入 tensor 和填充 tensor
281
+ padded_tensor = torch.cat([x, pad_tensor], dim=1)
282
+ else:
283
+ # 如果不需要填充,直接返回输入 tensor
284
+ padded_tensor = x
285
+
286
+ return padded_tensor
287
+
288
+ def collect_data(data_list):
289
+ audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
290
+ lyrics = [data[1] for data in data_list]
291
+ st_et = [data[2] for data in data_list]
292
+ paths = [data[3] for data in data_list]
293
+ lang_types = [data[4] for data in data_list]
294
+ return audios, lyrics, st_et, lang_types
295
+ # return audios, lyrics, st_et
296
+
297
+
298
+ def build_dataset():
299
+ train_dataset = WYYSongDataset(
300
+ metadata_path = "train.jsonl",
301
+ sr = 48000,
302
+ use_lang = ['zh-cn', 'en'],
303
+ num_examples = 10*10000
304
+ )
305
+
306
+ valid_dataset = WYYSongDataset(
307
+ metadata_path = "valid.jsonl",
308
+ sr = 48000,
309
+ use_lang = ['zh-cn', 'en'],
310
+ num_examples = 500
311
+ )
312
+
313
+ return train_dataset, valid_dataset
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from beartype.typing import Sequence, Callable, Optional, Dict, List
3
+ from beartype.door import is_bearable
4
+ import random
5
+ import os
6
+ from torchaudio.functional import resample
7
+ import torch
8
+ import typing as tp
9
+ from pathlib import Path
10
+ import torchaudio as ta
11
+ import torch.nn.functional as F
12
+ import soundfile
13
+ import numpy as np
14
+ import json
15
+ import yaml
16
+ import random
17
+ import librosa
18
+ from loguru import logger
19
+ import re
20
+
21
+
22
+ def _av_read(filepath, seek_time=0, duration=None):
23
+ if duration is not None:
24
+ sr = librosa.get_samplerate(filepath)
25
+ offset = seek_time
26
+ num_samples = int(duration * sr)
27
+ wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration)
28
+ else:
29
+ wav, sr = librosa.load(filepath, sr=None, offset=seek_time)
30
+
31
+ return wav, sr
32
+
33
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
34
+ duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]:
35
+ """Read audio by picking the most appropriate backend tool based on the audio format.
36
+
37
+ Args:
38
+ filepath (str or Path): Path to audio file to read.
39
+ seek_time (float): Time at which to start reading in the file.
40
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
41
+ pad (bool): Pad output audio if not reaching expected duration.
42
+ Returns:
43
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
44
+ """
45
+ fp = Path(filepath)
46
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
47
+ # There is some bug with ffmpeg and reading flac
48
+ info = soundfile.info(filepath)
49
+ frames = -1 if duration <= 0 else int(duration * info.samplerate)
50
+ frame_offset = int(seek_time * info.samplerate)
51
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
52
+ assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}"
53
+ wav = torch.from_numpy(wav).t().contiguous()
54
+ if len(wav.shape) == 1:
55
+ wav = torch.unsqueeze(wav, 0)
56
+ elif (
57
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
58
+ and duration <= 0 and seek_time == 0
59
+ ):
60
+ # Torchaudio is faster if we load an entire file at once.
61
+ wav, sr = librosa.load(fp, sr=None, mono=True)
62
+ else:
63
+ wav, sr = _av_read(filepath, seek_time, duration)
64
+ if pad and duration > 0:
65
+ expected_frames = int(duration * sr)
66
+ wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1]))
67
+ if not isinstance(wav, torch.Tensor):
68
+ wav = torch.tensor(wav)
69
+ return wav, sr
70
+
71
+ def random_seek_read(filepath, duration):
72
+ if duration > 0:
73
+ total_duration = librosa.get_duration(path=filepath)
74
+ acceptable_start = max(0, total_duration - duration)
75
+ wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True)
76
+ else:
77
+ wav, sr = audio_read(filepath, 0, -1, pad=False)
78
+ return wav, sr
79
+
80
+ def safe_random_seek_read(filepath, duration, sample_rate):
81
+ try:
82
+ wav, sr = random_seek_read(filepath, duration)
83
+ if sr != sample_rate:
84
+ wav = resample(wav, sr, sample_rate)
85
+ sr = sample_rate
86
+ except Exception as e:
87
+ logger.error(f"Error reading {filepath}: {e}")
88
+ sr = sample_rate
89
+ wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32)
90
+ return wav, sr
91
+
92
+ def read_jsonlike(path: os.PathLike):
93
+ #json or jsonl
94
+ if str(path).endswith(".json"):
95
+ with open(path, 'r', encoding='utf8') as f:
96
+ data = json.load(f)
97
+ return data
98
+ elif str(path).endswith(".jsonl"):
99
+ with open(path, 'r', encoding='utf8') as f:
100
+ data = [json.loads(line) for line in f.readlines()]
101
+ return data
102
+ else:
103
+ raise ValueError("Unknown file format")
104
+
105
+ dist_prob_map = {
106
+ 1: (1.0,),
107
+ 2: (0.5, 0.5),
108
+ 3: (0.3, 0.4, 0.3),
109
+ 4: (0.2, 0.3, 0.3, 0.2),
110
+ 5: (0.2, 0.2, 0.3, 0.2, 0.1),
111
+ 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
112
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
113
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
114
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
115
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
116
+ }
117
+
118
+ dist_prob_map_low = {
119
+ 1: (1.0,),
120
+ 2: (0.8, 0.2),
121
+ 3: (0.8, 0.1, 0.1),
122
+ 4: (0.7, 0.1, 0.1, 0.1),
123
+ 5: (0.7, 0.1, 0.1, 0.05, 0.05),
124
+ 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
125
+ }
126
+
127
+
128
+ _bpm_range_rights = (
129
+ (40, '20-40'),
130
+ (60, '40-60'),
131
+ (66, '60-66'),
132
+ (76, '66-76'),
133
+ (108, '76-108'),
134
+ (120, '108-120'),
135
+ (168, '120-168'),
136
+ (176, '168-176'),
137
+ (200, '176-200')
138
+ )
139
+ _bpm_desc_map = {
140
+ '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
141
+ '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
142
+ '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
143
+ '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
144
+ '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
145
+ '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
146
+ '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
147
+ '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
148
+ '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
149
+ '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
150
+ }
151
+ _bpm_desc_map_zh = {
152
+ '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
153
+ '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
154
+ '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
155
+ '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
156
+ '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
157
+ '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
158
+ '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
159
+ '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
160
+ '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
161
+ '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
162
+ }
163
+ def get_bpm_range(bpm):
164
+ bpm = int(bpm)
165
+ for right, tag in _bpm_range_rights:
166
+ if bpm <= right:
167
+ return tag
168
+ return '>200'
169
+
170
+ def gen_bpm_descript(bpm, lang='en'):
171
+ bpm_range = get_bpm_range(bpm)
172
+ if lang == 'en':
173
+ return random.choice(_bpm_desc_map[bpm_range])
174
+ elif lang == 'zh':
175
+ return random.choice(_bpm_desc_map_zh[bpm_range])
176
+ else:
177
+ raise ValueError(f"Unknown language {lang}")
178
+
179
+ def read_translate(translate: Optional[Dict[str, os.PathLike]]):
180
+ if translate is None:
181
+ return None
182
+ return {k: read_jsonlike(path) for k, path in translate.items()}
183
+
184
+
185
+ def tags_to_desc(tag_list, sep=',') -> str:
186
+ if not isinstance(tag_list, Sequence):
187
+ return str(tag_list)
188
+ if isinstance(tag_list, str):
189
+ return tag_list
190
+ if len(tag_list) <= 0:
191
+ return ''
192
+ elif len(tag_list) <= 5:
193
+ probs = dist_prob_map[len(tag_list)]
194
+ tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
195
+ random.shuffle(tag_list)
196
+ tag_list = tag_list[:tags_num]
197
+ return sep.join(tag_list)
198
+ else:
199
+ probs = dist_prob_map[5]
200
+ tags_num = random.choices(range(1, 6), probs)[0]
201
+ random.shuffle(tag_list)
202
+ tag_list = tag_list[:tags_num]
203
+ return sep.join(tag_list)
204
+
205
+
206
+ class PromptTemplate:
207
+ def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
208
+ self.template_text = template_text
209
+ self.tag_map = tag_map
210
+ self.lang = lang
211
+
212
+ @property
213
+ def tags(self):
214
+ return tuple(self.tag_map.keys())
215
+
216
+ def apply(self, **kwargs):
217
+ for tag in list(kwargs.keys()):
218
+ if kwargs[tag] == '':
219
+ kwargs.pop(tag)
220
+ for tag in self.tags:
221
+ if tag in kwargs:
222
+ kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
223
+ else:
224
+ kwargs[tag] = ''
225
+ prompt = self.template_text.format(**kwargs)
226
+
227
+ return self.beautify(prompt)
228
+
229
+ def beautify(self, text):
230
+ if self.lang == 'en':
231
+ return self._beautify_en(text)
232
+ elif self.lang == 'zh':
233
+ return self._beautify_zh(text)
234
+ else:
235
+ raise ValueError(f'Unknown language {self.lang}')
236
+
237
+ @staticmethod
238
+ def _beautify_en(text):
239
+ # no continuous commas without content between them
240
+ text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
241
+ # no continuous whitespace
242
+ text = re.sub(r'\s+', ' ', text)
243
+ # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
244
+ text = re.sub(r'\s+,', r',', text)
245
+ text = re.sub(r',\s+', r', ', text)
246
+ # no whitespace before the full stop
247
+ text = re.sub(r'\s+\.', r'.', text)
248
+ # strip whitespace, comma, and replace ',.'
249
+ text = text.strip(' ,')
250
+ text = text.replace(',.', '.')
251
+ return text
252
+
253
+ @staticmethod
254
+ def _beautify_zh(text):
255
+ # no continuous commas without content between them
256
+ text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
257
+ text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
258
+ # assume there should be NO whitespace in Chinese
259
+ text = re.sub(r'\s+', r'', text)
260
+ # strip whitespace, comma, and replace ',。'
261
+ text = text.strip(', 、')
262
+ text = text.replace(',。', '。')
263
+ return text
264
+
265
+ def __repr__(self):
266
+ return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
267
+
268
+ __str__ = __repr__
269
+
270
+ def parse_prompt_template(prompt_template_text, lang='en'):
271
+ span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
272
+ tag_pattern = re.compile(r'{.+?}', re.DOTALL)
273
+
274
+ template_text = prompt_template_text.strip()
275
+ span_texts = span_pattern.findall(prompt_template_text)
276
+ tag_map = {}
277
+ for span_text in span_texts:
278
+ tag = tag_pattern.findall(span_text)[0].strip('{}')
279
+ tag_map[tag] = span_text
280
+ template_text = template_text.replace(span_text, '{'+tag+'}')
281
+
282
+ return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
283
+
284
+ def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
285
+ with open(path, 'r') as f:
286
+ lines = f.readlines()
287
+ cnt = 0
288
+ pts = []
289
+ for line in lines:
290
+ pt = parse_prompt_template(line, lang=lang)
291
+ cnt += 1
292
+ if len(pt.tags) < num:
293
+ logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
294
+ pts.append(pt)
295
+
296
+ return pts
297
+
298
+
299
+ class AudioStockDataset(Dataset):
300
+ def __init__(self,
301
+ num_examples:int,
302
+ metadata_path:str,
303
+ duration:float=60,
304
+ sr:int = 0,
305
+ return_path = False,
306
+ return_audio = True,
307
+ prompt_template_path: os.PathLike = None,
308
+ tag_types = [],
309
+ lang = 'en',
310
+ translate:Optional[Dict[str, os.PathLike]] = None
311
+ ):
312
+ self.duration = duration
313
+ self.MAX_DURATION = 360
314
+ self._load_metadata(metadata_path)
315
+ if num_examples > 0:
316
+ self.random_choose = True
317
+ self.dataset_len = num_examples
318
+ else:
319
+ self.random_choose = False
320
+ self.dataset_len = len(self.data)
321
+ self.sr = sr
322
+ self.return_path = return_path
323
+ self.return_audio = return_audio
324
+
325
+ self.use_dynamic_prompt = prompt_template_path is not None
326
+ if self.use_dynamic_prompt:
327
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
328
+ self.tag_types = tag_types
329
+
330
+ self.lang = lang
331
+ self.translate = read_translate(translate)
332
+
333
+ def _load_metadata(self, metadata_path):
334
+ total_len = 0; valid_len = 0
335
+ with open(metadata_path) as fp:
336
+ lines = fp.readlines()
337
+ self.data = []
338
+ for line in lines:
339
+ item = json.loads(line)
340
+ total_len += 1
341
+ if(item['duration']>self.duration and item['duration']<self.MAX_DURATION):
342
+ valid_len += 1
343
+ self.data.append(item)
344
+ print("Filter data from {} to {}".format(total_len, valid_len))
345
+ self.is_info_recorded = bool('Tags' in self.data[0])
346
+
347
+ def __len__(self):
348
+ return self.dataset_len
349
+
350
+ def __getitem__(self, idx):
351
+ first_try = True
352
+ try_cnt = 0
353
+ while True:
354
+ try:
355
+ if(self.random_choose or not first_try):
356
+ index2 = np.random.randint(0,len(self.data))
357
+ else:
358
+ index2 = idx
359
+ first_try = False
360
+ return self.getitem_main(index2)
361
+ except:
362
+ print("Error loadding ", self.data[idx]["path"])
363
+ try_cnt += 1
364
+ if(try_cnt>10):
365
+ raise ValueError()
366
+
367
+ def getitem_main(self, idx):
368
+ path:str = self.data[idx]["path"]
369
+ json_path = path[:path.rfind('.')] + ".json"
370
+ if self.is_info_recorded:
371
+ item = self.data[idx]
372
+ else:
373
+ with open(json_path) as fp:
374
+ item:dict = json.load(fp)
375
+ description = self.generate_description(item)
376
+ if self.return_audio:
377
+ audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr)
378
+ else:
379
+ audio = None
380
+ if self.return_path:
381
+ return audio, description, path
382
+ return audio, description
383
+
384
+
385
+
386
+ def generate_description(self, item):
387
+ if self.use_dynamic_prompt:
388
+ # dynamically generate prompt from given prompt template
389
+ prompt_template = random.choice(self.prompt_templates)
390
+ description = self.generate_description_dynamic(item, prompt_template)
391
+ else:
392
+ # use ordinary static prompt instead
393
+ description = self.generate_description_ordinary(item)
394
+ return description
395
+
396
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
397
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
398
+
399
+ if len(exists_tag) > 0:
400
+ probs = dist_prob_map[len(exists_tag)]
401
+ tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
402
+ random.shuffle(exists_tag)
403
+ tags = exists_tag[:tags_num]
404
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
405
+ tags_args = self.handle_BPM_tag(tags_args)
406
+ prompt = prompt_template.apply(**tags_args)
407
+ else:
408
+ # no strong tags, use all weak tags instead
409
+ prompt = prompt_template.apply()
410
+
411
+ return prompt
412
+
413
+ def tags_to_desc(self, tag_list, tag_type) -> str:
414
+ if self.lang == 'en':
415
+ return tags_to_desc(tag_list)
416
+ elif self.lang == 'zh':
417
+ if tag_type == 'BPM':
418
+ return tags_to_desc(tag_list, sep='、')
419
+ translator = self.translate[tag_type]
420
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
421
+ return tags_to_desc(translated_tag_list, sep='、')
422
+
423
+ def handle_BPM_tag(self, tags_args):
424
+ if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
425
+ bpm = tags_args["BPM"]
426
+ del tags_args["BPM"]
427
+ tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
428
+ for tag_type in tag_types_used:
429
+ tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
430
+ return tags_args
431
+
432
+ def generate_description_ordinary(self, data, thresh = 0.3):
433
+ if self.lang != 'en':
434
+ raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
435
+ description = f'a piece of music by {data["Artist"]}'
436
+
437
+ # Add genre if available
438
+ if data["Genre"] and random.random() > thresh:
439
+ genres = ', '.join(data["Genre"])
440
+ description += f', belonging to the {genres} genres'
441
+
442
+ # Add moods if available
443
+ if data["Tags"] and random.random() > thresh:
444
+ tags = ', '.join(data["Tags"])
445
+ description += f'. This track contains the tags:{tags}'
446
+
447
+ # Add moods if available
448
+ if data["Mood"] and random.random() > thresh:
449
+ moods = ', '.join(data["Mood"])
450
+ description += f'. This track conveys a {moods} mood.'
451
+
452
+ # Add instruments if available
453
+ if data["Instrument"] and random.random() > thresh:
454
+ instruments = ', '.join(data["Instrument"])
455
+ description += f'. and primarily features the following instruments: {instruments}'
456
+
457
+ # Add a period to end the description
458
+ description += '.'
459
+
460
+ return description
461
+
codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
3
+ Code adapted from Jax version in Appendix A.1
4
+ """
5
+
6
+ from __future__ import annotations
7
+ from functools import wraps, partial
8
+ from contextlib import nullcontext
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import Module
14
+ from torch import Tensor, int32
15
+ from torch.amp import autocast
16
+
17
+ from einops import rearrange, pack, unpack
18
+
19
+ # helper functions
20
+
21
+ def exists(v):
22
+ return v is not None
23
+
24
+ def default(*args):
25
+ for arg in args:
26
+ if exists(arg):
27
+ return arg
28
+ return None
29
+
30
+ def maybe(fn):
31
+ @wraps(fn)
32
+ def inner(x, *args, **kwargs):
33
+ if not exists(x):
34
+ return x
35
+ return fn(x, *args, **kwargs)
36
+ return inner
37
+
38
+ def pack_one(t, pattern):
39
+ return pack([t], pattern)
40
+
41
+ def unpack_one(t, ps, pattern):
42
+ return unpack(t, ps, pattern)[0]
43
+
44
+ # tensor helpers
45
+
46
+ def round_ste(z: Tensor) -> Tensor:
47
+ """Round with straight through gradients."""
48
+ zhat = z.round()
49
+ return z + (zhat - z).detach()
50
+
51
+ # main class
52
+
53
+ class FSQ(Module):
54
+ def __init__(
55
+ self,
56
+ levels: List[int],
57
+ dim: int | None = None,
58
+ num_codebooks = 1,
59
+ keep_num_codebooks_dim: bool | None = None,
60
+ scale: float | None = None,
61
+ allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
62
+ channel_first: bool = False,
63
+ projection_has_bias: bool = True,
64
+ return_indices = True,
65
+ force_quantization_f32 = True
66
+ ):
67
+ super().__init__()
68
+ _levels = torch.tensor(levels, dtype=int32)
69
+ self.register_buffer("_levels", _levels, persistent = False)
70
+
71
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
72
+ self.register_buffer("_basis", _basis, persistent = False)
73
+
74
+ self.scale = scale
75
+
76
+ codebook_dim = len(levels)
77
+ self.codebook_dim = codebook_dim
78
+
79
+ effective_codebook_dim = codebook_dim * num_codebooks
80
+ self.num_codebooks = num_codebooks
81
+ self.effective_codebook_dim = effective_codebook_dim
82
+
83
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
84
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
85
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
86
+
87
+ self.dim = default(dim, len(_levels) * num_codebooks)
88
+
89
+ self.channel_first = channel_first
90
+
91
+ has_projections = self.dim != effective_codebook_dim
92
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
93
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
94
+
95
+ self.has_projections = has_projections
96
+
97
+ self.return_indices = return_indices
98
+ if return_indices:
99
+ self.codebook_size = self._levels.prod().item()
100
+ implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
101
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
102
+
103
+ self.allowed_dtypes = allowed_dtypes
104
+ self.force_quantization_f32 = force_quantization_f32
105
+
106
+ def bound(self, z, eps: float = 1e-3):
107
+ """ Bound `z`, an array of shape (..., d). """
108
+ half_l = (self._levels - 1) * (1 + eps) / 2
109
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
110
+ shift = (offset / half_l).atanh()
111
+ return (z + shift).tanh() * half_l - offset
112
+
113
+ def quantize(self, z):
114
+ """ Quantizes z, returns quantized zhat, same shape as z. """
115
+ quantized = round_ste(self.bound(z))
116
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
117
+ return quantized / half_width
118
+
119
+ def _scale_and_shift(self, zhat_normalized):
120
+ half_width = self._levels // 2
121
+ return (zhat_normalized * half_width) + half_width
122
+
123
+ def _scale_and_shift_inverse(self, zhat):
124
+ half_width = self._levels // 2
125
+ return (zhat - half_width) / half_width
126
+
127
+ def _indices_to_codes(self, indices):
128
+ level_indices = self.indices_to_level_indices(indices)
129
+ codes = self._scale_and_shift_inverse(level_indices)
130
+ return codes
131
+
132
+ def codes_to_indices(self, zhat):
133
+ """ Converts a `code` to an index in the codebook. """
134
+ assert zhat.shape[-1] == self.codebook_dim
135
+ zhat = self._scale_and_shift(zhat)
136
+ return (zhat * self._basis).sum(dim=-1).to(int32)
137
+
138
+ def indices_to_level_indices(self, indices):
139
+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
140
+ indices = rearrange(indices, '... -> ... 1')
141
+ codes_non_centered = (indices // self._basis) % self._levels
142
+ return codes_non_centered
143
+
144
+ def indices_to_codes(self, indices):
145
+ """ Inverse of `codes_to_indices`. """
146
+ assert exists(indices)
147
+
148
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
149
+
150
+ codes = self._indices_to_codes(indices)
151
+
152
+ if self.keep_num_codebooks_dim:
153
+ codes = rearrange(codes, '... c d -> ... (c d)')
154
+
155
+ codes = self.project_out(codes)
156
+
157
+ if is_img_or_video or self.channel_first:
158
+ codes = rearrange(codes, 'b ... d -> b d ...')
159
+
160
+ return codes
161
+
162
+ def forward(self, z):
163
+ """
164
+ einstein notation
165
+ b - batch
166
+ n - sequence (or flattened spatial dimensions)
167
+ d - feature dimension
168
+ c - number of codebook dim
169
+ """
170
+
171
+ is_img_or_video = z.ndim >= 4
172
+ need_move_channel_last = is_img_or_video or self.channel_first
173
+
174
+ # standardize image or video into (batch, seq, dimension)
175
+
176
+ if need_move_channel_last:
177
+ z = rearrange(z, 'b d ... -> b ... d')
178
+ z, ps = pack_one(z, 'b * d')
179
+
180
+ assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
181
+
182
+ z = self.project_in(z)
183
+
184
+ z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
185
+
186
+ # whether to force quantization step to be full precision or not
187
+
188
+ force_f32 = self.force_quantization_f32
189
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
190
+
191
+ with quantization_context():
192
+ orig_dtype = z.dtype
193
+
194
+ if force_f32 and orig_dtype not in self.allowed_dtypes:
195
+ z = z.float()
196
+
197
+ codes = self.quantize(z)
198
+
199
+ # returning indices could be optional
200
+
201
+ indices = None
202
+
203
+ if self.return_indices:
204
+ indices = self.codes_to_indices(codes)
205
+
206
+ codes = rearrange(codes, 'b n c d -> b n (c d)')
207
+
208
+ codes = codes.type(orig_dtype)
209
+
210
+ # project out
211
+
212
+ out = self.project_out(codes)
213
+
214
+ # reconstitute image or video dimensions
215
+
216
+ if need_move_channel_last:
217
+ out = unpack_one(out, ps, 'b * d')
218
+ out = rearrange(out, 'b ... d -> b d ...')
219
+
220
+ indices = maybe(unpack_one)(indices, ps, 'b * c')
221
+
222
+ if not self.keep_num_codebooks_dim and self.return_indices:
223
+ indices = maybe(rearrange)(indices, '... 1 -> ...')
224
+
225
+ # return quantized output and indices
226
+
227
+ return out, indices
228
+
229
+
230
+ if __name__ == '__main__':
231
+ # test
232
+ fsq = FSQ([4, 4, 4],dim=1024)
233
+ z = torch.randn(2, 3, 1024)
234
+ out, indices = fsq(z)
235
+ print(out.shape, indices.shape)
236
+ # print(out, indices)
codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ # from .. import distrib
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
80
+ means, "c d -> () c d"
81
+ )
82
+ dists = -(diffs ** 2).sum(dim=-1)
83
+
84
+ buckets = dists.max(dim=-1).indices
85
+ bins = torch.bincount(buckets, minlength=num_clusters)
86
+ zero_mask = bins == 0
87
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
88
+
89
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
90
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
91
+ new_means = new_means / bins_min_clamped[..., None]
92
+
93
+ means = torch.where(zero_mask[..., None], means, new_means)
94
+
95
+ return means, bins
96
+
97
+
98
+ class EuclideanCodebook(nn.Module):
99
+ """Codebook with Euclidean distance.
100
+ Args:
101
+ dim (int): Dimension.
102
+ codebook_size (int): Codebook size.
103
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
104
+ If set to true, run the k-means algorithm on the first training batch and use
105
+ the learned centroids as initialization.
106
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
107
+ decay (float): Decay for exponential moving average over the codebooks.
108
+ epsilon (float): Epsilon value for numerical stability.
109
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
110
+ that have an exponential moving average cluster size less than the specified threshold with
111
+ randomly selected vector from the current batch.
112
+ """
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ codebook_size: int,
117
+ kmeans_init: int = False,
118
+ kmeans_iters: int = 10,
119
+ decay: float = 0.99,
120
+ epsilon: float = 1e-5,
121
+ threshold_ema_dead_code: int = 2,
122
+ ):
123
+ super().__init__()
124
+ self.decay = decay
125
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
126
+ embed = init_fn(codebook_size, dim)
127
+
128
+ self.codebook_size = codebook_size
129
+
130
+ self.kmeans_iters = kmeans_iters
131
+ self.epsilon = epsilon
132
+ self.threshold_ema_dead_code = threshold_ema_dead_code
133
+
134
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
135
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
136
+ self.register_buffer("embed", embed)
137
+ self.register_buffer("embed_avg", embed.clone())
138
+
139
+ self.runed_steps = 0
140
+ self.stop_steps = 50_000
141
+
142
+ @torch.jit.ignore
143
+ def init_embed_(self, data):
144
+ if self.inited:
145
+ return
146
+
147
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
148
+ self.embed.data.copy_(embed)
149
+ self.embed_avg.data.copy_(embed.clone())
150
+ self.cluster_size.data.copy_(cluster_size)
151
+ self.inited.data.copy_(torch.Tensor([True]))
152
+ # Make sure all buffers across workers are in sync after initialization
153
+ distrib.broadcast_tensors(self.buffers())
154
+
155
+ def replace_(self, samples, mask):
156
+ modified_codebook = torch.where(
157
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
158
+ )
159
+ self.embed.data.copy_(modified_codebook)
160
+
161
+ def expire_codes_(self, batch_samples):
162
+ if self.threshold_ema_dead_code == 0:
163
+ return
164
+
165
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
166
+ if not torch.any(expired_codes):
167
+ return
168
+
169
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
170
+ self.replace_(batch_samples, mask=expired_codes)
171
+ # distrib.broadcast_tensors(self.buffers())
172
+
173
+ def preprocess(self, x):
174
+ x = rearrange(x, "... d -> (...) d")
175
+ return x
176
+
177
+ def quantize(self, x):
178
+ embed = self.embed.t()
179
+ dist = -(
180
+ x.pow(2).sum(1, keepdim=True)
181
+ - 2 * x @ embed
182
+ + embed.pow(2).sum(0, keepdim=True)
183
+ )
184
+ embed_ind = dist.max(dim=-1).indices
185
+ return embed_ind
186
+
187
+ def postprocess_emb(self, embed_ind, shape):
188
+ return embed_ind.view(*shape[:-1])
189
+
190
+ def dequantize(self, embed_ind):
191
+ quantize = F.embedding(embed_ind, self.embed)
192
+ return quantize
193
+
194
+ def encode(self, x):
195
+ shape = x.shape
196
+ # pre-process
197
+ x = self.preprocess(x)
198
+ # quantize
199
+ embed_ind = self.quantize(x)
200
+ # post-process
201
+ embed_ind = self.postprocess_emb(embed_ind, shape)
202
+ return embed_ind
203
+
204
+ def decode(self, embed_ind):
205
+ quantize = self.dequantize(embed_ind)
206
+ return quantize
207
+
208
+ def forward(self, x):
209
+ shape, dtype = x.shape, x.dtype
210
+ x = self.preprocess(x)
211
+ # self.init_embed_(x)
212
+
213
+ embed_ind = self.quantize(x)
214
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
215
+ embed_ind = self.postprocess_emb(embed_ind, shape)
216
+ quantize = self.dequantize(embed_ind)
217
+ self.runed_steps += 1
218
+
219
+ if self.training and self.runed_steps < self.stop_steps:
220
+ # We do the expiry of code at that point as buffers are in sync
221
+ # and all the workers will take the same decision.
222
+ self.expire_codes_(x)
223
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
224
+ embed_sum = x.t() @ embed_onehot
225
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
226
+ cluster_size = (
227
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
228
+ * self.cluster_size.sum()
229
+ )
230
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
231
+ self.embed.data.copy_(embed_normalized)
232
+
233
+ return quantize, embed_ind
234
+
235
+
236
+ class VectorQuantization(nn.Module):
237
+ """Vector quantization implementation.
238
+ Currently supports only euclidean distance.
239
+ Args:
240
+ dim (int): Dimension
241
+ codebook_size (int): Codebook size
242
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
243
+ decay (float): Decay for exponential moving average over the codebooks.
244
+ epsilon (float): Epsilon value for numerical stability.
245
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
246
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
247
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
248
+ that have an exponential moving average cluster size less than the specified threshold with
249
+ randomly selected vector from the current batch.
250
+ commitment_weight (float): Weight for commitment loss.
251
+ """
252
+ def __init__(
253
+ self,
254
+ dim: int,
255
+ codebook_size: int,
256
+ codebook_dim: tp.Optional[int] = None,
257
+ decay: float = 0.99,
258
+ epsilon: float = 1e-5,
259
+ kmeans_init: bool = True,
260
+ kmeans_iters: int = 50,
261
+ threshold_ema_dead_code: int = 2,
262
+ commitment_weight: float = 1.,
263
+ ):
264
+ super().__init__()
265
+ _codebook_dim: int = default(codebook_dim, dim)
266
+
267
+ requires_projection = _codebook_dim != dim
268
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
269
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
270
+
271
+ self.epsilon = epsilon
272
+ self.commitment_weight = commitment_weight
273
+
274
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
275
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
276
+ decay=decay, epsilon=epsilon,
277
+ threshold_ema_dead_code=threshold_ema_dead_code)
278
+ self.codebook_size = codebook_size
279
+
280
+ @property
281
+ def codebook(self):
282
+ return self._codebook.embed
283
+
284
+ def encode(self, x):
285
+ x = rearrange(x, "b d n -> b n d")
286
+ x = self.project_in(x)
287
+ embed_in = self._codebook.encode(x)
288
+ return embed_in
289
+
290
+ def decode(self, embed_ind):
291
+ quantize = self._codebook.decode(embed_ind)
292
+ quantize = self.project_out(quantize)
293
+ quantize = rearrange(quantize, "b n d -> b d n")
294
+ return quantize
295
+
296
+ def forward(self, x, do_debug=False):
297
+ device = x.device
298
+ x = rearrange(x, "b d n -> b n d")
299
+ x = self.project_in(x)
300
+
301
+ quantize, embed_ind = self._codebook(x)
302
+
303
+ if self.training:
304
+ quantize = x + (quantize - x).detach()
305
+
306
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
307
+
308
+ if self.training:
309
+ if self.commitment_weight > 0:
310
+ commit_loss = F.mse_loss(quantize.detach(), x)
311
+ loss = loss + commit_loss * self.commitment_weight
312
+ quantize = self.project_out(quantize)
313
+ quantize = rearrange(quantize, "b n d -> b d n")
314
+ return quantize, embed_ind, loss
315
+
316
+
317
+ class ResidualVectorQuantization(nn.Module):
318
+ """Residual vector quantization implementation.
319
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
320
+ """
321
+ def __init__(self, *, num_quantizers, **kwargs):
322
+ super().__init__()
323
+ self.layers = nn.ModuleList(
324
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
325
+ )
326
+
327
+ def forward(self, x, n_q: tp.Optional[int] = None):
328
+ quantized_out = 0.0
329
+ residual = x
330
+
331
+ all_losses = []
332
+ all_indices = []
333
+
334
+ n_q = n_q or len(self.layers)
335
+
336
+ for layerinx, layer in enumerate(self.layers[:n_q]):
337
+ print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
338
+ quantized, indices, loss = layer(residual)
339
+ residual = residual - quantized
340
+ quantized_out = quantized_out + quantized
341
+
342
+ all_indices.append(indices)
343
+ all_losses.append(loss)
344
+
345
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
346
+ return quantized_out, out_indices, out_losses
347
+
348
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
349
+ residual = x
350
+ all_indices = []
351
+ n_q = n_q or len(self.layers)
352
+ for layer in self.layers[:n_q]:
353
+ indices = layer.encode(residual)
354
+ quantized = layer.decode(indices)
355
+ residual = residual - quantized
356
+ all_indices.append(indices)
357
+ out_indices = torch.stack(all_indices)
358
+ return out_indices
359
+
360
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
361
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
362
+ for i, indices in enumerate(q_indices):
363
+ layer = self.layers[i]
364
+ quantized = layer.decode(indices)
365
+ quantized_out = quantized_out + quantized
366
+ return quantized_out
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ def WNConv1d(*args, **kwargs):
11
+ return weight_norm(nn.Conv1d(*args, **kwargs))
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
199
+ for n in range(encodings.shape[1]):
200
+ print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
201
+ (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
202
+ ))
203
+
204
+ return z_q, codes, latents, commitment_loss, codebook_loss
205
+
206
+ def from_codes(self, codes: torch.Tensor):
207
+ """Given the quantized codes, reconstruct the continuous representation
208
+ Parameters
209
+ ----------
210
+ codes : Tensor[B x N x T]
211
+ Quantized discrete representation of input
212
+ Returns
213
+ -------
214
+ Tensor[B x D x T]
215
+ Quantized continuous representation of input
216
+ """
217
+ z_q = 0.0
218
+ z_p = []
219
+ n_codebooks = codes.shape[1]
220
+ for i in range(n_codebooks):
221
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
222
+ z_p.append(z_p_i)
223
+
224
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
225
+ z_q = z_q + z_q_i
226
+ return z_q, torch.cat(z_p, dim=1), codes
227
+
228
+ def from_latents(self, latents: torch.Tensor):
229
+ """Given the unquantized latents, reconstruct the
230
+ continuous representation after quantization.
231
+
232
+ Parameters
233
+ ----------
234
+ latents : Tensor[B x N x T]
235
+ Continuous representation of input after projection
236
+
237
+ Returns
238
+ -------
239
+ Tensor[B x D x T]
240
+ Quantized representation of full-projected space
241
+ Tensor[B x D x T]
242
+ Quantized representation of latent space
243
+ """
244
+ z_q = 0
245
+ z_p = []
246
+ codes = []
247
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
248
+
249
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
250
+ 0
251
+ ]
252
+ for i in range(n_codebooks):
253
+ j, k = dims[i], dims[i + 1]
254
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
255
+ z_p.append(z_p_i)
256
+ codes.append(codes_i)
257
+
258
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
259
+ z_q = z_q + z_q_i
260
+
261
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
266
+ x = torch.randn(16, 512, 80)
267
+ y = rvq(x)
268
+ print(y["latents"].shape)
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ def WNConv1d(*args, **kwargs):
11
+ return weight_norm(nn.Conv1d(*args, **kwargs))
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
34
+ self.stale_tolerance = stale_tolerance
35
+
36
+ def forward(self, z):
37
+ """Quantized the input tensor using a fixed codebook and returns
38
+ the corresponding codebook vectors
39
+
40
+ Parameters
41
+ ----------
42
+ z : Tensor[B x D x T]
43
+
44
+ Returns
45
+ -------
46
+ Tensor[B x D x T]
47
+ Quantized continuous representation of input
48
+ Tensor[1]
49
+ Commitment loss to train encoder to predict vectors closer to codebook
50
+ entries
51
+ Tensor[1]
52
+ Codebook loss to update the codebook
53
+ Tensor[B x T]
54
+ Codebook indices (quantized discrete representation of input)
55
+ Tensor[B x D x T]
56
+ Projected latents (continuous representation of input before quantization)
57
+ """
58
+
59
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
60
+ z_e = self.in_proj(z) # z_e : (B x D x T)
61
+ z_q, indices = self.decode_latents(z_e)
62
+
63
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
64
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
65
+
66
+ z_q = (
67
+ z_e + (z_q - z_e).detach()
68
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
69
+
70
+ z_q = self.out_proj(z_q)
71
+
72
+ return z_q, commitment_loss, codebook_loss, indices, z_e
73
+
74
+ def embed_code(self, embed_id):
75
+ return F.embedding(embed_id, self.codebook.weight)
76
+
77
+ def decode_code(self, embed_id):
78
+ return self.embed_code(embed_id).transpose(1, 2)
79
+
80
+ def decode_latents(self, latents):
81
+ encodings = rearrange(latents, "b d t -> (b t) d")
82
+ codebook = self.codebook.weight # codebook: (N x D)
83
+
84
+ # L2 normalize encodings and codebook (ViT-VQGAN)
85
+ encodings = F.normalize(encodings)
86
+ codebook = F.normalize(codebook)
87
+
88
+ # Compute euclidean distance with codebook
89
+ dist = (
90
+ encodings.pow(2).sum(1, keepdim=True)
91
+ - 2 * encodings @ codebook.t()
92
+ + codebook.pow(2).sum(1, keepdim=True).t()
93
+ )
94
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
95
+ z_q = self.decode_code(indices)
96
+
97
+ if(self.training):
98
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
99
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
100
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
101
+
102
+ # random replace codes that haven't been used for a while
103
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
104
+ if replace_code.sum(-1) > 0:
105
+ print("Replace {} codes".format(replace_code.sum(-1)))
106
+ random_input_idx = torch.randperm(encodings.shape[0])
107
+ random_input = encodings[random_input_idx].view(encodings.shape)
108
+ if random_input.shape[0] < self.codebook_size:
109
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
110
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
111
+
112
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
113
+ self.stale_counter = self.stale_counter * (1 - replace_code)
114
+
115
+ return z_q, indices
116
+
117
+
118
+ class ResidualVectorQuantize(nn.Module):
119
+ """
120
+ Introduced in SoundStream: An end2end neural audio codec
121
+ https://arxiv.org/abs/2107.03312
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ input_dim: int = 512,
127
+ n_codebooks: int = 9,
128
+ codebook_size: int = 1024,
129
+ codebook_dim: Union[int, list] = 8,
130
+ quantizer_dropout: float = 0.0,
131
+ stale_tolerance: int = 100,
132
+ ):
133
+ super().__init__()
134
+ if isinstance(codebook_dim, int):
135
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
136
+
137
+ self.n_codebooks = n_codebooks
138
+ self.codebook_dim = codebook_dim
139
+ self.codebook_size = codebook_size
140
+
141
+ self.quantizers = nn.ModuleList(
142
+ [
143
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
144
+ for i in range(n_codebooks)
145
+ ]
146
+ )
147
+ self.quantizer_dropout = quantizer_dropout
148
+
149
+ def forward(self, z, n_quantizers: int = None):
150
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
151
+ the corresponding codebook vectors
152
+ Parameters
153
+ ----------
154
+ z : Tensor[B x D x T]
155
+ n_quantizers : int, optional
156
+ No. of quantizers to use
157
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
158
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
159
+ when in training mode, and a random number of quantizers is used.
160
+ Returns
161
+ -------
162
+ dict
163
+ A dictionary with the following keys:
164
+
165
+ "z" : Tensor[B x D x T]
166
+ Quantized continuous representation of input
167
+ "codes" : Tensor[B x N x T]
168
+ Codebook indices for each codebook
169
+ (quantized discrete representation of input)
170
+ "latents" : Tensor[B x N*D x T]
171
+ Projected latents (continuous representation of input before quantization)
172
+ "vq/commitment_loss" : Tensor[1]
173
+ Commitment loss to train encoder to predict vectors closer to codebook
174
+ entries
175
+ "vq/codebook_loss" : Tensor[1]
176
+ Codebook loss to update the codebook
177
+ """
178
+ z_q = 0
179
+ residual = z
180
+ commitment_loss = 0
181
+ codebook_loss = 0
182
+
183
+ codebook_indices = []
184
+ latents = []
185
+
186
+ if n_quantizers is None:
187
+ n_quantizers = self.n_codebooks
188
+ if self.training:
189
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
190
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
191
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
192
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
193
+ n_quantizers = n_quantizers.to(z.device)
194
+
195
+ for i, quantizer in enumerate(self.quantizers):
196
+ if self.training is False and i >= n_quantizers:
197
+ break
198
+
199
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
200
+ residual
201
+ )
202
+
203
+ # Create mask to apply quantizer dropout
204
+ mask = (
205
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
206
+ )
207
+ z_q = z_q + z_q_i * mask[:, None, None]
208
+ residual = residual - z_q_i
209
+
210
+ # Sum losses
211
+ commitment_loss += (commitment_loss_i * mask).mean()
212
+ codebook_loss += (codebook_loss_i * mask).mean()
213
+
214
+ codebook_indices.append(indices_i)
215
+ latents.append(z_e_i)
216
+
217
+ codes = torch.stack(codebook_indices, dim=1)
218
+ latents = torch.cat(latents, dim=1)
219
+
220
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
221
+ for n in range(encodings.shape[1]):
222
+ print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
223
+ (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
224
+ ))
225
+
226
+ return z_q, codes, latents, commitment_loss, codebook_loss
227
+
228
+ def from_codes(self, codes: torch.Tensor):
229
+ """Given the quantized codes, reconstruct the continuous representation
230
+ Parameters
231
+ ----------
232
+ codes : Tensor[B x N x T]
233
+ Quantized discrete representation of input
234
+ Returns
235
+ -------
236
+ Tensor[B x D x T]
237
+ Quantized continuous representation of input
238
+ """
239
+ z_q = 0.0
240
+ z_p = []
241
+ n_codebooks = codes.shape[1]
242
+ for i in range(n_codebooks):
243
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
244
+ z_p.append(z_p_i)
245
+
246
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
247
+ z_q = z_q + z_q_i
248
+ return z_q, torch.cat(z_p, dim=1), codes
249
+
250
+ def from_latents(self, latents: torch.Tensor):
251
+ """Given the unquantized latents, reconstruct the
252
+ continuous representation after quantization.
253
+
254
+ Parameters
255
+ ----------
256
+ latents : Tensor[B x N x T]
257
+ Continuous representation of input after projection
258
+
259
+ Returns
260
+ -------
261
+ Tensor[B x D x T]
262
+ Quantized representation of full-projected space
263
+ Tensor[B x D x T]
264
+ Quantized representation of latent space
265
+ """
266
+ z_q = 0
267
+ z_p = []
268
+ codes = []
269
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
270
+
271
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
272
+ 0
273
+ ]
274
+ for i in range(n_codebooks):
275
+ j, k = dims[i], dims[i + 1]
276
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
277
+ z_p.append(z_p_i)
278
+ codes.append(codes_i)
279
+
280
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
281
+ z_q = z_q + z_q_i
282
+
283
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
288
+ x = torch.randn(16, 512, 80)
289
+ y = rvq(x)
290
+ print(y["latents"].shape)
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compared with `descript_quantize2`, we use rvq & random_dropout
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ return weight_norm(nn.Conv1d(*args, **kwargs))
13
+
14
+ class VectorQuantize(nn.Module):
15
+ """
16
+ Implementation of VQ similar to Karpathy's repo:
17
+ https://github.com/karpathy/deep-vector-quantization
18
+ Additionally uses following tricks from Improved VQGAN
19
+ (https://arxiv.org/pdf/2110.04627.pdf):
20
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
21
+ for improved codebook usage
22
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
23
+ improves training stability
24
+ """
25
+
26
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
27
+ super().__init__()
28
+ self.codebook_size = codebook_size
29
+ self.codebook_dim = codebook_dim
30
+
31
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
32
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
33
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
34
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
35
+ self.stale_tolerance = stale_tolerance
36
+
37
+ def forward(self, z):
38
+ """Quantized the input tensor using a fixed codebook and returns
39
+ the corresponding codebook vectors
40
+
41
+ Parameters
42
+ ----------
43
+ z : Tensor[B x D x T]
44
+
45
+ Returns
46
+ -------
47
+ Tensor[B x D x T]
48
+ Quantized continuous representation of input
49
+ Tensor[1]
50
+ Commitment loss to train encoder to predict vectors closer to codebook
51
+ entries
52
+ Tensor[1]
53
+ Codebook loss to update the codebook
54
+ Tensor[B x T]
55
+ Codebook indices (quantized discrete representation of input)
56
+ Tensor[B x D x T]
57
+ Projected latents (continuous representation of input before quantization)
58
+ """
59
+
60
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
61
+ z_e = self.in_proj(z) # z_e : (B x D x T)
62
+ z_q, indices = self.decode_latents(z_e)
63
+
64
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
65
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
66
+
67
+ z_q = (
68
+ z_e + (z_q - z_e).detach()
69
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
70
+
71
+ z_q = self.out_proj(z_q)
72
+
73
+ return z_q, commitment_loss, codebook_loss, indices, z_e
74
+
75
+ def embed_code(self, embed_id):
76
+ return F.embedding(embed_id, self.codebook.weight)
77
+
78
+ def decode_code(self, embed_id):
79
+ return self.embed_code(embed_id).transpose(1, 2)
80
+
81
+ def decode_latents(self, latents):
82
+ encodings = rearrange(latents, "b d t -> (b t) d")
83
+ codebook = self.codebook.weight # codebook: (N x D)
84
+
85
+ # L2 normalize encodings and codebook (ViT-VQGAN)
86
+ encodings = F.normalize(encodings)
87
+ codebook = F.normalize(codebook)
88
+
89
+ # Compute euclidean distance with codebook
90
+ dist = (
91
+ encodings.pow(2).sum(1, keepdim=True)
92
+ - 2 * encodings @ codebook.t()
93
+ + codebook.pow(2).sum(1, keepdim=True).t()
94
+ )
95
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
96
+ z_q = self.decode_code(indices)
97
+
98
+ if(self.training):
99
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
100
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
101
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
102
+
103
+ # random replace codes that haven't been used for a while
104
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
105
+ if replace_code.sum(-1) > 0:
106
+ print("Replace {} codes".format(replace_code.sum(-1)))
107
+ random_input_idx = torch.randperm(encodings.shape[0])
108
+ random_input = encodings[random_input_idx].view(encodings.shape)
109
+ if random_input.shape[0] < self.codebook_size:
110
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
111
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
112
+
113
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
114
+ self.stale_counter = self.stale_counter * (1 - replace_code)
115
+
116
+ return z_q, indices
117
+
118
+
119
+ class ResidualVectorQuantize(nn.Module):
120
+ """
121
+ Introduced in SoundStream: An end2end neural audio codec
122
+ https://arxiv.org/abs/2107.03312
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ input_dim: int = 512,
128
+ n_codebooks: int = 9,
129
+ codebook_size: int = 1024,
130
+ codebook_dim: Union[int, list] = 8,
131
+ quantizer_dropout: float = 0.0,
132
+ stale_tolerance: int = 100,
133
+ ):
134
+ super().__init__()
135
+ if isinstance(codebook_dim, int):
136
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
137
+
138
+ self.n_codebooks = n_codebooks
139
+ self.codebook_dim = codebook_dim
140
+ self.codebook_size = codebook_size
141
+
142
+ self.quantizers = nn.ModuleList(
143
+ [
144
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
145
+ for i in range(n_codebooks)
146
+ ]
147
+ )
148
+ self.quantizer_dropout = quantizer_dropout
149
+
150
+ def forward(self, z, n_quantizers: int = None):
151
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
152
+ the corresponding codebook vectors
153
+ Parameters
154
+ ----------
155
+ z : Tensor[B x D x T]
156
+ n_quantizers : int, optional
157
+ No. of quantizers to use
158
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
159
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
160
+ when in training mode, and a random number of quantizers is used.
161
+ Returns
162
+ -------
163
+ dict
164
+ A dictionary with the following keys:
165
+
166
+ "z" : Tensor[B x D x T]
167
+ Quantized continuous representation of input
168
+ "codes" : Tensor[B x N x T]
169
+ Codebook indices for each codebook
170
+ (quantized discrete representation of input)
171
+ "latents" : Tensor[B x N*D x T]
172
+ Projected latents (continuous representation of input before quantization)
173
+ "vq/commitment_loss" : Tensor[1]
174
+ Commitment loss to train encoder to predict vectors closer to codebook
175
+ entries
176
+ "vq/codebook_loss" : Tensor[1]
177
+ Codebook loss to update the codebook
178
+ """
179
+ z_q = 0
180
+ residual = z
181
+ commitment_loss = 0
182
+ codebook_loss = 0
183
+
184
+ codebook_indices = []
185
+ latents = []
186
+
187
+ if n_quantizers is None:
188
+ n_quantizers = self.n_codebooks
189
+ if self.training:
190
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
191
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
192
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
193
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
194
+ n_quantizers = n_quantizers.to(z.device)
195
+ else:
196
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
197
+ n_quantizers = n_quantizers.to(z.device)
198
+
199
+ for i, quantizer in enumerate(self.quantizers):
200
+ # if self.training is False and i >= n_quantizers:
201
+ # break
202
+
203
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
204
+ residual
205
+ )
206
+
207
+ # Create mask to apply quantizer dropout
208
+ mask = (
209
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
210
+ )
211
+ z_q = z_q + z_q_i * mask[:, None, None]
212
+ residual = residual - z_q_i
213
+
214
+ # Sum losses
215
+ commitment_loss += (commitment_loss_i * mask).mean()
216
+ codebook_loss += (codebook_loss_i * mask).mean()
217
+
218
+ codebook_indices.append(indices_i)
219
+ latents.append(z_e_i)
220
+
221
+ codes = torch.stack(codebook_indices, dim=1)
222
+ latents = torch.cat(latents, dim=1)
223
+
224
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
225
+ # for n in range(encodings.shape[1]):
226
+ # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
227
+ # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
228
+ # ))
229
+
230
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
231
+
232
+ def from_codes(self, codes: torch.Tensor):
233
+ """Given the quantized codes, reconstruct the continuous representation
234
+ Parameters
235
+ ----------
236
+ codes : Tensor[B x N x T]
237
+ Quantized discrete representation of input
238
+ Returns
239
+ -------
240
+ Tensor[B x D x T]
241
+ Quantized continuous representation of input
242
+ """
243
+ z_q = 0.0
244
+ z_p = []
245
+ n_codebooks = codes.shape[1]
246
+ for i in range(n_codebooks):
247
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
248
+ z_p.append(z_p_i)
249
+
250
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
251
+ z_q = z_q + z_q_i
252
+ return z_q, torch.cat(z_p, dim=1), codes
253
+
254
+ def from_latents(self, latents: torch.Tensor):
255
+ """Given the unquantized latents, reconstruct the
256
+ continuous representation after quantization.
257
+
258
+ Parameters
259
+ ----------
260
+ latents : Tensor[B x N x T]
261
+ Continuous representation of input after projection
262
+
263
+ Returns
264
+ -------
265
+ Tensor[B x D x T]
266
+ Quantized representation of full-projected space
267
+ Tensor[B x D x T]
268
+ Quantized representation of latent space
269
+ """
270
+ z_q = 0
271
+ z_p = []
272
+ codes = []
273
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
274
+
275
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
276
+ 0
277
+ ]
278
+ for i in range(n_codebooks):
279
+ j, k = dims[i], dims[i + 1]
280
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
281
+ z_p.append(z_p_i)
282
+ codes.append(codes_i)
283
+
284
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
285
+ z_q = z_q + z_q_i
286
+
287
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
288
+
289
+
290
+ if __name__ == "__main__":
291
+ rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
292
+ x = torch.randn(16, 1024, 80)
293
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
294
+ print(quantized_prompt_embeds.shape)
295
+ print(codes.shape)
296
+ # w/o reconstruction
297
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0
298
+ # w/ reconstruction
299
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compared with `descript_quantize2`, we use rvq & random_dropout
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+ import random
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+ class VectorQuantize(nn.Module):
16
+ """
17
+ Implementation of VQ similar to Karpathy's repo:
18
+ https://github.com/karpathy/deep-vector-quantization
19
+ Additionally uses following tricks from Improved VQGAN
20
+ (https://arxiv.org/pdf/2110.04627.pdf):
21
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
22
+ for improved codebook usage
23
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
24
+ improves training stability
25
+ """
26
+
27
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
28
+ super().__init__()
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+
32
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
33
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
34
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
35
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
36
+ self.stale_tolerance = stale_tolerance
37
+
38
+ def forward(self, z):
39
+ """Quantized the input tensor using a fixed codebook and returns
40
+ the corresponding codebook vectors
41
+
42
+ Parameters
43
+ ----------
44
+ z : Tensor[B x D x T]
45
+
46
+ Returns
47
+ -------
48
+ Tensor[B x D x T]
49
+ Quantized continuous representation of input
50
+ Tensor[1]
51
+ Commitment loss to train encoder to predict vectors closer to codebook
52
+ entries
53
+ Tensor[1]
54
+ Codebook loss to update the codebook
55
+ Tensor[B x T]
56
+ Codebook indices (quantized discrete representation of input)
57
+ Tensor[B x D x T]
58
+ Projected latents (continuous representation of input before quantization)
59
+ """
60
+
61
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
62
+ z_e = self.in_proj(z) # z_e : (B x D x T)
63
+ z_q, indices = self.decode_latents(z_e)
64
+
65
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
66
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
67
+
68
+ z_q = (
69
+ z_e + (z_q - z_e).detach()
70
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
71
+
72
+ z_q = self.out_proj(z_q)
73
+
74
+ return z_q, commitment_loss, codebook_loss, indices, z_e
75
+
76
+ def embed_code(self, embed_id):
77
+ return F.embedding(embed_id, self.codebook.weight)
78
+
79
+ def decode_code(self, embed_id):
80
+ return self.embed_code(embed_id).transpose(1, 2)
81
+
82
+ def decode_latents(self, latents):
83
+ encodings = rearrange(latents, "b d t -> (b t) d")
84
+ codebook = self.codebook.weight # codebook: (N x D)
85
+
86
+ # L2 normalize encodings and codebook (ViT-VQGAN)
87
+ encodings = F.normalize(encodings)
88
+ codebook = F.normalize(codebook)
89
+
90
+ # Compute euclidean distance with codebook
91
+ dist = (
92
+ encodings.pow(2).sum(1, keepdim=True)
93
+ - 2 * encodings @ codebook.t()
94
+ + codebook.pow(2).sum(1, keepdim=True).t()
95
+ )
96
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
97
+ z_q = self.decode_code(indices)
98
+
99
+ if(self.training):
100
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
101
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
102
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
103
+
104
+ # random replace codes that haven't been used for a while
105
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
106
+ if replace_code.sum(-1) > 0:
107
+ print("Replace {} codes".format(replace_code.sum(-1)))
108
+ random_input_idx = torch.randperm(encodings.shape[0])
109
+ random_input = encodings[random_input_idx].view(encodings.shape)
110
+ if random_input.shape[0] < self.codebook_size:
111
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
112
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
113
+
114
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
115
+ self.stale_counter = self.stale_counter * (1 - replace_code)
116
+
117
+ return z_q, indices
118
+
119
+
120
+ class ResidualVectorQuantize(nn.Module):
121
+ """
122
+ Introduced in SoundStream: An end2end neural audio codec
123
+ https://arxiv.org/abs/2107.03312
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ input_dim: int = 512,
129
+ n_codebooks: int = 9,
130
+ codebook_size: int = 1024,
131
+ codebook_dim: Union[int, list] = 8,
132
+ quantizer_dropout: float = 0.0,
133
+ stale_tolerance: int = 100,
134
+ ):
135
+ super().__init__()
136
+ if isinstance(codebook_dim, int):
137
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
138
+
139
+ self.n_codebooks = n_codebooks
140
+ self.codebook_dim = codebook_dim
141
+ self.codebook_size = codebook_size
142
+
143
+ self.quantizers = nn.ModuleList(
144
+ [
145
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
146
+ for i in range(n_codebooks)
147
+ ]
148
+ )
149
+ self.quantizer_dropout = quantizer_dropout
150
+
151
+ def forward(self, z, n_quantizers: int = None):
152
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
153
+ the corresponding codebook vectors
154
+ Parameters
155
+ ----------
156
+ z : Tensor[B x D x T]
157
+ n_quantizers : int, optional
158
+ No. of quantizers to use
159
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
160
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
161
+ when in training mode, and a random number of quantizers is used.
162
+ Returns
163
+ -------
164
+ dict
165
+ A dictionary with the following keys:
166
+
167
+ "z" : Tensor[B x D x T]
168
+ Quantized continuous representation of input
169
+ "codes" : Tensor[B x N x T]
170
+ Codebook indices for each codebook
171
+ (quantized discrete representation of input)
172
+ "latents" : Tensor[B x N*D x T]
173
+ Projected latents (continuous representation of input before quantization)
174
+ "vq/commitment_loss" : Tensor[1]
175
+ Commitment loss to train encoder to predict vectors closer to codebook
176
+ entries
177
+ "vq/codebook_loss" : Tensor[1]
178
+ Codebook loss to update the codebook
179
+ """
180
+ z_q = 0
181
+ residual = z
182
+ commitment_loss = 0
183
+ codebook_loss = 0
184
+
185
+ codebook_indices = []
186
+ latents = []
187
+
188
+ if n_quantizers is None:
189
+ n_quantizers = self.n_codebooks
190
+ if self.training:
191
+ random_num = random.random()
192
+ if random_num<0.6:
193
+ n_quantizers = torch.ones((z.shape[0],)) * 1
194
+ elif random_num<0.8:
195
+ n_quantizers = torch.ones((z.shape[0],)) * 2
196
+ else:
197
+ n_quantizers = torch.ones((z.shape[0],)) * 4
198
+ n_quantizers = n_quantizers.to(z.device)
199
+ else:
200
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
201
+ n_quantizers = n_quantizers.to(z.device)
202
+
203
+ for i, quantizer in enumerate(self.quantizers):
204
+ # if self.training is False and i >= n_quantizers:
205
+ # break
206
+
207
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
208
+ residual
209
+ )
210
+
211
+ # Create mask to apply quantizer dropout
212
+ mask = (
213
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
214
+ )
215
+ z_q = z_q + z_q_i * mask[:, None, None]
216
+ residual = residual - z_q_i
217
+
218
+ # Sum losses
219
+ commitment_loss += (commitment_loss_i * mask).mean()
220
+ codebook_loss += (codebook_loss_i * mask).mean()
221
+
222
+ codebook_indices.append(indices_i)
223
+ latents.append(z_e_i)
224
+
225
+ codes = torch.stack(codebook_indices, dim=1)
226
+ latents = torch.cat(latents, dim=1)
227
+
228
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
229
+ for n in range(encodings.shape[1]):
230
+ print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
231
+ (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
232
+ ))
233
+
234
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
235
+
236
+ def from_codes(self, codes: torch.Tensor):
237
+ """Given the quantized codes, reconstruct the continuous representation
238
+ Parameters
239
+ ----------
240
+ codes : Tensor[B x N x T]
241
+ Quantized discrete representation of input
242
+ Returns
243
+ -------
244
+ Tensor[B x D x T]
245
+ Quantized continuous representation of input
246
+ """
247
+ z_q = 0.0
248
+ z_p = []
249
+ n_codebooks = codes.shape[1]
250
+ for i in range(n_codebooks):
251
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
252
+ z_p.append(z_p_i)
253
+
254
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
255
+ z_q = z_q + z_q_i
256
+ return z_q, torch.cat(z_p, dim=1), codes
257
+
258
+ def from_latents(self, latents: torch.Tensor):
259
+ """Given the unquantized latents, reconstruct the
260
+ continuous representation after quantization.
261
+
262
+ Parameters
263
+ ----------
264
+ latents : Tensor[B x N x T]
265
+ Continuous representation of input after projection
266
+
267
+ Returns
268
+ -------
269
+ Tensor[B x D x T]
270
+ Quantized representation of full-projected space
271
+ Tensor[B x D x T]
272
+ Quantized representation of latent space
273
+ """
274
+ z_q = 0
275
+ z_p = []
276
+ codes = []
277
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
278
+
279
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
280
+ 0
281
+ ]
282
+ for i in range(n_codebooks):
283
+ j, k = dims[i], dims[i + 1]
284
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
285
+ z_p.append(z_p_i)
286
+ codes.append(codes_i)
287
+
288
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
289
+ z_q = z_q + z_q_i
290
+
291
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
296
+ x = torch.randn(16, 1024, 80)
297
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
298
+ print(quantized_prompt_embeds.shape)
299
+ print(codes.shape)
300
+ # w/o reconstruction
301
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0
302
+ # w/ reconstruction
303
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compared with `descript_quantize2`, we use rvq & random_dropout
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+ import random
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+ class VectorQuantize(nn.Module):
16
+ """
17
+ Implementation of VQ similar to Karpathy's repo:
18
+ https://github.com/karpathy/deep-vector-quantization
19
+ Additionally uses following tricks from Improved VQGAN
20
+ (https://arxiv.org/pdf/2110.04627.pdf):
21
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
22
+ for improved codebook usage
23
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
24
+ improves training stability
25
+ """
26
+
27
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
28
+ super().__init__()
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+
32
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
33
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
34
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
35
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
36
+ self.stale_tolerance = stale_tolerance
37
+
38
+ def forward(self, z):
39
+ """Quantized the input tensor using a fixed codebook and returns
40
+ the corresponding codebook vectors
41
+
42
+ Parameters
43
+ ----------
44
+ z : Tensor[B x D x T]
45
+
46
+ Returns
47
+ -------
48
+ Tensor[B x D x T]
49
+ Quantized continuous representation of input
50
+ Tensor[1]
51
+ Commitment loss to train encoder to predict vectors closer to codebook
52
+ entries
53
+ Tensor[1]
54
+ Codebook loss to update the codebook
55
+ Tensor[B x T]
56
+ Codebook indices (quantized discrete representation of input)
57
+ Tensor[B x D x T]
58
+ Projected latents (continuous representation of input before quantization)
59
+ """
60
+
61
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
62
+ z_e = self.in_proj(z) # z_e : (B x D x T)
63
+ z_q, indices = self.decode_latents(z_e)
64
+
65
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
66
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
67
+
68
+ z_q = (
69
+ z_e + (z_q - z_e).detach()
70
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
71
+
72
+ z_q = self.out_proj(z_q)
73
+
74
+ return z_q, commitment_loss, codebook_loss, indices, z_e
75
+
76
+ def embed_code(self, embed_id):
77
+ return F.embedding(embed_id, self.codebook.weight)
78
+
79
+ def decode_code(self, embed_id):
80
+ return self.embed_code(embed_id).transpose(1, 2)
81
+
82
+ def decode_latents(self, latents):
83
+ encodings = rearrange(latents, "b d t -> (b t) d")
84
+ codebook = self.codebook.weight # codebook: (N x D)
85
+
86
+ # L2 normalize encodings and codebook (ViT-VQGAN)
87
+ encodings = F.normalize(encodings)
88
+ codebook = F.normalize(codebook)
89
+
90
+ # Compute euclidean distance with codebook
91
+ dist = (
92
+ encodings.pow(2).sum(1, keepdim=True)
93
+ - 2 * encodings @ codebook.t()
94
+ + codebook.pow(2).sum(1, keepdim=True).t()
95
+ )
96
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
97
+ z_q = self.decode_code(indices)
98
+
99
+ if(self.training):
100
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
101
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
102
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
103
+
104
+ # random replace codes that haven't been used for a while
105
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
106
+ if replace_code.sum(-1) > 0:
107
+ print("Replace {} codes".format(replace_code.sum(-1)))
108
+ random_input_idx = torch.randperm(encodings.shape[0])
109
+ random_input = encodings[random_input_idx].view(encodings.shape)
110
+ if random_input.shape[0] < self.codebook_size:
111
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
112
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
113
+
114
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
115
+ self.stale_counter = self.stale_counter * (1 - replace_code)
116
+
117
+ return z_q, indices
118
+
119
+
120
+ class ResidualVectorQuantize(nn.Module):
121
+ """
122
+ Introduced in SoundStream: An end2end neural audio codec
123
+ https://arxiv.org/abs/2107.03312
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ input_dim: int = 512,
129
+ n_codebooks: int = 9,
130
+ codebook_size: int = 1024,
131
+ codebook_dim: Union[int, list] = 8,
132
+ quantizer_dropout: float = 0.0,
133
+ stale_tolerance: int = 100,
134
+ ):
135
+ super().__init__()
136
+ if isinstance(codebook_dim, int):
137
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
138
+
139
+ self.n_codebooks = n_codebooks
140
+ self.codebook_dim = codebook_dim
141
+ self.codebook_size = codebook_size
142
+
143
+ self.quantizers = nn.ModuleList(
144
+ [
145
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
146
+ for i in range(n_codebooks)
147
+ ]
148
+ )
149
+ self.quantizer_dropout = quantizer_dropout
150
+
151
+ def forward(self, z, n_quantizers: int = None):
152
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
153
+ the corresponding codebook vectors
154
+ Parameters
155
+ ----------
156
+ z : Tensor[B x D x T]
157
+ n_quantizers : int, optional
158
+ No. of quantizers to use
159
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
160
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
161
+ when in training mode, and a random number of quantizers is used.
162
+ Returns
163
+ -------
164
+ dict
165
+ A dictionary with the following keys:
166
+
167
+ "z" : Tensor[B x D x T]
168
+ Quantized continuous representation of input
169
+ "codes" : Tensor[B x N x T]
170
+ Codebook indices for each codebook
171
+ (quantized discrete representation of input)
172
+ "latents" : Tensor[B x N*D x T]
173
+ Projected latents (continuous representation of input before quantization)
174
+ "vq/commitment_loss" : Tensor[1]
175
+ Commitment loss to train encoder to predict vectors closer to codebook
176
+ entries
177
+ "vq/codebook_loss" : Tensor[1]
178
+ Codebook loss to update the codebook
179
+ """
180
+ z_q = 0
181
+ residual = z
182
+ commitment_loss = 0
183
+ codebook_loss = 0
184
+
185
+ codebook_indices = []
186
+ latents = []
187
+
188
+ if n_quantizers is None:
189
+ n_quantizers = self.n_codebooks
190
+ if self.training:
191
+ random_num = random.random()
192
+ if random_num<0.6:
193
+ n_quantizers = torch.ones((z.shape[0],)) * 2
194
+ else:
195
+ n_quantizers = torch.ones((z.shape[0],)) * 4
196
+ n_quantizers = n_quantizers.to(z.device)
197
+ else:
198
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
199
+ n_quantizers = n_quantizers.to(z.device)
200
+
201
+ for i, quantizer in enumerate(self.quantizers):
202
+ # if self.training is False and i >= n_quantizers:
203
+ # break
204
+
205
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
206
+ residual
207
+ )
208
+
209
+ # Create mask to apply quantizer dropout
210
+ mask = (
211
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
212
+ )
213
+ z_q = z_q + z_q_i * mask[:, None, None]
214
+ residual = residual - z_q_i
215
+
216
+ # Sum losses
217
+ commitment_loss += (commitment_loss_i * mask).mean()
218
+ codebook_loss += (codebook_loss_i * mask).mean()
219
+
220
+ codebook_indices.append(indices_i)
221
+ latents.append(z_e_i)
222
+
223
+ codes = torch.stack(codebook_indices, dim=1)
224
+ latents = torch.cat(latents, dim=1)
225
+
226
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
227
+ # for n in range(encodings.shape[1]):
228
+ # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
229
+ # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
230
+ # ))
231
+
232
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
233
+
234
+ def from_codes(self, codes: torch.Tensor):
235
+ """Given the quantized codes, reconstruct the continuous representation
236
+ Parameters
237
+ ----------
238
+ codes : Tensor[B x N x T]
239
+ Quantized discrete representation of input
240
+ Returns
241
+ -------
242
+ Tensor[B x D x T]
243
+ Quantized continuous representation of input
244
+ """
245
+ z_q = 0.0
246
+ z_p = []
247
+ n_codebooks = codes.shape[1]
248
+ for i in range(n_codebooks):
249
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
250
+ z_p.append(z_p_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+ return z_q, torch.cat(z_p, dim=1), codes
255
+
256
+ def from_latents(self, latents: torch.Tensor):
257
+ """Given the unquantized latents, reconstruct the
258
+ continuous representation after quantization.
259
+
260
+ Parameters
261
+ ----------
262
+ latents : Tensor[B x N x T]
263
+ Continuous representation of input after projection
264
+
265
+ Returns
266
+ -------
267
+ Tensor[B x D x T]
268
+ Quantized representation of full-projected space
269
+ Tensor[B x D x T]
270
+ Quantized representation of latent space
271
+ """
272
+ z_q = 0
273
+ z_p = []
274
+ codes = []
275
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
276
+
277
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
278
+ 0
279
+ ]
280
+ for i in range(n_codebooks):
281
+ j, k = dims[i], dims[i + 1]
282
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
283
+ z_p.append(z_p_i)
284
+ codes.append(codes_i)
285
+
286
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
287
+ z_q = z_q + z_q_i
288
+
289
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
294
+ x = torch.randn(16, 1024, 80)
295
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
296
+ print(quantized_prompt_embeds.shape)
297
+ print(codes.shape)
298
+ # w/o reconstruction
299
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0
300
+ # w/ reconstruction
301
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compared with `descript_quantize2`, we use rvq & random_dropout
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+ import random
11
+
12
+ def WNConv1d(*args, **kwargs):
13
+ return weight_norm(nn.Conv1d(*args, **kwargs))
14
+
15
+ class VectorQuantize(nn.Module):
16
+ """
17
+ Implementation of VQ similar to Karpathy's repo:
18
+ https://github.com/karpathy/deep-vector-quantization
19
+ Additionally uses following tricks from Improved VQGAN
20
+ (https://arxiv.org/pdf/2110.04627.pdf):
21
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
22
+ for improved codebook usage
23
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
24
+ improves training stability
25
+ """
26
+
27
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
28
+ super().__init__()
29
+ self.codebook_size = codebook_size
30
+ self.codebook_dim = codebook_dim
31
+
32
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
33
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
34
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
35
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
36
+ self.stale_tolerance = stale_tolerance
37
+
38
+ def forward(self, z):
39
+ """Quantized the input tensor using a fixed codebook and returns
40
+ the corresponding codebook vectors
41
+
42
+ Parameters
43
+ ----------
44
+ z : Tensor[B x D x T]
45
+
46
+ Returns
47
+ -------
48
+ Tensor[B x D x T]
49
+ Quantized continuous representation of input
50
+ Tensor[1]
51
+ Commitment loss to train encoder to predict vectors closer to codebook
52
+ entries
53
+ Tensor[1]
54
+ Codebook loss to update the codebook
55
+ Tensor[B x T]
56
+ Codebook indices (quantized discrete representation of input)
57
+ Tensor[B x D x T]
58
+ Projected latents (continuous representation of input before quantization)
59
+ """
60
+
61
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
62
+ z_e = self.in_proj(z) # z_e : (B x D x T)
63
+ z_q, indices = self.decode_latents(z_e)
64
+
65
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
66
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
67
+
68
+ z_q = (
69
+ z_e + (z_q - z_e).detach()
70
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
71
+
72
+ z_q = self.out_proj(z_q)
73
+
74
+ return z_q, commitment_loss, codebook_loss, indices, z_e
75
+
76
+ def embed_code(self, embed_id):
77
+ return F.embedding(embed_id, self.codebook.weight)
78
+
79
+ def decode_code(self, embed_id):
80
+ return self.embed_code(embed_id).transpose(1, 2)
81
+
82
+ def decode_latents(self, latents):
83
+ encodings = rearrange(latents, "b d t -> (b t) d")
84
+ codebook = self.codebook.weight # codebook: (N x D)
85
+
86
+ # L2 normalize encodings and codebook (ViT-VQGAN)
87
+ encodings = F.normalize(encodings)
88
+ codebook = F.normalize(codebook)
89
+
90
+ # Compute euclidean distance with codebook
91
+ dist = (
92
+ encodings.pow(2).sum(1, keepdim=True)
93
+ - 2 * encodings @ codebook.t()
94
+ + codebook.pow(2).sum(1, keepdim=True).t()
95
+ )
96
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
97
+ z_q = self.decode_code(indices)
98
+
99
+ if(self.training):
100
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
101
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
102
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
103
+
104
+ # random replace codes that haven't been used for a while
105
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
106
+ if replace_code.sum(-1) > 0:
107
+ print("Replace {} codes".format(replace_code.sum(-1)))
108
+ random_input_idx = torch.randperm(encodings.shape[0])
109
+ random_input = encodings[random_input_idx].view(encodings.shape)
110
+ if random_input.shape[0] < self.codebook_size:
111
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
112
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
113
+
114
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
115
+ self.stale_counter = self.stale_counter * (1 - replace_code)
116
+
117
+ return z_q, indices
118
+
119
+
120
+ class ResidualVectorQuantize(nn.Module):
121
+ """
122
+ Introduced in SoundStream: An end2end neural audio codec
123
+ https://arxiv.org/abs/2107.03312
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ input_dim: int = 512,
129
+ n_codebooks: int = 9,
130
+ codebook_size: int = 1024,
131
+ codebook_dim: Union[int, list] = 8,
132
+ quantizer_dropout: float = 0.0,
133
+ stale_tolerance: int = 100,
134
+ ):
135
+ super().__init__()
136
+ if isinstance(codebook_dim, int):
137
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
138
+
139
+ self.n_codebooks = n_codebooks
140
+ self.codebook_dim = codebook_dim
141
+ self.codebook_size = codebook_size
142
+
143
+ self.quantizers = nn.ModuleList(
144
+ [
145
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
146
+ for i in range(n_codebooks)
147
+ ]
148
+ )
149
+ self.quantizer_dropout = quantizer_dropout
150
+
151
+ def forward(self, z, n_quantizers: int = None):
152
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
153
+ the corresponding codebook vectors
154
+ Parameters
155
+ ----------
156
+ z : Tensor[B x D x T]
157
+ n_quantizers : int, optional
158
+ No. of quantizers to use
159
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
160
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
161
+ when in training mode, and a random number of quantizers is used.
162
+ Returns
163
+ -------
164
+ dict
165
+ A dictionary with the following keys:
166
+
167
+ "z" : Tensor[B x D x T]
168
+ Quantized continuous representation of input
169
+ "codes" : Tensor[B x N x T]
170
+ Codebook indices for each codebook
171
+ (quantized discrete representation of input)
172
+ "latents" : Tensor[B x N*D x T]
173
+ Projected latents (continuous representation of input before quantization)
174
+ "vq/commitment_loss" : Tensor[1]
175
+ Commitment loss to train encoder to predict vectors closer to codebook
176
+ entries
177
+ "vq/codebook_loss" : Tensor[1]
178
+ Codebook loss to update the codebook
179
+ """
180
+ z_q = 0
181
+ residual = z
182
+ commitment_loss = 0
183
+ codebook_loss = 0
184
+ layer = self.n_codebooks
185
+ codebook_indices = []
186
+ latents = []
187
+
188
+ if n_quantizers is None:
189
+ n_quantizers = self.n_codebooks
190
+ if self.training:
191
+ random_num = random.random()
192
+ if random_num<0.6:
193
+ n_quantizers = torch.ones((z.shape[0],)) * 1
194
+ elif random_num<0.8:
195
+ n_quantizers = torch.ones((z.shape[0],)) * 2
196
+ layer = 2
197
+ else:
198
+ n_quantizers = torch.ones((z.shape[0],)) * 4
199
+ layer = 4
200
+ n_quantizers = n_quantizers.to(z.device)
201
+ else:
202
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
203
+ n_quantizers = n_quantizers.to(z.device)
204
+
205
+ for i, quantizer in enumerate(self.quantizers):
206
+ # if self.training is False and i >= n_quantizers:
207
+ # break
208
+
209
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
210
+ residual
211
+ )
212
+
213
+ # Create mask to apply quantizer dropout
214
+ mask = (
215
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
216
+ )
217
+ z_q = z_q + z_q_i * mask[:, None, None]
218
+ residual = residual - z_q_i
219
+
220
+ # Sum losses
221
+ commitment_loss += (commitment_loss_i * mask).mean()
222
+ codebook_loss += (codebook_loss_i * mask).mean()
223
+
224
+ codebook_indices.append(indices_i)
225
+ latents.append(z_e_i)
226
+
227
+ codes = torch.stack(codebook_indices, dim=1)
228
+ latents = torch.cat(latents, dim=1)
229
+
230
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
231
+ for n in range(encodings.shape[1]):
232
+ print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
233
+ (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
234
+ ))
235
+
236
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1,layer
237
+
238
+ def from_codes(self, codes: torch.Tensor):
239
+ """Given the quantized codes, reconstruct the continuous representation
240
+ Parameters
241
+ ----------
242
+ codes : Tensor[B x N x T]
243
+ Quantized discrete representation of input
244
+ Returns
245
+ -------
246
+ Tensor[B x D x T]
247
+ Quantized continuous representation of input
248
+ """
249
+ z_q = 0.0
250
+ z_p = []
251
+ n_codebooks = codes.shape[1]
252
+ for i in range(n_codebooks):
253
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
254
+ z_p.append(z_p_i)
255
+
256
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
257
+ z_q = z_q + z_q_i
258
+ return z_q, torch.cat(z_p, dim=1), codes
259
+
260
+ def from_latents(self, latents: torch.Tensor):
261
+ """Given the unquantized latents, reconstruct the
262
+ continuous representation after quantization.
263
+
264
+ Parameters
265
+ ----------
266
+ latents : Tensor[B x N x T]
267
+ Continuous representation of input after projection
268
+
269
+ Returns
270
+ -------
271
+ Tensor[B x D x T]
272
+ Quantized representation of full-projected space
273
+ Tensor[B x D x T]
274
+ Quantized representation of latent space
275
+ """
276
+ z_q = 0
277
+ z_p = []
278
+ codes = []
279
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
280
+
281
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
282
+ 0
283
+ ]
284
+ for i in range(n_codebooks):
285
+ j, k = dims[i], dims[i + 1]
286
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
287
+ z_p.append(z_p_i)
288
+ codes.append(codes_i)
289
+
290
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
291
+ z_q = z_q + z_q_i
292
+
293
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
294
+
295
+
296
+ if __name__ == "__main__":
297
+ rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
298
+ x = torch.randn(16, 1024, 80)
299
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
300
+ print(quantized_prompt_embeds.shape)
301
+ print(codes.shape)
302
+ # w/o reconstruction
303
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0
304
+ # w/ reconstruction
305
+ loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()