Spaces:
Running
on
L40S
Running
on
L40S
hainazhu
commited on
Commit
·
258fd02
1
Parent(s):
51fab49
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +10 -33
- .gitignore +3 -0
- Dockerfile +13 -0
- LICENSE +211 -0
- README.md +63 -6
- app.py +140 -0
- codeclm/models/__init__.py +11 -0
- codeclm/models/builders.py +139 -0
- codeclm/models/codeclm.py +303 -0
- codeclm/models/levo.py +224 -0
- codeclm/models/llama/__init__.py +90 -0
- codeclm/models/llama/configuration_llama.py +182 -0
- codeclm/models/llama/convert_llama_weights_to_hf.py +318 -0
- codeclm/models/llama/modeling_llama.py +1243 -0
- codeclm/models/llama/tokenization_llama.py +426 -0
- codeclm/models/llama/tokenization_llama_fast.py +264 -0
- codeclm/models/lm_levo.py +546 -0
- codeclm/modules/conditioners.py +883 -0
- codeclm/modules/pattern.py +351 -0
- codeclm/modules/streaming.py +112 -0
- codeclm/tokenizer/Flow1dVAE/audio.py +304 -0
- codeclm/tokenizer/Flow1dVAE/cal_token_stat.py +19 -0
- codeclm/tokenizer/Flow1dVAE/compare_model_weight.py +13 -0
- codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json +26 -0
- codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json +14 -0
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py +121 -0
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py +94 -0
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py +70 -0
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py +46 -0
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py +86 -0
- codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +283 -0
- codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +294 -0
- codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +293 -0
- codeclm/tokenizer/Flow1dVAE/generate_septoken.py +302 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py +1278 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py +372 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py +830 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py +994 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py +313 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py +313 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py +313 -0
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py +461 -0
- codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py +236 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py +366 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py +268 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py +290 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py +299 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py +303 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py +301 -0
- codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py +305 -0
.gitattributes
CHANGED
@@ -1,35 +1,12 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
*.
|
33 |
-
*.
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
third_party/demucs/ckpt/htdemucs.pth filter=lfs diff=lfs merge=lfs -text
|
2 |
+
ckpt/100000_dpo.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
4 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
5 |
+
ckpt/vae/autoencoder_music_1320k.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text
|
7 |
+
codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.onnx filter=lfs diff=lfs merge=lfs -text
|
8 |
+
codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.pt filter=lfs diff=lfs merge=lfs -text
|
9 |
+
third_party/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
10 |
+
ckpt/60000_alnew.pt filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
launchs/
|
2 |
+
**/__pycache__
|
3 |
+
sample/generated/
|
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM juhayna/song-generation-levo:v0.1
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
COPY --chown=user . /app
|
13 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making SongGeneration available.
|
2 |
+
|
3 |
+
Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
|
4 |
+
|
5 |
+
SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
6 |
+
|
7 |
+
|
8 |
+
License Terms of SongGeneration:
|
9 |
+
--------------------------------------------------------------------
|
10 |
+
|
11 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
14 |
+
|
15 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components.
|
18 |
+
|
19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
20 |
+
|
21 |
+
|
22 |
+
Other dependencies and licenses:
|
23 |
+
|
24 |
+
|
25 |
+
Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein:
|
26 |
+
--------------------------------------------------------------------
|
27 |
+
1. stable_audio_tools
|
28 |
+
Copyright (c) 2023 Stability AI
|
29 |
+
|
30 |
+
|
31 |
+
Terms of the MIT:
|
32 |
+
--------------------------------------------------------------------
|
33 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
34 |
+
|
35 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
36 |
+
|
37 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
38 |
+
|
39 |
+
For the license of other third party components, please refer to the following URL:
|
40 |
+
https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES
|
41 |
+
|
42 |
+
|
43 |
+
Open Source Software Licensed under the MIT License:
|
44 |
+
--------------------------------------------------------------------
|
45 |
+
1. demucs
|
46 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
47 |
+
|
48 |
+
|
49 |
+
A copy of the MIT is included in this file.
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
54 |
+
--------------------------------------------------------------------
|
55 |
+
1. torch
|
56 |
+
From PyTorch:
|
57 |
+
|
58 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
59 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
60 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
61 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
62 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
63 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
64 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
65 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
66 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
67 |
+
|
68 |
+
From Caffe2:
|
69 |
+
|
70 |
+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
71 |
+
|
72 |
+
All contributions by Facebook:
|
73 |
+
Copyright (c) 2016 Facebook Inc.
|
74 |
+
|
75 |
+
All contributions by Google:
|
76 |
+
Copyright (c) 2015 Google Inc.
|
77 |
+
All rights reserved.
|
78 |
+
|
79 |
+
All contributions by Yangqing Jia:
|
80 |
+
Copyright (c) 2015 Yangqing Jia
|
81 |
+
All rights reserved.
|
82 |
+
|
83 |
+
All contributions by Kakao Brain:
|
84 |
+
Copyright 2019-2020 Kakao Brain
|
85 |
+
|
86 |
+
All contributions by Cruise LLC:
|
87 |
+
Copyright (c) 2022 Cruise LLC.
|
88 |
+
All rights reserved.
|
89 |
+
|
90 |
+
All contributions from Caffe:
|
91 |
+
Copyright(c) 2013, 2014, 2015, the respective contributors
|
92 |
+
All rights reserved.
|
93 |
+
|
94 |
+
All other contributions:
|
95 |
+
Copyright(c) 2015, 2016 the respective contributors
|
96 |
+
All rights reserved.
|
97 |
+
|
98 |
+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
99 |
+
copyright over their contributions to Caffe2. The project versioning records
|
100 |
+
all such contribution and copyright details. If a contributor wants to further
|
101 |
+
mark their specific copyright on a particular contribution, they should
|
102 |
+
indicate their copyright solely in the commit message of the change when it is
|
103 |
+
committed.
|
104 |
+
|
105 |
+
All rights reserved.
|
106 |
+
|
107 |
+
|
108 |
+
Terms of the BSD 3-Clause:
|
109 |
+
--------------------------------------------------------------------
|
110 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
111 |
+
|
112 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
113 |
+
|
114 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
115 |
+
|
116 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
117 |
+
|
118 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
119 |
+
|
120 |
+
For the license of other third party components, please refer to the following URL:
|
121 |
+
https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE
|
122 |
+
|
123 |
+
|
124 |
+
Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein:
|
125 |
+
--------------------------------------------------------------------
|
126 |
+
1. torchaudio
|
127 |
+
Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
128 |
+
All rights reserved.
|
129 |
+
|
130 |
+
|
131 |
+
Terms of the BSD 2-Clause:
|
132 |
+
--------------------------------------------------------------------
|
133 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
134 |
+
|
135 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
136 |
+
|
137 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
138 |
+
|
139 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
140 |
+
|
141 |
+
For the license of other third party components, please refer to the following URL:
|
142 |
+
https://github.com/pytorch/audio/blob/v2.0.2/LICENSE
|
143 |
+
|
144 |
+
|
145 |
+
Open Source Software License under the Apache License Version 2.0:
|
146 |
+
--------------------------------------------------------------------
|
147 |
+
1. huggingface-hub
|
148 |
+
Copyright (c) huggingface-hub original author and authors
|
149 |
+
|
150 |
+
2. transformers
|
151 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
152 |
+
|
153 |
+
|
154 |
+
Terms of the Apache License Version 2.0:
|
155 |
+
--------------------------------------------------------------------
|
156 |
+
Apache License
|
157 |
+
|
158 |
+
Version 2.0, January 2004
|
159 |
+
|
160 |
+
http://www.apache.org/licenses/
|
161 |
+
|
162 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
163 |
+
1. Definitions.
|
164 |
+
|
165 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
166 |
+
|
167 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
168 |
+
|
169 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
170 |
+
|
171 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
172 |
+
|
173 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
174 |
+
|
175 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
176 |
+
|
177 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
178 |
+
|
179 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
180 |
+
|
181 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
182 |
+
|
183 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
184 |
+
|
185 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
186 |
+
|
187 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
188 |
+
|
189 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
190 |
+
|
191 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
192 |
+
|
193 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
194 |
+
|
195 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
196 |
+
|
197 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
198 |
+
|
199 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
200 |
+
|
201 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
202 |
+
|
203 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
204 |
+
|
205 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
206 |
+
|
207 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
208 |
+
|
209 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
210 |
+
|
211 |
+
END OF TERMS AND CONDITIONS
|
README.md
CHANGED
@@ -1,11 +1,68 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
-
|
8 |
-
short_description: Demo interface for the LeVo song generation model.
|
9 |
---
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: LeVo Song Generation
|
3 |
+
emoji: 🎵
|
4 |
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
sdk: docker
|
7 |
+
app_port: 7860
|
|
|
8 |
---
|
9 |
|
10 |
+
|
11 |
+
# SongGeration:
|
12 |
+
|
13 |
+
This repository is the official code repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. You can find our paper on [here](https://arxiv.org/). The demo page is available [here](https://levo-demo.github.io/).
|
14 |
+
|
15 |
+
In this repository, we provide the SongGeration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the SFT + auto-DPO version.
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
## Start from scatch
|
20 |
+
You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
then install flash attention from wget
|
27 |
+
|
28 |
+
```bash
|
29 |
+
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl -P /home/
|
30 |
+
pip install /home/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
31 |
+
```
|
32 |
+
|
33 |
+
## Start with docker
|
34 |
+
```bash
|
35 |
+
docker pull juhayna/song-generation-levo:v0.1
|
36 |
+
docker run -it --gpus all --network=host juhayna/song-generation-levo:v0.1 /bin/bash
|
37 |
+
```
|
38 |
+
|
39 |
+
## Inference
|
40 |
+
|
41 |
+
Please note that all the two folder below must be downloaded completely for the model to load correctly, which is sourced from [here](https://huggingface.co/waytan22/SongGeneration)
|
42 |
+
|
43 |
+
- Save `ckpt` to the root directory
|
44 |
+
- Save `third_party` to the root directory
|
45 |
+
|
46 |
+
Then run inference, use the following command:
|
47 |
+
|
48 |
+
```bash
|
49 |
+
sh generate.sh sample/lyric.jsonl sample/generate
|
50 |
+
```
|
51 |
+
- Input keys in the `sample/lyric.jsonl`
|
52 |
+
- `idx`: name of the generate song file
|
53 |
+
- `descriptions`: text description, can be None or specified gender, timbre, genre, mood, instrument and BPM
|
54 |
+
- `prompt_audio_path`: reference audio path, can be None or 10s song audio path
|
55 |
+
- `gt_lyric`: lyrics, it needs to follow the format of '\[Structure\] Text', supported structures can be found in `conf/vocab.yaml`
|
56 |
+
|
57 |
+
- Outputs of the loader `sample/generate`:
|
58 |
+
- `audio`: generated audio files
|
59 |
+
- `jsonl`: output jsonls
|
60 |
+
- `token`: Token corresponding to the generated audio files
|
61 |
+
|
62 |
+
## Note
|
63 |
+
|
64 |
+
Since the model is trained based on data longer than 1 minute, if the given lyrics are too short, the model will automatically fill in the lyrics to extend the duration.
|
65 |
+
|
66 |
+
## License
|
67 |
+
|
68 |
+
The code and weights in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
|
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
from datetime import datetime
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import librosa
|
9 |
+
|
10 |
+
EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
|
11 |
+
EXAMPLE_LYRICS = """
|
12 |
+
[intro-short]
|
13 |
+
|
14 |
+
[verse]
|
15 |
+
夜晚的街灯闪烁.
|
16 |
+
我漫步在熟悉的角落.
|
17 |
+
回忆像潮水般涌来.
|
18 |
+
你的笑容如此清晰.
|
19 |
+
在心头无法抹去.
|
20 |
+
那些曾经的甜蜜.
|
21 |
+
如今只剩我独自回忆.
|
22 |
+
|
23 |
+
[bridge]
|
24 |
+
手机屏幕亮起.
|
25 |
+
是你发来的消息.
|
26 |
+
简单的几个字.
|
27 |
+
却让我泪流满面.
|
28 |
+
曾经的拥抱温暖.
|
29 |
+
如今却变得遥远.
|
30 |
+
我多想回到从前.
|
31 |
+
重新拥有你的陪伴.
|
32 |
+
|
33 |
+
[chorus]
|
34 |
+
回忆的温度还在.
|
35 |
+
你却已不在.
|
36 |
+
我的心被爱填满.
|
37 |
+
却又被思念刺痛.
|
38 |
+
R&B的节奏奏响.
|
39 |
+
我的心却在流浪.
|
40 |
+
没有你的日子.
|
41 |
+
我该如何继续向前.
|
42 |
+
|
43 |
+
[outro-short]
|
44 |
+
""".strip()
|
45 |
+
|
46 |
+
|
47 |
+
# 模拟歌曲生成函数
|
48 |
+
def generate_song(description, lyric, prompt_audio=None):
|
49 |
+
# 这里模拟生成过程 - 实际应用中替换为你的模型调用
|
50 |
+
print(f"Generating song with description: {description}")
|
51 |
+
print(f"Lyrics provided: {lyric}")
|
52 |
+
if prompt_audio is not None:
|
53 |
+
print("Using prompt audio for generation")
|
54 |
+
|
55 |
+
# 从文件中加载示例音频
|
56 |
+
audio_path = "./sample/example.mp3"
|
57 |
+
audio_data, sample_rate = librosa.load(audio_path, sr=None) # 保持原始采样率
|
58 |
+
|
59 |
+
|
60 |
+
# 创建输入配置的JSON
|
61 |
+
input_config = {
|
62 |
+
"description": description,
|
63 |
+
"lyric": lyric,
|
64 |
+
"has_prompt_audio": prompt_audio is not None,
|
65 |
+
"timestamp": datetime.now().isoformat(),
|
66 |
+
}
|
67 |
+
|
68 |
+
return (sample_rate, audio_data), json.dumps(input_config, indent=2)
|
69 |
+
|
70 |
+
# 创建Gradio界面
|
71 |
+
with gr.Blocks(title="LeVo Demo Space") as demo:
|
72 |
+
gr.Markdown("# 🎵 LeVo Demo Space")
|
73 |
+
gr.Markdown("Demo interface for the LeVo song generation model. Provide a description, lyrics, and optionally an audio prompt to generate a custom song.")
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Column():
|
77 |
+
description = gr.Textbox(
|
78 |
+
label="Song Description",
|
79 |
+
placeholder="Describe the style, mood, and characteristics of the song...",
|
80 |
+
lines=1,
|
81 |
+
max_lines=2,
|
82 |
+
value=EXAMPLE_DESC,
|
83 |
+
)
|
84 |
+
lyric = gr.Textbox(
|
85 |
+
label="Lyrics",
|
86 |
+
placeholder="Enter the lyrics for the song...",
|
87 |
+
lines=5,
|
88 |
+
max_lines=8,
|
89 |
+
value=EXAMPLE_LYRICS,
|
90 |
+
)
|
91 |
+
|
92 |
+
with gr.Tabs(elem_id="extra-tabs"):
|
93 |
+
with gr.Tab("Audio Prompt"):
|
94 |
+
prompt_audio = gr.Audio(
|
95 |
+
label="Prompt Audio (Optional)",
|
96 |
+
type="filepath",
|
97 |
+
elem_id="audio-prompt"
|
98 |
+
)
|
99 |
+
with gr.Tab("Advanced Config"):
|
100 |
+
text_prompt = gr.Textbox(
|
101 |
+
label="Text Prompt",
|
102 |
+
placeholder="Enter the Text Prompt, eg: emotional piano pop",
|
103 |
+
)
|
104 |
+
|
105 |
+
generate_btn = gr.Button("Generate Song", variant="primary")
|
106 |
+
|
107 |
+
with gr.Column():
|
108 |
+
output_audio = gr.Audio(label="Generated Song", type="numpy")
|
109 |
+
output_json = gr.JSON(label="Input Configuration")
|
110 |
+
|
111 |
+
# 示例按钮
|
112 |
+
examples = gr.Examples(
|
113 |
+
examples=[
|
114 |
+
["An uplifting pop song with catchy melodies"],
|
115 |
+
["Melancholic piano ballad"],
|
116 |
+
],
|
117 |
+
inputs=[description],
|
118 |
+
label="Description examples"
|
119 |
+
)
|
120 |
+
|
121 |
+
examples = gr.Examples(
|
122 |
+
examples=[
|
123 |
+
["Shine bright like the stars above\nYou're the one that I'm dreaming of"],
|
124 |
+
["The rain keeps falling on my window pane\nReminding me of love that's gone away"],
|
125 |
+
],
|
126 |
+
inputs=[lyric],
|
127 |
+
label="Lyrics examples"
|
128 |
+
)
|
129 |
+
|
130 |
+
# 生成按钮点击事件
|
131 |
+
generate_btn.click(
|
132 |
+
fn=generate_song,
|
133 |
+
inputs=[description, lyric, prompt_audio],
|
134 |
+
outputs=[output_audio, output_json]
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
# 启动应用
|
139 |
+
if __name__ == "__main__":
|
140 |
+
demo.launch()
|
codeclm/models/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
|
8 |
+
"""
|
9 |
+
# flake8: noqa
|
10 |
+
from . import builders
|
11 |
+
from .codeclm import CodecLM
|
codeclm/models/builders.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
All the functions to build the relevant models and modules
|
3 |
+
from the Hydra config.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import typing as tp
|
7 |
+
|
8 |
+
import omegaconf
|
9 |
+
import torch
|
10 |
+
from codeclm.utils.utils import dict_from_config
|
11 |
+
from codeclm.modules.pattern import (
|
12 |
+
CodebooksPatternProvider,
|
13 |
+
DelayedPatternProvider,
|
14 |
+
)
|
15 |
+
from codeclm.modules.conditioners import (
|
16 |
+
BaseConditioner,
|
17 |
+
QwTokenizerConditioner,
|
18 |
+
QwTextConditioner,
|
19 |
+
PhonemeTokenizerConditioner,
|
20 |
+
QuantizedEmbeddingConditioner,
|
21 |
+
ConditionerProvider,
|
22 |
+
ConditionFuser,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig):
|
27 |
+
from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
|
28 |
+
"""Instantiate a compression model."""
|
29 |
+
if checkpoint_path is None:
|
30 |
+
return None
|
31 |
+
if checkpoint_path.startswith('//pretrained/'):
|
32 |
+
name = checkpoint_path.split('/', 3)[-1]
|
33 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
|
34 |
+
elif checkpoint_path == "":
|
35 |
+
return None
|
36 |
+
else:
|
37 |
+
name = checkpoint_path
|
38 |
+
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode)
|
39 |
+
|
40 |
+
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
41 |
+
"""Instantiate a LM."""
|
42 |
+
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
43 |
+
|
44 |
+
# n_q: number of RVQ
|
45 |
+
code_depth = lm_kwargs['code_depth']
|
46 |
+
q_modeling = lm_kwargs.pop('q_modeling', None)
|
47 |
+
|
48 |
+
# conditioner
|
49 |
+
condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
|
50 |
+
|
51 |
+
# codebook pattern: delay
|
52 |
+
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
53 |
+
if codebooks_pattern_cfg.modeling is None:
|
54 |
+
assert q_modeling is not None, \
|
55 |
+
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
56 |
+
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
57 |
+
{'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}}
|
58 |
+
)
|
59 |
+
pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg)
|
60 |
+
|
61 |
+
# condition dropout
|
62 |
+
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
63 |
+
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
64 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
65 |
+
|
66 |
+
# condition fuser
|
67 |
+
fuser = get_condition_fuser(cfg)
|
68 |
+
lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type
|
69 |
+
if lm_type == 'Llama':
|
70 |
+
from .lm_levo import LmModel
|
71 |
+
return LmModel(
|
72 |
+
pattern_provider=pattern_provider,
|
73 |
+
condition_provider=condition_provider,
|
74 |
+
fuser=fuser,
|
75 |
+
cfg_dropout=cfg_prob,
|
76 |
+
cfg_coef=cfg_coef,
|
77 |
+
attribute_dropout=attribute_dropout,
|
78 |
+
cfg=cfg,
|
79 |
+
**lm_kwargs
|
80 |
+
).to('cpu')
|
81 |
+
else:
|
82 |
+
raise KeyError(f"Unexpected LM model {lm_type}")
|
83 |
+
|
84 |
+
|
85 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
|
86 |
+
"""Instantiate a conditioning model."""
|
87 |
+
cfg = getattr(cfg, 'conditioners')
|
88 |
+
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
89 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
90 |
+
condition_provider_args = dict_cfg.pop('args', {})
|
91 |
+
|
92 |
+
for cond, cond_cfg in dict_cfg.items():
|
93 |
+
model_type = cond_cfg['model']
|
94 |
+
model_args = cond_cfg[model_type]
|
95 |
+
if model_type == 'QwTokenizer':
|
96 |
+
conditioners[str(cond)] = QwTokenizerConditioner(
|
97 |
+
output_dim=output_dim,
|
98 |
+
**model_args
|
99 |
+
)
|
100 |
+
elif model_type == "QwTextTokenizer":
|
101 |
+
conditioners[str(cond)] = QwTextConditioner(
|
102 |
+
output_dim=output_dim,
|
103 |
+
**model_args
|
104 |
+
)
|
105 |
+
elif model_type == 'PhonemeTokenizer':
|
106 |
+
conditioners[str(cond)] = PhonemeTokenizerConditioner(
|
107 |
+
output_dim=output_dim,
|
108 |
+
**model_args
|
109 |
+
)
|
110 |
+
elif model_type == "qt_embedding":
|
111 |
+
conditioners[str(cond)] = QuantizedEmbeddingConditioner(
|
112 |
+
dim=output_dim,
|
113 |
+
**model_args
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
117 |
+
conditioner = ConditionerProvider(conditioners, **condition_provider_args)
|
118 |
+
return conditioner
|
119 |
+
|
120 |
+
|
121 |
+
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
122 |
+
"""Instantiate a condition fuser object."""
|
123 |
+
fuser_cfg = getattr(cfg, 'fuser')
|
124 |
+
fuser_methods = ['sum', 'prepend']
|
125 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
126 |
+
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
127 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
128 |
+
return fuser
|
129 |
+
|
130 |
+
|
131 |
+
def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
132 |
+
"""Instantiate a codebooks pattern provider object."""
|
133 |
+
pattern_providers = {
|
134 |
+
'delay': DelayedPatternProvider,
|
135 |
+
}
|
136 |
+
name = cfg.modeling
|
137 |
+
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
138 |
+
klass = pattern_providers[name]
|
139 |
+
return klass(code_depth, **kwargs)
|
codeclm/models/codeclm.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main model for using CodecLM. This will combine all the required components
|
3 |
+
and provide easy access to the generation API.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import typing as tp
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
|
12 |
+
from .lm_levo import LmModel
|
13 |
+
from ..modules.conditioners import ConditioningAttributes, AudioCondition
|
14 |
+
from ..utils.autocast import TorchAutocast
|
15 |
+
import torch
|
16 |
+
from torch.nn import functional as F
|
17 |
+
import torchaudio
|
18 |
+
# from optim.ema import EMA
|
19 |
+
|
20 |
+
|
21 |
+
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
22 |
+
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
23 |
+
|
24 |
+
class CodecLM:
|
25 |
+
"""CodecLM main model with convenient generation API.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
name (str): name of the model.
|
29 |
+
compression_model (CompressionModel): Compression model
|
30 |
+
used to map audio to invertible discrete representations.
|
31 |
+
lm (LMModel): Language model over discrete representations.
|
32 |
+
max_duration (float, optional): maximum duration the model can produce,
|
33 |
+
otherwise, inferred from the training params.
|
34 |
+
"""
|
35 |
+
def __init__(self, name: str, audiotokenizer: AudioTokenizer, lm: LmModel,
|
36 |
+
max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None):
|
37 |
+
self.name = name
|
38 |
+
self.audiotokenizer = audiotokenizer
|
39 |
+
self.lm = lm
|
40 |
+
self.seperate_tokenizer = seperate_tokenizer
|
41 |
+
# import pdb; pdb.set_trace()
|
42 |
+
if max_duration is None:
|
43 |
+
if hasattr(lm, 'cfg'):
|
44 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
45 |
+
else:
|
46 |
+
raise ValueError("You must provide max_duration when building directly CodecLM")
|
47 |
+
assert max_duration is not None
|
48 |
+
|
49 |
+
self.max_duration: float = max_duration
|
50 |
+
self.device = next(iter(lm.parameters())).device
|
51 |
+
self.generation_params: dict = {}
|
52 |
+
# self.set_generation_params(duration=15) # 15 seconds by default
|
53 |
+
self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
|
54 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
55 |
+
if self.device.type == 'cpu':
|
56 |
+
self.autocast = TorchAutocast(enabled=False)
|
57 |
+
else:
|
58 |
+
self.autocast = TorchAutocast(enabled=False)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
@property
|
63 |
+
def frame_rate(self) -> float:
|
64 |
+
"""Roughly the number of AR steps per seconds."""
|
65 |
+
return self.audiotokenizer.frame_rate
|
66 |
+
|
67 |
+
@property
|
68 |
+
def sample_rate(self) -> int:
|
69 |
+
"""Sample rate of the generated audio."""
|
70 |
+
return self.audiotokenizer.sample_rate
|
71 |
+
|
72 |
+
@property
|
73 |
+
def audio_channels(self) -> int:
|
74 |
+
"""Audio channels of the generated audio."""
|
75 |
+
return self.audiotokenizer.channels
|
76 |
+
|
77 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
78 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
79 |
+
duration: float = 30.0, cfg_coef: float = 3.0,
|
80 |
+
extend_stride: float = 18, record_tokens: bool = False,
|
81 |
+
record_window: int = 50):
|
82 |
+
"""Set the generation parameters for CodecLM.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
86 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
87 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
88 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
89 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
90 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
91 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
92 |
+
instead of batching together the two. This has some impact on how things
|
93 |
+
are padded but seems to have little impact in practice.
|
94 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
95 |
+
should we extend the audio each time. Larger values will mean less context is
|
96 |
+
preserved, and shorter value will require extra computations.
|
97 |
+
"""
|
98 |
+
assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration."
|
99 |
+
self.extend_stride = extend_stride
|
100 |
+
self.duration = duration
|
101 |
+
self.generation_params = {
|
102 |
+
'use_sampling': use_sampling,
|
103 |
+
'temp': temperature,
|
104 |
+
'top_k': top_k,
|
105 |
+
'top_p': top_p,
|
106 |
+
'cfg_coef': cfg_coef,
|
107 |
+
'record_tokens': record_tokens,
|
108 |
+
'record_window': record_window,
|
109 |
+
}
|
110 |
+
|
111 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
112 |
+
"""Override the default progress callback."""
|
113 |
+
self._progress_callback = progress_callback
|
114 |
+
|
115 |
+
# Inference
|
116 |
+
def generate(self, lyrics: tp.List[str],
|
117 |
+
descriptions: tp.List[str],
|
118 |
+
melody_wavs: torch.Tensor = None,
|
119 |
+
melody_is_wav: bool = True,
|
120 |
+
vocal_wavs: torch.Tensor = None,
|
121 |
+
bgm_wavs: torch.Tensor = None,
|
122 |
+
return_tokens: bool = False,
|
123 |
+
) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
124 |
+
"""Generate samples conditioned on text and melody.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
128 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
129 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
130 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
131 |
+
a list of [C, T] tensors.
|
132 |
+
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
133 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
134 |
+
"""
|
135 |
+
if melody_wavs is not None:
|
136 |
+
if melody_wavs.dim() == 2:
|
137 |
+
melody_wavs = melody_wavs[None]
|
138 |
+
if melody_wavs.dim() != 3:
|
139 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
140 |
+
melody_wavs = list(melody_wavs)
|
141 |
+
if vocal_wavs is not None:
|
142 |
+
if vocal_wavs.dim() == 2:
|
143 |
+
vocal_wavs = vocal_wavs[None]
|
144 |
+
if vocal_wavs.dim() != 3:
|
145 |
+
raise ValueError("Vocal wavs should have a shape [B, C, T].")
|
146 |
+
vocal_wavs = list(vocal_wavs)
|
147 |
+
if bgm_wavs is not None:
|
148 |
+
if bgm_wavs.dim() == 2:
|
149 |
+
bgm_wavs = bgm_wavs[None]
|
150 |
+
if bgm_wavs.dim() != 3:
|
151 |
+
raise ValueError("BGM wavs should have a shape [B, C, T].")
|
152 |
+
bgm_wavs = list(bgm_wavs)
|
153 |
+
|
154 |
+
texts, audio_qt_embs = self._prepare_tokens_and_attributes(lyrics=lyrics, melody_wavs=melody_wavs, vocal_wavs=vocal_wavs, bgm_wavs=bgm_wavs, melody_is_wav=melody_is_wav)
|
155 |
+
tokens = self._generate_tokens(texts, descriptions, audio_qt_embs)
|
156 |
+
|
157 |
+
if (tokens == self.lm.eos_token_id).any():
|
158 |
+
length = torch.nonzero(torch.eq(tokens, self.lm.eos_token_id))[:,-1].min()
|
159 |
+
tokens = tokens[...,:length]
|
160 |
+
|
161 |
+
if return_tokens:
|
162 |
+
return tokens
|
163 |
+
else:
|
164 |
+
out = self.generate_audio(tokens)
|
165 |
+
return out
|
166 |
+
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def _prepare_tokens_and_attributes(
|
170 |
+
self,
|
171 |
+
lyrics: tp.Sequence[tp.Optional[str]],
|
172 |
+
melody_wavs: tp.Optional[MelodyList] = None,
|
173 |
+
vocal_wavs: tp.Optional[MelodyList] = None,
|
174 |
+
bgm_wavs: tp.Optional[MelodyList] = None,
|
175 |
+
melody_is_wav = True
|
176 |
+
) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]:
|
177 |
+
"""Prepare model inputs.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
181 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
182 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
183 |
+
used as melody conditioning. Defaults to None.
|
184 |
+
"""
|
185 |
+
assert len(lyrics) == 1
|
186 |
+
texts = [lyric for lyric in lyrics]
|
187 |
+
audio_qt_embs = []
|
188 |
+
target_melody_token_len = self.lm.cfg.prompt_len * self.audiotokenizer.frame_rate
|
189 |
+
# import pdb; pdb.set_trace()
|
190 |
+
if melody_wavs is None:
|
191 |
+
melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
192 |
+
elif melody_wavs is not None:
|
193 |
+
if 'prompt_audio' not in self.lm.condition_provider.conditioners:
|
194 |
+
raise RuntimeError("This model doesn't support melody conditioning. "
|
195 |
+
"Use the `melody` model.")
|
196 |
+
assert len(melody_wavs) == len(texts), \
|
197 |
+
f"number of melody wavs must match number of descriptions! " \
|
198 |
+
f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}"
|
199 |
+
if type(melody_wavs) == list:
|
200 |
+
melody_wavs = torch.stack(melody_wavs, dim=0)
|
201 |
+
melody_wavs = melody_wavs.to(self.device)
|
202 |
+
if melody_is_wav:
|
203 |
+
melody_tokens, scale = self.audiotokenizer.encode(melody_wavs)
|
204 |
+
else:
|
205 |
+
melody_tokens = melody_wavs
|
206 |
+
if melody_tokens.shape[-1] > target_melody_token_len:
|
207 |
+
melody_tokens = melody_tokens[...,:target_melody_token_len]
|
208 |
+
elif melody_tokens.shape[-1] < target_melody_token_len:
|
209 |
+
melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
210 |
+
if self.seperate_tokenizer is not None:
|
211 |
+
if vocal_wavs is not None:
|
212 |
+
if type(vocal_wavs) == list:
|
213 |
+
vocal_wavs = torch.stack(vocal_wavs, dim=0)
|
214 |
+
if bgm_wavs is None:
|
215 |
+
use_bgm = False
|
216 |
+
bgm_wavs = torch.zeros_like(vocal_wavs)
|
217 |
+
bgm_wavs[:, 0] = 1.0
|
218 |
+
bgm_wavs[:, 1:] = torch.randn_like(bgm_wavs[:, 1:])* 0.0003
|
219 |
+
else:
|
220 |
+
use_bgm = True
|
221 |
+
if type(bgm_wavs) == list:
|
222 |
+
bgm_wavs = torch.stack(bgm_wavs, dim=0)
|
223 |
+
vocal_wavs = vocal_wavs.to(self.device)
|
224 |
+
bgm_wavs = bgm_wavs.to(self.device)
|
225 |
+
vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
|
226 |
+
assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
|
227 |
+
f"vocal and bgm tokens should have a shape [B, C, T]! " \
|
228 |
+
f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
|
229 |
+
assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
|
230 |
+
f"vocal and bgm tokens should have the same length! " \
|
231 |
+
f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
|
232 |
+
if not use_bgm:
|
233 |
+
bgm_tokens = torch.full_like(bgm_tokens, 16385)
|
234 |
+
if bgm_tokens.shape[-1] > target_melody_token_len:
|
235 |
+
bgm_tokens = bgm_tokens[...,:target_melody_token_len]
|
236 |
+
elif bgm_tokens.shape[-1] < target_melody_token_len:
|
237 |
+
bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
238 |
+
if vocal_tokens.shape[-1] > target_melody_token_len:
|
239 |
+
vocal_tokens = vocal_tokens[...,:target_melody_token_len]
|
240 |
+
elif vocal_tokens.shape[-1] < target_melody_token_len:
|
241 |
+
vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
|
242 |
+
else:
|
243 |
+
bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
244 |
+
vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
|
245 |
+
|
246 |
+
melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
|
247 |
+
assert melody_tokens.shape[-1] == target_melody_token_len
|
248 |
+
audio_qt_embs = melody_tokens.long()
|
249 |
+
return texts, audio_qt_embs
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
def _generate_tokens(self,
|
254 |
+
texts: tp.Optional[tp.List[str]] = None,
|
255 |
+
descriptions: tp.Optional[tp.List[str]] = None,
|
256 |
+
audio_qt_embs: tp.Optional[tp.List[torch.Tensor]] = None) -> torch.Tensor:
|
257 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
261 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
262 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
263 |
+
Returns:
|
264 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
265 |
+
"""
|
266 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
267 |
+
current_gen_offset: int = 0
|
268 |
+
|
269 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
270 |
+
generated_tokens += current_gen_offset
|
271 |
+
if self._progress_callback is not None:
|
272 |
+
# Note that total_gen_len might be quite wrong depending on the
|
273 |
+
# codebook pattern used, but with delay it is almost accurate.
|
274 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
275 |
+
else:
|
276 |
+
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
277 |
+
|
278 |
+
if self.duration <= self.max_duration:
|
279 |
+
# generate by sampling from LM, simple case.
|
280 |
+
with self.autocast:
|
281 |
+
gen_tokens = self.lm.generate(texts=texts,
|
282 |
+
descriptions=descriptions,
|
283 |
+
audio_qt_embs=audio_qt_embs,
|
284 |
+
max_gen_len=total_gen_len,
|
285 |
+
**self.generation_params)
|
286 |
+
else:
|
287 |
+
raise NotImplementedError(f"duration {self.duration} < max duration {self.max_duration}")
|
288 |
+
return gen_tokens
|
289 |
+
|
290 |
+
@torch.no_grad()
|
291 |
+
def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None):
|
292 |
+
"""Generate Audio from tokens"""
|
293 |
+
assert gen_tokens.dim() == 3
|
294 |
+
if self.seperate_tokenizer is not None:
|
295 |
+
gen_tokens_song = gen_tokens[:, [0], :]
|
296 |
+
gen_tokens_vocal = gen_tokens[:, [1], :]
|
297 |
+
gen_tokens_bgm = gen_tokens[:, [2], :]
|
298 |
+
# gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt)
|
299 |
+
gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt)
|
300 |
+
return gen_audio_seperate
|
301 |
+
else:
|
302 |
+
gen_audio = self.audiotokenizer.decode(gen_tokens, prompt)
|
303 |
+
return gen_audio
|
codeclm/models/levo.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .llama.modeling_llama import LlamaConfig, CausalLMOutputWithPast, BaseModelOutputWithPast, LlamaDecoderLayer, LlamaRMSNorm
|
3 |
+
from .llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLM_base
|
4 |
+
from .llama.modeling_llama import LlamaModel as LlamaModel_base
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from typing import Union, Optional, Tuple, List
|
9 |
+
from packaging import version
|
10 |
+
import transformers
|
11 |
+
"""
|
12 |
+
Wrap the original Llama model for potential customized changes.
|
13 |
+
"""
|
14 |
+
|
15 |
+
"""main class"""
|
16 |
+
class CausalLM(LlamaForCausalLM_base):
|
17 |
+
def __init__(self, config):
|
18 |
+
super().__init__(config)
|
19 |
+
self.model = LmModel(config)
|
20 |
+
self.vocab_size = config.vocab_size
|
21 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
22 |
+
|
23 |
+
def forward(
|
24 |
+
self,
|
25 |
+
input_ids: torch.LongTensor = None,
|
26 |
+
attention_mask: Optional[torch.Tensor] = None,
|
27 |
+
position_ids: Optional[torch.LongTensor] = None,
|
28 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
29 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
30 |
+
labels: Optional[torch.LongTensor] = None,
|
31 |
+
use_cache: Optional[bool] = None,
|
32 |
+
output_attentions: Optional[bool] = None,
|
33 |
+
output_hidden_states: Optional[bool] = None,
|
34 |
+
return_dict: Optional[bool] = None,
|
35 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
36 |
+
|
37 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
38 |
+
output_hidden_states = (
|
39 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
40 |
+
)
|
41 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
42 |
+
|
43 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
44 |
+
outputs = self.model(
|
45 |
+
input_ids=input_ids,
|
46 |
+
attention_mask=attention_mask,
|
47 |
+
position_ids=position_ids,
|
48 |
+
past_key_values=past_key_values,
|
49 |
+
inputs_embeds=inputs_embeds,
|
50 |
+
use_cache=use_cache,
|
51 |
+
output_attentions=output_attentions,
|
52 |
+
output_hidden_states=output_hidden_states,
|
53 |
+
return_dict=return_dict,
|
54 |
+
)
|
55 |
+
|
56 |
+
hidden_states = outputs[0]
|
57 |
+
if self.config.pretraining_tp > 1:
|
58 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
59 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
60 |
+
logits = torch.cat(logits, dim=-1)
|
61 |
+
else:
|
62 |
+
logits = self.lm_head(hidden_states)
|
63 |
+
logits = logits.float()
|
64 |
+
|
65 |
+
loss = None
|
66 |
+
if labels is not None:
|
67 |
+
# Shift so that tokens < n predict n
|
68 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
69 |
+
shift_labels = labels[..., 1:].contiguous()
|
70 |
+
# Flatten the tokens
|
71 |
+
loss_fct = nn.CrossEntropyLoss()
|
72 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
73 |
+
shift_labels = shift_labels.view(-1)
|
74 |
+
# Enable model parallelism
|
75 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
76 |
+
loss = loss_fct(shift_logits, shift_labels)
|
77 |
+
|
78 |
+
if not return_dict:
|
79 |
+
output = (logits,) + outputs[1:]
|
80 |
+
return (loss,) + output if loss is not None else output
|
81 |
+
|
82 |
+
return CausalLMOutputWithPast(
|
83 |
+
loss=loss,
|
84 |
+
logits=logits,
|
85 |
+
past_key_values=outputs.past_key_values,
|
86 |
+
hidden_states=hidden_states,
|
87 |
+
attentions=outputs.attentions,
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
"""Submodel class"""
|
92 |
+
class LmModel(LlamaModel_base):
|
93 |
+
def __init__(self, config: LlamaConfig):
|
94 |
+
super().__init__(config)
|
95 |
+
self.padding_idx = config.pad_token_id
|
96 |
+
self.vocab_size = config.vocab_size
|
97 |
+
layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here
|
98 |
+
|
99 |
+
assert version.parse(transformers.__version__) < version.parse("4.40")
|
100 |
+
|
101 |
+
self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)])
|
102 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
103 |
+
|
104 |
+
self.gradient_checkpointing = False
|
105 |
+
# Initialize weights and apply final processing
|
106 |
+
self.post_init()
|
107 |
+
self.gradient_checkpointing_disable()
|
108 |
+
|
109 |
+
def forward(
|
110 |
+
self,
|
111 |
+
input_ids: torch.LongTensor = None,
|
112 |
+
attention_mask: Optional[torch.Tensor] = None,
|
113 |
+
position_ids: Optional[torch.LongTensor] = None,
|
114 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
115 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
116 |
+
use_cache: Optional[bool] = None,
|
117 |
+
output_attentions: Optional[bool] = None,
|
118 |
+
output_hidden_states: Optional[bool] = None,
|
119 |
+
return_dict: Optional[bool] = None,
|
120 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
121 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
122 |
+
output_hidden_states = (
|
123 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
124 |
+
)
|
125 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
126 |
+
|
127 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
128 |
+
|
129 |
+
# retrieve input_ids and inputs_embeds
|
130 |
+
if input_ids is not None and inputs_embeds is not None:
|
131 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
132 |
+
elif input_ids is not None:
|
133 |
+
batch_size, seq_length = input_ids.shape
|
134 |
+
elif inputs_embeds is not None:
|
135 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
136 |
+
else:
|
137 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
138 |
+
|
139 |
+
seq_length_with_past = seq_length
|
140 |
+
past_key_values_length = 0
|
141 |
+
|
142 |
+
if past_key_values is not None:
|
143 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
144 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
145 |
+
|
146 |
+
if position_ids is None:
|
147 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
148 |
+
position_ids = torch.arange(
|
149 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
150 |
+
)
|
151 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
152 |
+
else:
|
153 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
154 |
+
|
155 |
+
# embed positions
|
156 |
+
if attention_mask is None:
|
157 |
+
attention_mask = torch.ones(
|
158 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
159 |
+
)
|
160 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
161 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
162 |
+
)
|
163 |
+
|
164 |
+
hidden_states = inputs_embeds
|
165 |
+
|
166 |
+
if self.gradient_checkpointing and self.training:
|
167 |
+
if use_cache:
|
168 |
+
use_cache = False
|
169 |
+
|
170 |
+
# decoder layers
|
171 |
+
all_hidden_states = () if output_hidden_states else None
|
172 |
+
all_self_attns = () if output_attentions else None
|
173 |
+
next_decoder_cache = () if use_cache else None
|
174 |
+
|
175 |
+
for idx, decoder_layer in enumerate(self.layers):
|
176 |
+
if output_hidden_states:
|
177 |
+
all_hidden_states += (hidden_states,)
|
178 |
+
|
179 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
180 |
+
|
181 |
+
layer_args = (hidden_states, attention_mask, position_ids,)
|
182 |
+
|
183 |
+
if self.gradient_checkpointing and self.training:
|
184 |
+
|
185 |
+
def create_custom_forward(module):
|
186 |
+
def custom_forward(*inputs):
|
187 |
+
# None for past_key_value
|
188 |
+
return module(*inputs, past_key_value, output_attentions)
|
189 |
+
|
190 |
+
return custom_forward
|
191 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
192 |
+
create_custom_forward(decoder_layer), *layer_args
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
|
196 |
+
layer_outputs = decoder_layer(*layer_args,
|
197 |
+
past_key_value=past_key_value,
|
198 |
+
output_attentions=output_attentions,
|
199 |
+
use_cache=use_cache)
|
200 |
+
|
201 |
+
hidden_states = layer_outputs[0]
|
202 |
+
|
203 |
+
if use_cache:
|
204 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
205 |
+
|
206 |
+
if output_attentions:
|
207 |
+
all_self_attns += (layer_outputs[1],)
|
208 |
+
|
209 |
+
hidden_states = self.norm(hidden_states)
|
210 |
+
|
211 |
+
# add hidden states from the last decoder layer
|
212 |
+
if output_hidden_states:
|
213 |
+
all_hidden_states += (hidden_states,)
|
214 |
+
|
215 |
+
next_cache = next_decoder_cache if use_cache else None
|
216 |
+
if not return_dict:
|
217 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
218 |
+
return BaseModelOutputWithPast(
|
219 |
+
last_hidden_state=hidden_states,
|
220 |
+
past_key_values=next_cache,
|
221 |
+
hidden_states=all_hidden_states,
|
222 |
+
attentions=all_self_attns,
|
223 |
+
)
|
224 |
+
|
codeclm/models/llama/__init__.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import TYPE_CHECKING
|
15 |
+
|
16 |
+
from transformers.utils import (
|
17 |
+
OptionalDependencyNotAvailable,
|
18 |
+
_LazyModule,
|
19 |
+
is_sentencepiece_available,
|
20 |
+
is_tokenizers_available,
|
21 |
+
is_torch_available,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
_import_structure = {
|
26 |
+
"configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"],
|
27 |
+
}
|
28 |
+
|
29 |
+
try:
|
30 |
+
if not is_sentencepiece_available():
|
31 |
+
raise OptionalDependencyNotAvailable()
|
32 |
+
except OptionalDependencyNotAvailable:
|
33 |
+
pass
|
34 |
+
else:
|
35 |
+
_import_structure["tokenization_llama"] = ["LlamaTokenizer"]
|
36 |
+
|
37 |
+
try:
|
38 |
+
if not is_tokenizers_available():
|
39 |
+
raise OptionalDependencyNotAvailable()
|
40 |
+
except OptionalDependencyNotAvailable:
|
41 |
+
pass
|
42 |
+
else:
|
43 |
+
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
|
44 |
+
|
45 |
+
try:
|
46 |
+
if not is_torch_available():
|
47 |
+
raise OptionalDependencyNotAvailable()
|
48 |
+
except OptionalDependencyNotAvailable:
|
49 |
+
pass
|
50 |
+
else:
|
51 |
+
_import_structure["modeling_llama"] = [
|
52 |
+
"LlamaForCausalLM",
|
53 |
+
"LlamaModel",
|
54 |
+
"LlamaPreTrainedModel",
|
55 |
+
"LlamaForSequenceClassification",
|
56 |
+
]
|
57 |
+
|
58 |
+
|
59 |
+
if TYPE_CHECKING:
|
60 |
+
from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig
|
61 |
+
|
62 |
+
try:
|
63 |
+
if not is_sentencepiece_available():
|
64 |
+
raise OptionalDependencyNotAvailable()
|
65 |
+
except OptionalDependencyNotAvailable:
|
66 |
+
pass
|
67 |
+
else:
|
68 |
+
from .tokenization_llama import LlamaTokenizer
|
69 |
+
|
70 |
+
try:
|
71 |
+
if not is_tokenizers_available():
|
72 |
+
raise OptionalDependencyNotAvailable()
|
73 |
+
except OptionalDependencyNotAvailable:
|
74 |
+
pass
|
75 |
+
else:
|
76 |
+
from .tokenization_llama_fast import LlamaTokenizerFast
|
77 |
+
|
78 |
+
try:
|
79 |
+
if not is_torch_available():
|
80 |
+
raise OptionalDependencyNotAvailable()
|
81 |
+
except OptionalDependencyNotAvailable:
|
82 |
+
pass
|
83 |
+
else:
|
84 |
+
from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
|
85 |
+
|
86 |
+
|
87 |
+
else:
|
88 |
+
import sys
|
89 |
+
|
90 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
codeclm/models/llama/configuration_llama.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" LLaMA model configuration"""
|
21 |
+
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.utils import logging
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
29 |
+
|
30 |
+
|
31 |
+
class LlamaConfig(PretrainedConfig):
|
32 |
+
r"""
|
33 |
+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
|
34 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
35 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
|
41 |
+
Args:
|
42 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
43 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
44 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
45 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
46 |
+
Dimension of the hidden representations.
|
47 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
48 |
+
Dimension of the MLP representations.
|
49 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
50 |
+
Number of hidden layers in the Transformer encoder.
|
51 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
52 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
53 |
+
num_key_value_heads (`int`, *optional*):
|
54 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
55 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
56 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
57 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
58 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
59 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
60 |
+
`num_attention_heads`.
|
61 |
+
pretraining_tp (`int`, *optional*, defaults to `1`):
|
62 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
63 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
64 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
65 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
66 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
67 |
+
The non-linear activation function (function or string) in the decoder.
|
68 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
69 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
70 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
71 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
72 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
73 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
74 |
+
The epsilon used by the rms normalization layers.
|
75 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
76 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
77 |
+
relevant if `config.is_decoder=True`.
|
78 |
+
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
79 |
+
Whether to tie weight embeddings
|
80 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
81 |
+
The base period of the RoPE embeddings.
|
82 |
+
rope_scaling (`Dict`, *optional*):
|
83 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
84 |
+
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
85 |
+
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
86 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
87 |
+
these scaling strategies behave:
|
88 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
89 |
+
experimental feature, subject to breaking API changes in future versions.
|
90 |
+
attention_bias (`bool`, defaults to `False`):
|
91 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
92 |
+
|
93 |
+
Example:
|
94 |
+
|
95 |
+
```python
|
96 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
97 |
+
|
98 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
99 |
+
>>> configuration = LlamaConfig()
|
100 |
+
|
101 |
+
>>> # Initializing a model from the llama-7b style configuration
|
102 |
+
>>> model = LlamaModel(configuration)
|
103 |
+
|
104 |
+
>>> # Accessing the model configuration
|
105 |
+
>>> configuration = model.config
|
106 |
+
```"""
|
107 |
+
model_type = "llama"
|
108 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
vocab_size=32000,
|
113 |
+
hidden_size=4096,
|
114 |
+
intermediate_size=11008,
|
115 |
+
num_hidden_layers=32,
|
116 |
+
num_attention_heads=32,
|
117 |
+
num_key_value_heads=None,
|
118 |
+
hidden_act="silu",
|
119 |
+
max_position_embeddings=2048,
|
120 |
+
initializer_range=0.02,
|
121 |
+
rms_norm_eps=1e-6,
|
122 |
+
use_cache=True,
|
123 |
+
pad_token_id=None,
|
124 |
+
bos_token_id=1,
|
125 |
+
eos_token_id=2,
|
126 |
+
pretraining_tp=1,
|
127 |
+
tie_word_embeddings=False,
|
128 |
+
rope_theta=10000.0,
|
129 |
+
rope_scaling=None,
|
130 |
+
attention_bias=False,
|
131 |
+
**kwargs,
|
132 |
+
):
|
133 |
+
self.vocab_size = vocab_size
|
134 |
+
self.max_position_embeddings = max_position_embeddings
|
135 |
+
self.hidden_size = hidden_size
|
136 |
+
self.intermediate_size = intermediate_size
|
137 |
+
self.num_hidden_layers = num_hidden_layers
|
138 |
+
self.num_attention_heads = num_attention_heads
|
139 |
+
|
140 |
+
# for backward compatibility
|
141 |
+
if num_key_value_heads is None:
|
142 |
+
num_key_value_heads = num_attention_heads
|
143 |
+
|
144 |
+
self.num_key_value_heads = num_key_value_heads
|
145 |
+
self.hidden_act = hidden_act
|
146 |
+
self.initializer_range = initializer_range
|
147 |
+
self.rms_norm_eps = rms_norm_eps
|
148 |
+
self.pretraining_tp = pretraining_tp
|
149 |
+
self.use_cache = use_cache
|
150 |
+
self.rope_theta = rope_theta
|
151 |
+
self.rope_scaling = rope_scaling
|
152 |
+
self._rope_scaling_validation()
|
153 |
+
self.attention_bias = attention_bias
|
154 |
+
|
155 |
+
super().__init__(
|
156 |
+
pad_token_id=pad_token_id,
|
157 |
+
bos_token_id=bos_token_id,
|
158 |
+
eos_token_id=eos_token_id,
|
159 |
+
tie_word_embeddings=tie_word_embeddings,
|
160 |
+
**kwargs,
|
161 |
+
)
|
162 |
+
|
163 |
+
def _rope_scaling_validation(self):
|
164 |
+
"""
|
165 |
+
Validate the `rope_scaling` configuration.
|
166 |
+
"""
|
167 |
+
if self.rope_scaling is None:
|
168 |
+
return
|
169 |
+
|
170 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
171 |
+
raise ValueError(
|
172 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
173 |
+
f"got {self.rope_scaling}"
|
174 |
+
)
|
175 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
176 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
177 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
178 |
+
raise ValueError(
|
179 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
180 |
+
)
|
181 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
182 |
+
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
codeclm/models/llama/convert_llama_weights_to_hf.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import gc
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import shutil
|
19 |
+
import warnings
|
20 |
+
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from transformers import LlamaTokenizerFast
|
28 |
+
except ImportError as e:
|
29 |
+
warnings.warn(e)
|
30 |
+
warnings.warn(
|
31 |
+
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
32 |
+
)
|
33 |
+
LlamaTokenizerFast = None
|
34 |
+
|
35 |
+
"""
|
36 |
+
Sample usage:
|
37 |
+
|
38 |
+
```
|
39 |
+
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
40 |
+
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
41 |
+
```
|
42 |
+
|
43 |
+
Thereafter, models can be loaded via:
|
44 |
+
|
45 |
+
```py
|
46 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
47 |
+
|
48 |
+
model = LlamaForCausalLM.from_pretrained("/output/path")
|
49 |
+
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
50 |
+
```
|
51 |
+
|
52 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
53 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
54 |
+
"""
|
55 |
+
|
56 |
+
NUM_SHARDS = {
|
57 |
+
"7B": 1,
|
58 |
+
"7Bf": 1,
|
59 |
+
"13B": 2,
|
60 |
+
"13Bf": 2,
|
61 |
+
"34B": 4,
|
62 |
+
"30B": 4,
|
63 |
+
"65B": 8,
|
64 |
+
"70B": 8,
|
65 |
+
"70Bf": 8,
|
66 |
+
}
|
67 |
+
|
68 |
+
|
69 |
+
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
70 |
+
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
71 |
+
|
72 |
+
|
73 |
+
def read_json(path):
|
74 |
+
with open(path, "r") as f:
|
75 |
+
return json.load(f)
|
76 |
+
|
77 |
+
|
78 |
+
def write_json(text, path):
|
79 |
+
with open(path, "w") as f:
|
80 |
+
json.dump(text, f)
|
81 |
+
|
82 |
+
|
83 |
+
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
|
84 |
+
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
|
85 |
+
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
|
86 |
+
input_base_path = os.path.join(input_base_path, model_size)
|
87 |
+
|
88 |
+
os.makedirs(model_path, exist_ok=True)
|
89 |
+
tmp_model_path = os.path.join(model_path, "tmp")
|
90 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
91 |
+
|
92 |
+
params = read_json(os.path.join(input_base_path, "params.json"))
|
93 |
+
num_shards = NUM_SHARDS[model_size]
|
94 |
+
n_layers = params["n_layers"]
|
95 |
+
n_heads = params["n_heads"]
|
96 |
+
n_heads_per_shard = n_heads // num_shards
|
97 |
+
dim = params["dim"]
|
98 |
+
dims_per_head = dim // n_heads
|
99 |
+
base = params.get("rope_theta", 10000.0)
|
100 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
101 |
+
if base > 10000.0:
|
102 |
+
max_position_embeddings = 16384
|
103 |
+
else:
|
104 |
+
max_position_embeddings = 2048
|
105 |
+
|
106 |
+
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
107 |
+
if tokenizer_path is not None:
|
108 |
+
tokenizer = tokenizer_class(tokenizer_path)
|
109 |
+
tokenizer.save_pretrained(model_path)
|
110 |
+
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
|
111 |
+
|
112 |
+
if "n_kv_heads" in params:
|
113 |
+
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
114 |
+
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
115 |
+
key_value_dim = dim // num_key_value_heads
|
116 |
+
else: # compatibility with other checkpoints
|
117 |
+
num_key_value_heads = n_heads
|
118 |
+
num_local_key_value_heads = n_heads_per_shard
|
119 |
+
key_value_dim = dim
|
120 |
+
|
121 |
+
# permute for sliced rotary
|
122 |
+
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
|
123 |
+
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
124 |
+
|
125 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
126 |
+
# Load weights
|
127 |
+
if model_size == "7B":
|
128 |
+
# Not sharded
|
129 |
+
# (The sharded implementation would also work, but this is simpler.)
|
130 |
+
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
131 |
+
else:
|
132 |
+
# Sharded
|
133 |
+
loaded = [
|
134 |
+
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
135 |
+
for i in range(num_shards)
|
136 |
+
]
|
137 |
+
param_count = 0
|
138 |
+
index_dict = {"weight_map": {}}
|
139 |
+
for layer_i in range(n_layers):
|
140 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
141 |
+
if model_size == "7B":
|
142 |
+
# Unsharded
|
143 |
+
state_dict = {
|
144 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
145 |
+
loaded[f"layers.{layer_i}.attention.wq.weight"]
|
146 |
+
),
|
147 |
+
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
148 |
+
loaded[f"layers.{layer_i}.attention.wk.weight"]
|
149 |
+
),
|
150 |
+
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
151 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
152 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
153 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
154 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
155 |
+
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
|
156 |
+
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
|
157 |
+
}
|
158 |
+
else:
|
159 |
+
# Sharded
|
160 |
+
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
161 |
+
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
162 |
+
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
163 |
+
|
164 |
+
state_dict = {
|
165 |
+
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
166 |
+
f"layers.{layer_i}.attention_norm.weight"
|
167 |
+
].clone(),
|
168 |
+
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
169 |
+
f"layers.{layer_i}.ffn_norm.weight"
|
170 |
+
].clone(),
|
171 |
+
}
|
172 |
+
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
173 |
+
torch.cat(
|
174 |
+
[
|
175 |
+
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
176 |
+
for i in range(num_shards)
|
177 |
+
],
|
178 |
+
dim=0,
|
179 |
+
).reshape(dim, dim)
|
180 |
+
)
|
181 |
+
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
182 |
+
torch.cat(
|
183 |
+
[
|
184 |
+
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
|
185 |
+
num_local_key_value_heads, dims_per_head, dim
|
186 |
+
)
|
187 |
+
for i in range(num_shards)
|
188 |
+
],
|
189 |
+
dim=0,
|
190 |
+
).reshape(key_value_dim, dim),
|
191 |
+
num_key_value_heads,
|
192 |
+
key_value_dim,
|
193 |
+
dim,
|
194 |
+
)
|
195 |
+
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
196 |
+
[
|
197 |
+
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
|
198 |
+
num_local_key_value_heads, dims_per_head, dim
|
199 |
+
)
|
200 |
+
for i in range(num_shards)
|
201 |
+
],
|
202 |
+
dim=0,
|
203 |
+
).reshape(key_value_dim, dim)
|
204 |
+
|
205 |
+
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
206 |
+
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
207 |
+
)
|
208 |
+
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
209 |
+
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
210 |
+
)
|
211 |
+
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
212 |
+
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
213 |
+
)
|
214 |
+
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
215 |
+
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
216 |
+
)
|
217 |
+
|
218 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
219 |
+
for k, v in state_dict.items():
|
220 |
+
index_dict["weight_map"][k] = filename
|
221 |
+
param_count += v.numel()
|
222 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
223 |
+
|
224 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
225 |
+
if model_size == "7B":
|
226 |
+
# Unsharded
|
227 |
+
state_dict = {
|
228 |
+
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
229 |
+
"model.norm.weight": loaded["norm.weight"],
|
230 |
+
"lm_head.weight": loaded["output.weight"],
|
231 |
+
}
|
232 |
+
else:
|
233 |
+
state_dict = {
|
234 |
+
"model.norm.weight": loaded[0]["norm.weight"],
|
235 |
+
"model.embed_tokens.weight": torch.cat(
|
236 |
+
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
237 |
+
),
|
238 |
+
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
239 |
+
}
|
240 |
+
|
241 |
+
for k, v in state_dict.items():
|
242 |
+
index_dict["weight_map"][k] = filename
|
243 |
+
param_count += v.numel()
|
244 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
245 |
+
|
246 |
+
# Write configs
|
247 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
248 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
249 |
+
ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1
|
250 |
+
multiple_of = params["multiple_of"] if "multiple_of" in params else 256
|
251 |
+
config = LlamaConfig(
|
252 |
+
hidden_size=dim,
|
253 |
+
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
254 |
+
num_attention_heads=params["n_heads"],
|
255 |
+
num_hidden_layers=params["n_layers"],
|
256 |
+
rms_norm_eps=params["norm_eps"],
|
257 |
+
num_key_value_heads=num_key_value_heads,
|
258 |
+
vocab_size=vocab_size,
|
259 |
+
rope_theta=base,
|
260 |
+
max_position_embeddings=max_position_embeddings,
|
261 |
+
)
|
262 |
+
config.save_pretrained(tmp_model_path)
|
263 |
+
|
264 |
+
# Make space so we can load the model properly now.
|
265 |
+
del state_dict
|
266 |
+
del loaded
|
267 |
+
gc.collect()
|
268 |
+
|
269 |
+
print("Loading the checkpoint in a Llama model.")
|
270 |
+
model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
|
271 |
+
# Avoid saving this as part of the config.
|
272 |
+
del model.config._name_or_path
|
273 |
+
model.config.torch_dtype = torch.float16
|
274 |
+
print("Saving in the Transformers format.")
|
275 |
+
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
276 |
+
shutil.rmtree(tmp_model_path)
|
277 |
+
|
278 |
+
|
279 |
+
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
280 |
+
# Initialize the tokenizer based on the `spm` model
|
281 |
+
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
282 |
+
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
283 |
+
tokenizer = tokenizer_class(input_tokenizer_path)
|
284 |
+
tokenizer.save_pretrained(tokenizer_path)
|
285 |
+
|
286 |
+
|
287 |
+
def main():
|
288 |
+
parser = argparse.ArgumentParser()
|
289 |
+
parser.add_argument(
|
290 |
+
"--input_dir",
|
291 |
+
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
292 |
+
)
|
293 |
+
parser.add_argument(
|
294 |
+
"--model_size",
|
295 |
+
choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"],
|
296 |
+
help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama",
|
297 |
+
)
|
298 |
+
parser.add_argument(
|
299 |
+
"--output_dir",
|
300 |
+
help="Location to write HF model and tokenizer",
|
301 |
+
)
|
302 |
+
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
|
303 |
+
args = parser.parse_args()
|
304 |
+
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
305 |
+
if args.model_size != "tokenizer_only":
|
306 |
+
write_model(
|
307 |
+
model_path=args.output_dir,
|
308 |
+
input_base_path=args.input_dir,
|
309 |
+
model_size=args.model_size,
|
310 |
+
safe_serialization=args.safe_serialization,
|
311 |
+
tokenizer_path=spm_path,
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
write_tokenizer(args.output_dir, spm_path)
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
main()
|
codeclm/models/llama/modeling_llama.py
ADDED
@@ -0,0 +1,1243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
""" PyTorch LLaMA model."""
|
21 |
+
import math
|
22 |
+
from typing import List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
+
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
32 |
+
from transformers.modeling_utils import PreTrainedModel
|
33 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
34 |
+
from transformers.utils import (
|
35 |
+
add_start_docstrings,
|
36 |
+
add_start_docstrings_to_model_forward,
|
37 |
+
is_flash_attn_available,
|
38 |
+
logging,
|
39 |
+
replace_return_docstrings,
|
40 |
+
)
|
41 |
+
from .configuration_llama import LlamaConfig
|
42 |
+
|
43 |
+
|
44 |
+
if is_flash_attn_available():
|
45 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
46 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
_CONFIG_FOR_DOC = "LlamaConfig"
|
52 |
+
|
53 |
+
|
54 |
+
def _get_unpad_data(padding_mask):
|
55 |
+
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
56 |
+
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
57 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
58 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
59 |
+
return (
|
60 |
+
indices,
|
61 |
+
cu_seqlens,
|
62 |
+
max_seqlen_in_batch,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
67 |
+
def _make_causal_mask(
|
68 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Make causal mask used for bi-directional self-attention.
|
72 |
+
"""
|
73 |
+
bsz, tgt_len = input_ids_shape
|
74 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
75 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
76 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
77 |
+
mask = mask.to(dtype)
|
78 |
+
|
79 |
+
if past_key_values_length > 0:
|
80 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
81 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
82 |
+
|
83 |
+
|
84 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
85 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
86 |
+
"""
|
87 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
88 |
+
"""
|
89 |
+
bsz, src_len = mask.size()
|
90 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
91 |
+
|
92 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
93 |
+
|
94 |
+
inverted_mask = 1.0 - expanded_mask
|
95 |
+
|
96 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
97 |
+
|
98 |
+
|
99 |
+
class LlamaRMSNorm(nn.Module):
|
100 |
+
def __init__(self, hidden_size, eps=1e-6):
|
101 |
+
"""
|
102 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
106 |
+
self.variance_epsilon = eps
|
107 |
+
|
108 |
+
def forward(self, hidden_states):
|
109 |
+
input_dtype = hidden_states.dtype
|
110 |
+
hidden_states = hidden_states.to(torch.float32)
|
111 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
112 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
113 |
+
return self.weight * hidden_states.to(input_dtype)
|
114 |
+
|
115 |
+
|
116 |
+
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
117 |
+
|
118 |
+
|
119 |
+
class LlamaRotaryEmbedding(nn.Module):
|
120 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
121 |
+
super().__init__()
|
122 |
+
|
123 |
+
self.dim = dim
|
124 |
+
self.max_position_embeddings = max_position_embeddings
|
125 |
+
self.base = base
|
126 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
127 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
128 |
+
|
129 |
+
# Build here to make `torch.jit.trace` work.
|
130 |
+
self._set_cos_sin_cache(
|
131 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
132 |
+
)
|
133 |
+
|
134 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
135 |
+
self.max_seq_len_cached = seq_len
|
136 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
137 |
+
|
138 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
139 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
140 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
141 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
142 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
143 |
+
|
144 |
+
def forward(self, x, seq_len=None):
|
145 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
146 |
+
if seq_len > self.max_seq_len_cached:
|
147 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
148 |
+
|
149 |
+
return (
|
150 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
151 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
156 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
157 |
+
|
158 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
159 |
+
self.scaling_factor = scaling_factor
|
160 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
161 |
+
|
162 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
163 |
+
self.max_seq_len_cached = seq_len
|
164 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
165 |
+
t = t / self.scaling_factor
|
166 |
+
|
167 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
168 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
169 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
170 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
171 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
172 |
+
|
173 |
+
|
174 |
+
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
175 |
+
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
176 |
+
|
177 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
178 |
+
self.scaling_factor = scaling_factor
|
179 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
180 |
+
|
181 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
182 |
+
self.max_seq_len_cached = seq_len
|
183 |
+
|
184 |
+
if seq_len > self.max_position_embeddings:
|
185 |
+
base = self.base * (
|
186 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
187 |
+
) ** (self.dim / (self.dim - 2))
|
188 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
189 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
190 |
+
|
191 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
192 |
+
|
193 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
194 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
195 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
196 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
197 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
198 |
+
|
199 |
+
|
200 |
+
def rotate_half(x):
|
201 |
+
"""Rotates half the hidden dims of the input."""
|
202 |
+
x1 = x[..., : x.shape[-1] // 2]
|
203 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
204 |
+
return torch.cat((-x2, x1), dim=-1)
|
205 |
+
|
206 |
+
|
207 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
208 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
209 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
210 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
211 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
212 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
213 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
214 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
215 |
+
return q_embed, k_embed
|
216 |
+
|
217 |
+
|
218 |
+
class LlamaMLP(nn.Module):
|
219 |
+
def __init__(self, config):
|
220 |
+
super().__init__()
|
221 |
+
self.config = config
|
222 |
+
self.hidden_size = config.hidden_size
|
223 |
+
self.intermediate_size = config.intermediate_size
|
224 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
225 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
226 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
227 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
if self.config.pretraining_tp > 1:
|
231 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
232 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
233 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
234 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
235 |
+
|
236 |
+
gate_proj = torch.cat(
|
237 |
+
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
238 |
+
)
|
239 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
240 |
+
|
241 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
242 |
+
down_proj = [
|
243 |
+
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
244 |
+
]
|
245 |
+
down_proj = sum(down_proj)
|
246 |
+
else:
|
247 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
248 |
+
|
249 |
+
return down_proj
|
250 |
+
|
251 |
+
|
252 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
253 |
+
"""
|
254 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
255 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
256 |
+
"""
|
257 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
258 |
+
if n_rep == 1:
|
259 |
+
return hidden_states
|
260 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
261 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
262 |
+
|
263 |
+
|
264 |
+
class LlamaAttention(nn.Module):
|
265 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
266 |
+
|
267 |
+
def __init__(self, config: LlamaConfig):
|
268 |
+
super().__init__()
|
269 |
+
self.config = config
|
270 |
+
self.hidden_size = config.hidden_size
|
271 |
+
self.num_heads = config.num_attention_heads
|
272 |
+
self.head_dim = self.hidden_size // self.num_heads
|
273 |
+
self.num_key_value_heads = config.num_key_value_heads
|
274 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
275 |
+
self.max_position_embeddings = config.max_position_embeddings
|
276 |
+
self.rope_theta = config.rope_theta
|
277 |
+
|
278 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
279 |
+
raise ValueError(
|
280 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
281 |
+
f" and `num_heads`: {self.num_heads})."
|
282 |
+
)
|
283 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
284 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
285 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
286 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
287 |
+
self._init_rope()
|
288 |
+
|
289 |
+
def _init_rope(self):
|
290 |
+
if self.config.rope_scaling is None:
|
291 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
292 |
+
self.head_dim,
|
293 |
+
max_position_embeddings=self.max_position_embeddings,
|
294 |
+
base=self.rope_theta,
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
scaling_type = self.config.rope_scaling["type"]
|
298 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
299 |
+
if scaling_type == "linear":
|
300 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
301 |
+
self.head_dim,
|
302 |
+
max_position_embeddings=self.max_position_embeddings,
|
303 |
+
scaling_factor=scaling_factor,
|
304 |
+
base=self.rope_theta,
|
305 |
+
)
|
306 |
+
elif scaling_type == "dynamic":
|
307 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
308 |
+
self.head_dim,
|
309 |
+
max_position_embeddings=self.max_position_embeddings,
|
310 |
+
scaling_factor=scaling_factor,
|
311 |
+
base=self.rope_theta,
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
315 |
+
|
316 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
317 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
318 |
+
|
319 |
+
def forward(
|
320 |
+
self,
|
321 |
+
hidden_states: torch.Tensor,
|
322 |
+
attention_mask: Optional[torch.Tensor] = None,
|
323 |
+
position_ids: Optional[torch.LongTensor] = None,
|
324 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
325 |
+
output_attentions: bool = False,
|
326 |
+
use_cache: bool = False,
|
327 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
328 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
329 |
+
bsz, q_len, _ = hidden_states.size()
|
330 |
+
|
331 |
+
if self.config.pretraining_tp > 1:
|
332 |
+
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
333 |
+
query_slices = self.q_proj.weight.split(
|
334 |
+
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
335 |
+
)
|
336 |
+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
337 |
+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
338 |
+
|
339 |
+
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
340 |
+
query_states = torch.cat(query_states, dim=-1)
|
341 |
+
|
342 |
+
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
343 |
+
key_states = torch.cat(key_states, dim=-1)
|
344 |
+
|
345 |
+
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
346 |
+
value_states = torch.cat(value_states, dim=-1)
|
347 |
+
|
348 |
+
else:
|
349 |
+
query_states = self.q_proj(hidden_states)
|
350 |
+
key_states = self.k_proj(hidden_states)
|
351 |
+
value_states = self.v_proj(hidden_states)
|
352 |
+
|
353 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
354 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
355 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
356 |
+
|
357 |
+
kv_seq_len = key_states.shape[-2]
|
358 |
+
if past_key_value is not None:
|
359 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
360 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
361 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
362 |
+
|
363 |
+
if past_key_value is not None:
|
364 |
+
# reuse k, v, self_attention
|
365 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
366 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
367 |
+
|
368 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
369 |
+
|
370 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
371 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
372 |
+
|
373 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
374 |
+
|
375 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
376 |
+
raise ValueError(
|
377 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
378 |
+
f" {attn_weights.size()}"
|
379 |
+
)
|
380 |
+
|
381 |
+
if attention_mask is not None:
|
382 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
383 |
+
raise ValueError(
|
384 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
385 |
+
)
|
386 |
+
attn_weights = attn_weights + attention_mask
|
387 |
+
|
388 |
+
# upcast attention to fp32
|
389 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
390 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
391 |
+
|
392 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
393 |
+
raise ValueError(
|
394 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
395 |
+
f" {attn_output.size()}"
|
396 |
+
)
|
397 |
+
|
398 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
399 |
+
|
400 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
401 |
+
|
402 |
+
if self.config.pretraining_tp > 1:
|
403 |
+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
404 |
+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
405 |
+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
406 |
+
else:
|
407 |
+
attn_output = self.o_proj(attn_output)
|
408 |
+
|
409 |
+
if not output_attentions:
|
410 |
+
attn_weights = None
|
411 |
+
|
412 |
+
return attn_output, attn_weights, past_key_value
|
413 |
+
|
414 |
+
|
415 |
+
class LlamaFlashAttention2(LlamaAttention):
|
416 |
+
"""
|
417 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
418 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
419 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
420 |
+
"""
|
421 |
+
|
422 |
+
def forward(
|
423 |
+
self,
|
424 |
+
hidden_states: torch.Tensor,
|
425 |
+
attention_mask: Optional[torch.Tensor] = None,
|
426 |
+
position_ids: Optional[torch.LongTensor] = None,
|
427 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
428 |
+
output_attentions: bool = False,
|
429 |
+
use_cache: bool = False,
|
430 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
431 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
432 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
433 |
+
output_attentions = False
|
434 |
+
|
435 |
+
bsz, q_len, _ = hidden_states.size()
|
436 |
+
|
437 |
+
query_states = self.q_proj(hidden_states)
|
438 |
+
key_states = self.k_proj(hidden_states)
|
439 |
+
value_states = self.v_proj(hidden_states)
|
440 |
+
|
441 |
+
# Flash attention requires the input to have the shape
|
442 |
+
# batch_size x seq_length x head_dime x hidden_dim
|
443 |
+
# therefore we just need to keep the original shape
|
444 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
445 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
446 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
447 |
+
|
448 |
+
kv_seq_len = key_states.shape[-2]
|
449 |
+
if past_key_value is not None:
|
450 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
451 |
+
|
452 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
453 |
+
|
454 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
455 |
+
|
456 |
+
if past_key_value is not None:
|
457 |
+
# reuse k, v, self_attention
|
458 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
459 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
460 |
+
|
461 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
462 |
+
|
463 |
+
query_states = query_states.transpose(1, 2)
|
464 |
+
key_states = key_states.transpose(1, 2)
|
465 |
+
value_states = value_states.transpose(1, 2)
|
466 |
+
|
467 |
+
# TODO: llama does not have dropout in the config??
|
468 |
+
# It is recommended to use dropout with FA according to the docs
|
469 |
+
# when training.
|
470 |
+
dropout_rate = 0.0 # if not self.training else self.attn_dropout
|
471 |
+
|
472 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
473 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
474 |
+
# cast them back in float16 just to be sure everything works as expected.
|
475 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
476 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
477 |
+
input_dtype = query_states.dtype
|
478 |
+
if input_dtype == torch.float32:
|
479 |
+
logger.warning_once(
|
480 |
+
"The input hidden states seems to be silently casted in float32, this might be related to"
|
481 |
+
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
482 |
+
" float16."
|
483 |
+
)
|
484 |
+
|
485 |
+
query_states = query_states.to(torch.float16)
|
486 |
+
key_states = key_states.to(torch.float16)
|
487 |
+
value_states = value_states.to(torch.float16)
|
488 |
+
|
489 |
+
attn_output = self._flash_attention_forward(
|
490 |
+
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
|
491 |
+
)
|
492 |
+
|
493 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
494 |
+
attn_output = self.o_proj(attn_output)
|
495 |
+
|
496 |
+
if not output_attentions:
|
497 |
+
attn_weights = None
|
498 |
+
|
499 |
+
return attn_output, attn_weights, past_key_value
|
500 |
+
|
501 |
+
def _flash_attention_forward(
|
502 |
+
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
|
503 |
+
):
|
504 |
+
"""
|
505 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
506 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
query_states (`torch.Tensor`):
|
510 |
+
Input query states to be passed to Flash Attention API
|
511 |
+
key_states (`torch.Tensor`):
|
512 |
+
Input key states to be passed to Flash Attention API
|
513 |
+
value_states (`torch.Tensor`):
|
514 |
+
Input value states to be passed to Flash Attention API
|
515 |
+
padding_mask (`torch.Tensor`):
|
516 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
517 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
518 |
+
dropout (`int`, *optional*):
|
519 |
+
Attention dropout
|
520 |
+
softmax_scale (`float`, *optional*):
|
521 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
522 |
+
"""
|
523 |
+
# Contains at least one padding token in the sequence
|
524 |
+
if padding_mask is not None:
|
525 |
+
batch_size = query_states.shape[0]
|
526 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
527 |
+
query_states, key_states, value_states, padding_mask, query_length
|
528 |
+
)
|
529 |
+
|
530 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
531 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
532 |
+
|
533 |
+
attn_output_unpad = flash_attn_varlen_func(
|
534 |
+
query_states,
|
535 |
+
key_states,
|
536 |
+
value_states,
|
537 |
+
cu_seqlens_q=cu_seqlens_q,
|
538 |
+
cu_seqlens_k=cu_seqlens_k,
|
539 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
540 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
541 |
+
dropout_p=dropout,
|
542 |
+
softmax_scale=softmax_scale,
|
543 |
+
causal=True,
|
544 |
+
)
|
545 |
+
|
546 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
547 |
+
else:
|
548 |
+
attn_output = flash_attn_func(
|
549 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
|
550 |
+
)
|
551 |
+
|
552 |
+
return attn_output
|
553 |
+
|
554 |
+
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
|
555 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
556 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
557 |
+
|
558 |
+
key_layer = index_first_axis(
|
559 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
560 |
+
)
|
561 |
+
value_layer = index_first_axis(
|
562 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
563 |
+
)
|
564 |
+
if query_length == kv_seq_len:
|
565 |
+
query_layer = index_first_axis(
|
566 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
567 |
+
)
|
568 |
+
cu_seqlens_q = cu_seqlens_k
|
569 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
570 |
+
indices_q = indices_k
|
571 |
+
elif query_length == 1:
|
572 |
+
max_seqlen_in_batch_q = 1
|
573 |
+
cu_seqlens_q = torch.arange(
|
574 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
575 |
+
) # There is a memcpy here, that is very bad.
|
576 |
+
indices_q = cu_seqlens_q[:-1]
|
577 |
+
query_layer = query_layer.squeeze(1)
|
578 |
+
else:
|
579 |
+
# The -q_len: slice assumes left padding.
|
580 |
+
padding_mask = padding_mask[:, -query_length:]
|
581 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)
|
582 |
+
|
583 |
+
return (
|
584 |
+
query_layer,
|
585 |
+
key_layer,
|
586 |
+
value_layer,
|
587 |
+
indices_q,
|
588 |
+
(cu_seqlens_q, cu_seqlens_k),
|
589 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
class LlamaDecoderLayer(nn.Module):
|
594 |
+
def __init__(self, config: LlamaConfig):
|
595 |
+
super().__init__()
|
596 |
+
self.hidden_size = config.hidden_size
|
597 |
+
self.self_attn = (
|
598 |
+
LlamaAttention(config=config)
|
599 |
+
if not getattr(config, "_flash_attn_2_enabled", False)
|
600 |
+
else LlamaFlashAttention2(config=config)
|
601 |
+
)
|
602 |
+
self.mlp = LlamaMLP(config)
|
603 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
604 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
605 |
+
|
606 |
+
def forward(
|
607 |
+
self,
|
608 |
+
hidden_states: torch.Tensor,
|
609 |
+
attention_mask: Optional[torch.Tensor] = None,
|
610 |
+
position_ids: Optional[torch.LongTensor] = None,
|
611 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
612 |
+
output_attentions: Optional[bool] = False,
|
613 |
+
use_cache: Optional[bool] = False,
|
614 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
615 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
616 |
+
"""
|
617 |
+
Args:
|
618 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
619 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
620 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
621 |
+
output_attentions (`bool`, *optional*):
|
622 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
623 |
+
returned tensors for more detail.
|
624 |
+
use_cache (`bool`, *optional*):
|
625 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
626 |
+
(see `past_key_values`).
|
627 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
628 |
+
"""
|
629 |
+
|
630 |
+
residual = hidden_states
|
631 |
+
|
632 |
+
hidden_states = self.input_layernorm(hidden_states)
|
633 |
+
|
634 |
+
# Self Attention
|
635 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
636 |
+
hidden_states=hidden_states,
|
637 |
+
attention_mask=attention_mask,
|
638 |
+
position_ids=position_ids,
|
639 |
+
past_key_value=past_key_value,
|
640 |
+
output_attentions=output_attentions,
|
641 |
+
use_cache=use_cache,
|
642 |
+
padding_mask=padding_mask,
|
643 |
+
)
|
644 |
+
hidden_states = residual + hidden_states
|
645 |
+
|
646 |
+
# Fully Connected
|
647 |
+
residual = hidden_states
|
648 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
649 |
+
hidden_states = self.mlp(hidden_states)
|
650 |
+
hidden_states = residual + hidden_states
|
651 |
+
|
652 |
+
outputs = (hidden_states,)
|
653 |
+
|
654 |
+
if output_attentions:
|
655 |
+
outputs += (self_attn_weights,)
|
656 |
+
|
657 |
+
if use_cache:
|
658 |
+
outputs += (present_key_value,)
|
659 |
+
|
660 |
+
return outputs
|
661 |
+
|
662 |
+
|
663 |
+
LLAMA_START_DOCSTRING = r"""
|
664 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
665 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
666 |
+
etc.)
|
667 |
+
|
668 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
669 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
670 |
+
and behavior.
|
671 |
+
|
672 |
+
Parameters:
|
673 |
+
config ([`LlamaConfig`]):
|
674 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
675 |
+
load the weights associated with the model, only the configuration. Check out the
|
676 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
677 |
+
"""
|
678 |
+
|
679 |
+
|
680 |
+
@add_start_docstrings(
|
681 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
682 |
+
LLAMA_START_DOCSTRING,
|
683 |
+
)
|
684 |
+
class LlamaPreTrainedModel(PreTrainedModel):
|
685 |
+
config_class = LlamaConfig
|
686 |
+
base_model_prefix = "model"
|
687 |
+
supports_gradient_checkpointing = True
|
688 |
+
_no_split_modules = ["LlamaDecoderLayer"]
|
689 |
+
_skip_keys_device_placement = "past_key_values"
|
690 |
+
_supports_flash_attn_2 = True
|
691 |
+
|
692 |
+
def _init_weights(self, module):
|
693 |
+
std = self.config.initializer_range
|
694 |
+
if isinstance(module, nn.Linear):
|
695 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
696 |
+
if module.bias is not None:
|
697 |
+
module.bias.data.zero_()
|
698 |
+
elif isinstance(module, nn.Embedding):
|
699 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
700 |
+
if module.padding_idx is not None:
|
701 |
+
module.weight.data[module.padding_idx].zero_()
|
702 |
+
|
703 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
704 |
+
if isinstance(module, LlamaModel):
|
705 |
+
module.gradient_checkpointing = value
|
706 |
+
|
707 |
+
|
708 |
+
LLAMA_INPUTS_DOCSTRING = r"""
|
709 |
+
Args:
|
710 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
711 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
712 |
+
it.
|
713 |
+
|
714 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
715 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
716 |
+
|
717 |
+
[What are input IDs?](../glossary#input-ids)
|
718 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
719 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
720 |
+
|
721 |
+
- 1 for tokens that are **not masked**,
|
722 |
+
- 0 for tokens that are **masked**.
|
723 |
+
|
724 |
+
[What are attention masks?](../glossary#attention-mask)
|
725 |
+
|
726 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
727 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
728 |
+
|
729 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
730 |
+
`past_key_values`).
|
731 |
+
|
732 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
733 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
734 |
+
information on the default strategy.
|
735 |
+
|
736 |
+
- 1 indicates the head is **not masked**,
|
737 |
+
- 0 indicates the head is **masked**.
|
738 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
739 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
740 |
+
config.n_positions - 1]`.
|
741 |
+
|
742 |
+
[What are position IDs?](../glossary#position-ids)
|
743 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
744 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
745 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
746 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
747 |
+
|
748 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
749 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
750 |
+
|
751 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
752 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
753 |
+
of shape `(batch_size, sequence_length)`.
|
754 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
755 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
756 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
757 |
+
model's internal embedding lookup matrix.
|
758 |
+
use_cache (`bool`, *optional*):
|
759 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
760 |
+
`past_key_values`).
|
761 |
+
output_attentions (`bool`, *optional*):
|
762 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
763 |
+
tensors for more detail.
|
764 |
+
output_hidden_states (`bool`, *optional*):
|
765 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
766 |
+
more detail.
|
767 |
+
return_dict (`bool`, *optional*):
|
768 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
769 |
+
"""
|
770 |
+
|
771 |
+
|
772 |
+
@add_start_docstrings(
|
773 |
+
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
774 |
+
LLAMA_START_DOCSTRING,
|
775 |
+
)
|
776 |
+
class LlamaModel(LlamaPreTrainedModel):
|
777 |
+
"""
|
778 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
779 |
+
|
780 |
+
Args:
|
781 |
+
config: LlamaConfig
|
782 |
+
"""
|
783 |
+
|
784 |
+
def __init__(self, config: LlamaConfig):
|
785 |
+
super().__init__(config)
|
786 |
+
self.padding_idx = config.pad_token_id
|
787 |
+
self.vocab_size = config.vocab_size
|
788 |
+
|
789 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
790 |
+
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
791 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
792 |
+
|
793 |
+
self.gradient_checkpointing = False
|
794 |
+
# Initialize weights and apply final processing
|
795 |
+
self.post_init()
|
796 |
+
|
797 |
+
def get_input_embeddings(self):
|
798 |
+
return self.embed_tokens
|
799 |
+
|
800 |
+
def set_input_embeddings(self, value):
|
801 |
+
self.embed_tokens = value
|
802 |
+
|
803 |
+
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
804 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
805 |
+
# create causal mask
|
806 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
807 |
+
combined_attention_mask = None
|
808 |
+
if input_shape[-1] > 1:
|
809 |
+
combined_attention_mask = _make_causal_mask(
|
810 |
+
input_shape,
|
811 |
+
inputs_embeds.dtype,
|
812 |
+
device=inputs_embeds.device,
|
813 |
+
past_key_values_length=past_key_values_length,
|
814 |
+
)
|
815 |
+
|
816 |
+
if attention_mask is not None:
|
817 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
818 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
819 |
+
inputs_embeds.device
|
820 |
+
)
|
821 |
+
combined_attention_mask = (
|
822 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
823 |
+
)
|
824 |
+
|
825 |
+
return combined_attention_mask
|
826 |
+
|
827 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
828 |
+
def forward(
|
829 |
+
self,
|
830 |
+
input_ids: torch.LongTensor = None,
|
831 |
+
attention_mask: Optional[torch.Tensor] = None,
|
832 |
+
position_ids: Optional[torch.LongTensor] = None,
|
833 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
834 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
835 |
+
use_cache: Optional[bool] = None,
|
836 |
+
output_attentions: Optional[bool] = None,
|
837 |
+
output_hidden_states: Optional[bool] = None,
|
838 |
+
return_dict: Optional[bool] = None,
|
839 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
840 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
841 |
+
output_hidden_states = (
|
842 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
843 |
+
)
|
844 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
845 |
+
|
846 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
847 |
+
|
848 |
+
# retrieve input_ids and inputs_embeds
|
849 |
+
if input_ids is not None and inputs_embeds is not None:
|
850 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
851 |
+
elif input_ids is not None:
|
852 |
+
batch_size, seq_length = input_ids.shape
|
853 |
+
elif inputs_embeds is not None:
|
854 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
855 |
+
else:
|
856 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
857 |
+
|
858 |
+
seq_length_with_past = seq_length
|
859 |
+
past_key_values_length = 0
|
860 |
+
|
861 |
+
if past_key_values is not None:
|
862 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
863 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
864 |
+
|
865 |
+
if position_ids is None:
|
866 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
867 |
+
position_ids = torch.arange(
|
868 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
869 |
+
)
|
870 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
871 |
+
else:
|
872 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
873 |
+
|
874 |
+
if inputs_embeds is None:
|
875 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
876 |
+
# embed positions
|
877 |
+
if attention_mask is None:
|
878 |
+
attention_mask = torch.ones(
|
879 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
880 |
+
)
|
881 |
+
padding_mask = None
|
882 |
+
else:
|
883 |
+
if 0 in attention_mask:
|
884 |
+
padding_mask = attention_mask
|
885 |
+
else:
|
886 |
+
padding_mask = None
|
887 |
+
|
888 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
889 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
890 |
+
)
|
891 |
+
|
892 |
+
hidden_states = inputs_embeds
|
893 |
+
|
894 |
+
if self.gradient_checkpointing and self.training:
|
895 |
+
if use_cache:
|
896 |
+
logger.warning_once(
|
897 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
898 |
+
)
|
899 |
+
use_cache = False
|
900 |
+
|
901 |
+
# decoder layers
|
902 |
+
all_hidden_states = () if output_hidden_states else None
|
903 |
+
all_self_attns = () if output_attentions else None
|
904 |
+
next_decoder_cache = () if use_cache else None
|
905 |
+
|
906 |
+
for idx, decoder_layer in enumerate(self.layers):
|
907 |
+
if output_hidden_states:
|
908 |
+
all_hidden_states += (hidden_states,)
|
909 |
+
|
910 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
911 |
+
|
912 |
+
if self.gradient_checkpointing and self.training:
|
913 |
+
|
914 |
+
def create_custom_forward(module):
|
915 |
+
def custom_forward(*inputs):
|
916 |
+
# None for past_key_value
|
917 |
+
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
|
918 |
+
|
919 |
+
return custom_forward
|
920 |
+
|
921 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
922 |
+
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
|
923 |
+
)
|
924 |
+
else:
|
925 |
+
layer_outputs = decoder_layer(
|
926 |
+
hidden_states,
|
927 |
+
attention_mask=attention_mask,
|
928 |
+
position_ids=position_ids,
|
929 |
+
past_key_value=past_key_value,
|
930 |
+
output_attentions=output_attentions,
|
931 |
+
use_cache=use_cache,
|
932 |
+
padding_mask=padding_mask,
|
933 |
+
)
|
934 |
+
|
935 |
+
hidden_states = layer_outputs[0]
|
936 |
+
|
937 |
+
if use_cache:
|
938 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
939 |
+
|
940 |
+
if output_attentions:
|
941 |
+
all_self_attns += (layer_outputs[1],)
|
942 |
+
|
943 |
+
hidden_states = self.norm(hidden_states)
|
944 |
+
|
945 |
+
# add hidden states from the last decoder layer
|
946 |
+
if output_hidden_states:
|
947 |
+
all_hidden_states += (hidden_states,)
|
948 |
+
|
949 |
+
next_cache = next_decoder_cache if use_cache else None
|
950 |
+
if not return_dict:
|
951 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
952 |
+
return BaseModelOutputWithPast(
|
953 |
+
last_hidden_state=hidden_states,
|
954 |
+
past_key_values=next_cache,
|
955 |
+
hidden_states=all_hidden_states,
|
956 |
+
attentions=all_self_attns,
|
957 |
+
)
|
958 |
+
|
959 |
+
|
960 |
+
class LlamaForCausalLM(LlamaPreTrainedModel):
|
961 |
+
_tied_weights_keys = ["lm_head.weight"]
|
962 |
+
|
963 |
+
def __init__(self, config):
|
964 |
+
super().__init__(config)
|
965 |
+
self.model = LlamaModel(config)
|
966 |
+
self.vocab_size = config.vocab_size
|
967 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
968 |
+
|
969 |
+
# Initialize weights and apply final processing
|
970 |
+
self.post_init()
|
971 |
+
|
972 |
+
def get_input_embeddings(self):
|
973 |
+
return self.model.embed_tokens
|
974 |
+
|
975 |
+
def set_input_embeddings(self, value):
|
976 |
+
self.model.embed_tokens = value
|
977 |
+
|
978 |
+
def get_output_embeddings(self):
|
979 |
+
return self.lm_head
|
980 |
+
|
981 |
+
def set_output_embeddings(self, new_embeddings):
|
982 |
+
self.lm_head = new_embeddings
|
983 |
+
|
984 |
+
def set_decoder(self, decoder):
|
985 |
+
self.model = decoder
|
986 |
+
|
987 |
+
def get_decoder(self):
|
988 |
+
return self.model
|
989 |
+
|
990 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
991 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
992 |
+
def forward(
|
993 |
+
self,
|
994 |
+
input_ids: torch.LongTensor = None,
|
995 |
+
attention_mask: Optional[torch.Tensor] = None,
|
996 |
+
position_ids: Optional[torch.LongTensor] = None,
|
997 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
998 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
999 |
+
labels: Optional[torch.LongTensor] = None,
|
1000 |
+
use_cache: Optional[bool] = None,
|
1001 |
+
output_attentions: Optional[bool] = None,
|
1002 |
+
output_hidden_states: Optional[bool] = None,
|
1003 |
+
return_dict: Optional[bool] = None,
|
1004 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
1005 |
+
r"""
|
1006 |
+
Args:
|
1007 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1008 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1009 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1010 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1011 |
+
|
1012 |
+
Returns:
|
1013 |
+
|
1014 |
+
Example:
|
1015 |
+
|
1016 |
+
```python
|
1017 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
1018 |
+
|
1019 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
1020 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
1021 |
+
|
1022 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
1023 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
1024 |
+
|
1025 |
+
>>> # Generate
|
1026 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
1027 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
1028 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1029 |
+
```"""
|
1030 |
+
|
1031 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1032 |
+
output_hidden_states = (
|
1033 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1034 |
+
)
|
1035 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1036 |
+
|
1037 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1038 |
+
outputs = self.model(
|
1039 |
+
input_ids=input_ids,
|
1040 |
+
attention_mask=attention_mask,
|
1041 |
+
position_ids=position_ids,
|
1042 |
+
past_key_values=past_key_values,
|
1043 |
+
inputs_embeds=inputs_embeds,
|
1044 |
+
use_cache=use_cache,
|
1045 |
+
output_attentions=output_attentions,
|
1046 |
+
output_hidden_states=output_hidden_states,
|
1047 |
+
return_dict=return_dict,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
hidden_states = outputs[0]
|
1051 |
+
if self.config.pretraining_tp > 1:
|
1052 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
1053 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
1054 |
+
logits = torch.cat(logits, dim=-1)
|
1055 |
+
else:
|
1056 |
+
logits = self.lm_head(hidden_states)
|
1057 |
+
logits = logits.float()
|
1058 |
+
|
1059 |
+
loss = None
|
1060 |
+
if labels is not None:
|
1061 |
+
# Shift so that tokens < n predict n
|
1062 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1063 |
+
shift_labels = labels[..., 1:].contiguous()
|
1064 |
+
# Flatten the tokens
|
1065 |
+
loss_fct = CrossEntropyLoss()
|
1066 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1067 |
+
shift_labels = shift_labels.view(-1)
|
1068 |
+
# Enable model parallelism
|
1069 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
1070 |
+
loss = loss_fct(shift_logits, shift_labels)
|
1071 |
+
|
1072 |
+
if not return_dict:
|
1073 |
+
output = (logits,) + outputs[1:]
|
1074 |
+
return (loss,) + output if loss is not None else output
|
1075 |
+
|
1076 |
+
return CausalLMOutputWithPast(
|
1077 |
+
loss=loss,
|
1078 |
+
logits=logits,
|
1079 |
+
past_key_values=outputs.past_key_values,
|
1080 |
+
hidden_states=outputs.hidden_states,
|
1081 |
+
attentions=outputs.attentions,
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
def prepare_inputs_for_generation(
|
1085 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
1086 |
+
):
|
1087 |
+
if past_key_values:
|
1088 |
+
input_ids = input_ids[:, -1:]
|
1089 |
+
|
1090 |
+
position_ids = kwargs.get("position_ids", None)
|
1091 |
+
if attention_mask is not None and position_ids is None:
|
1092 |
+
# create position_ids on the fly for batch generation
|
1093 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1094 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1095 |
+
if past_key_values:
|
1096 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
1097 |
+
|
1098 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1099 |
+
if inputs_embeds is not None and past_key_values is None:
|
1100 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1101 |
+
else:
|
1102 |
+
model_inputs = {"input_ids": input_ids}
|
1103 |
+
|
1104 |
+
model_inputs.update(
|
1105 |
+
{
|
1106 |
+
"position_ids": position_ids,
|
1107 |
+
"past_key_values": past_key_values,
|
1108 |
+
"use_cache": kwargs.get("use_cache"),
|
1109 |
+
"attention_mask": attention_mask,
|
1110 |
+
}
|
1111 |
+
)
|
1112 |
+
return model_inputs
|
1113 |
+
|
1114 |
+
@staticmethod
|
1115 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1116 |
+
reordered_past = ()
|
1117 |
+
for layer_past in past_key_values:
|
1118 |
+
reordered_past += (
|
1119 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1120 |
+
)
|
1121 |
+
return reordered_past
|
1122 |
+
|
1123 |
+
|
1124 |
+
@add_start_docstrings(
|
1125 |
+
"""
|
1126 |
+
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
1127 |
+
|
1128 |
+
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1129 |
+
(e.g. GPT-2) do.
|
1130 |
+
|
1131 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1132 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1133 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1134 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1135 |
+
each row of the batch).
|
1136 |
+
""",
|
1137 |
+
LLAMA_START_DOCSTRING,
|
1138 |
+
)
|
1139 |
+
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
1140 |
+
def __init__(self, config):
|
1141 |
+
super().__init__(config)
|
1142 |
+
self.num_labels = config.num_labels
|
1143 |
+
self.model = LlamaModel(config)
|
1144 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1145 |
+
|
1146 |
+
# Initialize weights and apply final processing
|
1147 |
+
self.post_init()
|
1148 |
+
|
1149 |
+
def get_input_embeddings(self):
|
1150 |
+
return self.model.embed_tokens
|
1151 |
+
|
1152 |
+
def set_input_embeddings(self, value):
|
1153 |
+
self.model.embed_tokens = value
|
1154 |
+
|
1155 |
+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1156 |
+
def forward(
|
1157 |
+
self,
|
1158 |
+
input_ids: torch.LongTensor = None,
|
1159 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1160 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1161 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1162 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1163 |
+
labels: Optional[torch.LongTensor] = None,
|
1164 |
+
use_cache: Optional[bool] = None,
|
1165 |
+
output_attentions: Optional[bool] = None,
|
1166 |
+
output_hidden_states: Optional[bool] = None,
|
1167 |
+
return_dict: Optional[bool] = None,
|
1168 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1169 |
+
r"""
|
1170 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1171 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1172 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1173 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1174 |
+
"""
|
1175 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1176 |
+
|
1177 |
+
transformer_outputs = self.model(
|
1178 |
+
input_ids,
|
1179 |
+
attention_mask=attention_mask,
|
1180 |
+
position_ids=position_ids,
|
1181 |
+
past_key_values=past_key_values,
|
1182 |
+
inputs_embeds=inputs_embeds,
|
1183 |
+
use_cache=use_cache,
|
1184 |
+
output_attentions=output_attentions,
|
1185 |
+
output_hidden_states=output_hidden_states,
|
1186 |
+
return_dict=return_dict,
|
1187 |
+
)
|
1188 |
+
hidden_states = transformer_outputs[0]
|
1189 |
+
logits = self.score(hidden_states)
|
1190 |
+
|
1191 |
+
if input_ids is not None:
|
1192 |
+
batch_size = input_ids.shape[0]
|
1193 |
+
else:
|
1194 |
+
batch_size = inputs_embeds.shape[0]
|
1195 |
+
|
1196 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1197 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1198 |
+
if self.config.pad_token_id is None:
|
1199 |
+
sequence_lengths = -1
|
1200 |
+
else:
|
1201 |
+
if input_ids is not None:
|
1202 |
+
sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
|
1203 |
+
logits.device
|
1204 |
+
)
|
1205 |
+
else:
|
1206 |
+
sequence_lengths = -1
|
1207 |
+
|
1208 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1209 |
+
|
1210 |
+
loss = None
|
1211 |
+
if labels is not None:
|
1212 |
+
labels = labels.to(logits.device)
|
1213 |
+
if self.config.problem_type is None:
|
1214 |
+
if self.num_labels == 1:
|
1215 |
+
self.config.problem_type = "regression"
|
1216 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1217 |
+
self.config.problem_type = "single_label_classification"
|
1218 |
+
else:
|
1219 |
+
self.config.problem_type = "multi_label_classification"
|
1220 |
+
|
1221 |
+
if self.config.problem_type == "regression":
|
1222 |
+
loss_fct = MSELoss()
|
1223 |
+
if self.num_labels == 1:
|
1224 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1225 |
+
else:
|
1226 |
+
loss = loss_fct(pooled_logits, labels)
|
1227 |
+
elif self.config.problem_type == "single_label_classification":
|
1228 |
+
loss_fct = CrossEntropyLoss()
|
1229 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1230 |
+
elif self.config.problem_type == "multi_label_classification":
|
1231 |
+
loss_fct = BCEWithLogitsLoss()
|
1232 |
+
loss = loss_fct(pooled_logits, labels)
|
1233 |
+
if not return_dict:
|
1234 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1235 |
+
return ((loss,) + output) if loss is not None else output
|
1236 |
+
|
1237 |
+
return SequenceClassifierOutputWithPast(
|
1238 |
+
loss=loss,
|
1239 |
+
logits=pooled_logits,
|
1240 |
+
past_key_values=transformer_outputs.past_key_values,
|
1241 |
+
hidden_states=transformer_outputs.hidden_states,
|
1242 |
+
attentions=transformer_outputs.attentions,
|
1243 |
+
)
|
codeclm/models/llama/tokenization_llama.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
|
21 |
+
"""Tokenization classes for LLaMA."""
|
22 |
+
import os
|
23 |
+
from shutil import copyfile
|
24 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
25 |
+
|
26 |
+
import sentencepiece as spm
|
27 |
+
|
28 |
+
from transformers.convert_slow_tokenizer import import_protobuf
|
29 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
30 |
+
from transformers.utils import logging
|
31 |
+
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from transformers.tokenization_utils_base import TextInput
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
39 |
+
|
40 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
41 |
+
"vocab_file": {
|
42 |
+
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
|
43 |
+
},
|
44 |
+
"tokenizer_file": {
|
45 |
+
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
|
46 |
+
},
|
47 |
+
}
|
48 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
49 |
+
"hf-internal-testing/llama-tokenizer": 2048,
|
50 |
+
}
|
51 |
+
SPIECE_UNDERLINE = "▁"
|
52 |
+
|
53 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
54 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
55 |
+
|
56 |
+
# fmt: off
|
57 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
58 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
59 |
+
that your responses are socially unbiased and positive in nature.
|
60 |
+
|
61 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
62 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
63 |
+
# fmt: on
|
64 |
+
|
65 |
+
|
66 |
+
class LlamaTokenizer(PreTrainedTokenizer):
|
67 |
+
"""
|
68 |
+
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
69 |
+
no padding token in the original model.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
vocab_file (`str`):
|
73 |
+
Path to the vocabulary file.
|
74 |
+
legacy (`bool`, *optional*):
|
75 |
+
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
|
76 |
+
and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
|
77 |
+
example:
|
78 |
+
|
79 |
+
- `legacy=True`:
|
80 |
+
```python
|
81 |
+
>>> from transformers import T5Tokenizer
|
82 |
+
|
83 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
|
84 |
+
>>> tokenizer.encode("Hello <extra_id_0>.")
|
85 |
+
[8774, 32099, 3, 5, 1]
|
86 |
+
```
|
87 |
+
- `legacy=False`:
|
88 |
+
```python
|
89 |
+
>>> from transformers import T5Tokenizer
|
90 |
+
|
91 |
+
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
|
92 |
+
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
|
93 |
+
[8774, 32099, 5, 1]
|
94 |
+
```
|
95 |
+
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
|
96 |
+
|
97 |
+
"""
|
98 |
+
|
99 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
100 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
101 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
102 |
+
model_input_names = ["input_ids", "attention_mask"]
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
vocab_file,
|
107 |
+
unk_token="<unk>",
|
108 |
+
bos_token="<s>",
|
109 |
+
eos_token="</s>",
|
110 |
+
pad_token=None,
|
111 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
112 |
+
add_bos_token=True,
|
113 |
+
add_eos_token=False,
|
114 |
+
clean_up_tokenization_spaces=False,
|
115 |
+
use_default_system_prompt=True,
|
116 |
+
spaces_between_special_tokens=False,
|
117 |
+
legacy=None,
|
118 |
+
**kwargs,
|
119 |
+
):
|
120 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
121 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
122 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
123 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
124 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
125 |
+
|
126 |
+
if legacy is None:
|
127 |
+
logger.warning_once(
|
128 |
+
f"You are using the default legacy behaviour of the {self.__class__}. This is"
|
129 |
+
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
|
130 |
+
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
|
131 |
+
" means, and thouroughly read the reason why this was added as explained in"
|
132 |
+
" https://github.com/huggingface/transformers/pull/24565"
|
133 |
+
)
|
134 |
+
legacy = True
|
135 |
+
|
136 |
+
self.legacy = legacy
|
137 |
+
self.vocab_file = vocab_file
|
138 |
+
self.add_bos_token = add_bos_token
|
139 |
+
self.add_eos_token = add_eos_token
|
140 |
+
self.use_default_system_prompt = use_default_system_prompt
|
141 |
+
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
|
142 |
+
|
143 |
+
super().__init__(
|
144 |
+
bos_token=bos_token,
|
145 |
+
eos_token=eos_token,
|
146 |
+
unk_token=unk_token,
|
147 |
+
pad_token=pad_token,
|
148 |
+
add_bos_token=add_bos_token,
|
149 |
+
add_eos_token=add_eos_token,
|
150 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
151 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
152 |
+
use_default_system_prompt=use_default_system_prompt,
|
153 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
154 |
+
legacy=legacy,
|
155 |
+
**kwargs,
|
156 |
+
)
|
157 |
+
|
158 |
+
@property
|
159 |
+
def unk_token_length(self):
|
160 |
+
return len(self.sp_model.encode(str(self.unk_token)))
|
161 |
+
|
162 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
|
163 |
+
def get_spm_processor(self, from_slow=False):
|
164 |
+
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
165 |
+
if self.legacy or from_slow: # no dependency on protobuf
|
166 |
+
tokenizer.Load(self.vocab_file)
|
167 |
+
return tokenizer
|
168 |
+
|
169 |
+
with open(self.vocab_file, "rb") as f:
|
170 |
+
sp_model = f.read()
|
171 |
+
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
|
172 |
+
model = model_pb2.ModelProto.FromString(sp_model)
|
173 |
+
normalizer_spec = model_pb2.NormalizerSpec()
|
174 |
+
normalizer_spec.add_dummy_prefix = False
|
175 |
+
model.normalizer_spec.MergeFrom(normalizer_spec)
|
176 |
+
sp_model = model.SerializeToString()
|
177 |
+
tokenizer.LoadFromSerializedProto(sp_model)
|
178 |
+
return tokenizer
|
179 |
+
|
180 |
+
def __getstate__(self):
|
181 |
+
state = self.__dict__.copy()
|
182 |
+
state["sp_model"] = None
|
183 |
+
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
184 |
+
return state
|
185 |
+
|
186 |
+
def __setstate__(self, d):
|
187 |
+
self.__dict__ = d
|
188 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
189 |
+
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
190 |
+
|
191 |
+
@property
|
192 |
+
def vocab_size(self):
|
193 |
+
"""Returns vocab size"""
|
194 |
+
return self.sp_model.get_piece_size()
|
195 |
+
|
196 |
+
def get_vocab(self):
|
197 |
+
"""Returns vocab as a dict"""
|
198 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
199 |
+
vocab.update(self.added_tokens_encoder)
|
200 |
+
return vocab
|
201 |
+
|
202 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
|
203 |
+
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
|
204 |
+
"""
|
205 |
+
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
206 |
+
first token is special.
|
207 |
+
"""
|
208 |
+
if self.legacy or len(text) == 0:
|
209 |
+
return super().tokenize(text, **kwargs)
|
210 |
+
|
211 |
+
tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
|
212 |
+
|
213 |
+
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
214 |
+
tokens = tokens[1:]
|
215 |
+
return tokens
|
216 |
+
|
217 |
+
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
|
218 |
+
def _tokenize(self, text, **kwargs):
|
219 |
+
"""
|
220 |
+
Returns a tokenized string.
|
221 |
+
|
222 |
+
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
|
223 |
+
SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
|
224 |
+
`['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
|
225 |
+
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
226 |
+
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
227 |
+
"""
|
228 |
+
tokens = self.sp_model.encode(text, out_type=str)
|
229 |
+
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
230 |
+
return tokens
|
231 |
+
|
232 |
+
# 1. Encode string + prefix ex: "<unk> Hey"
|
233 |
+
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
234 |
+
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
235 |
+
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
236 |
+
|
237 |
+
def _convert_token_to_id(self, token):
|
238 |
+
"""Converts a token (str) in an id using the vocab."""
|
239 |
+
return self.sp_model.piece_to_id(token)
|
240 |
+
|
241 |
+
def _convert_id_to_token(self, index):
|
242 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
243 |
+
token = self.sp_model.IdToPiece(index)
|
244 |
+
return token
|
245 |
+
|
246 |
+
def convert_tokens_to_string(self, tokens):
|
247 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
248 |
+
# since we manually add the prefix space, we have to remove it when decoding
|
249 |
+
if tokens[0].startswith(SPIECE_UNDERLINE):
|
250 |
+
tokens[0] = tokens[0][1:]
|
251 |
+
|
252 |
+
current_sub_tokens = []
|
253 |
+
out_string = ""
|
254 |
+
prev_is_special = False
|
255 |
+
for i, token in enumerate(tokens):
|
256 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
257 |
+
if token in self.all_special_tokens:
|
258 |
+
if not prev_is_special and i != 0 and self.legacy:
|
259 |
+
out_string += " "
|
260 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
261 |
+
prev_is_special = True
|
262 |
+
current_sub_tokens = []
|
263 |
+
else:
|
264 |
+
current_sub_tokens.append(token)
|
265 |
+
prev_is_special = False
|
266 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
267 |
+
return out_string
|
268 |
+
|
269 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
270 |
+
"""
|
271 |
+
Save the vocabulary and special tokens file to a directory.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
save_directory (`str`):
|
275 |
+
The directory in which to save the vocabulary.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
`Tuple(str)`: Paths to the files saved.
|
279 |
+
"""
|
280 |
+
if not os.path.isdir(save_directory):
|
281 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
282 |
+
return
|
283 |
+
out_vocab_file = os.path.join(
|
284 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
285 |
+
)
|
286 |
+
|
287 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
288 |
+
copyfile(self.vocab_file, out_vocab_file)
|
289 |
+
elif not os.path.isfile(self.vocab_file):
|
290 |
+
with open(out_vocab_file, "wb") as fi:
|
291 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
292 |
+
fi.write(content_spiece_model)
|
293 |
+
|
294 |
+
return (out_vocab_file,)
|
295 |
+
|
296 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
297 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
298 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
299 |
+
|
300 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
301 |
+
|
302 |
+
if token_ids_1 is not None:
|
303 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
304 |
+
|
305 |
+
return output
|
306 |
+
|
307 |
+
def get_special_tokens_mask(
|
308 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
309 |
+
) -> List[int]:
|
310 |
+
"""
|
311 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
312 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
token_ids_0 (`List[int]`):
|
316 |
+
List of IDs.
|
317 |
+
token_ids_1 (`List[int]`, *optional*):
|
318 |
+
Optional second list of IDs for sequence pairs.
|
319 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
320 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
324 |
+
"""
|
325 |
+
if already_has_special_tokens:
|
326 |
+
return super().get_special_tokens_mask(
|
327 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
328 |
+
)
|
329 |
+
|
330 |
+
bos_token_id = [1] if self.add_bos_token else []
|
331 |
+
eos_token_id = [1] if self.add_eos_token else []
|
332 |
+
|
333 |
+
if token_ids_1 is None:
|
334 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
335 |
+
return (
|
336 |
+
bos_token_id
|
337 |
+
+ ([0] * len(token_ids_0))
|
338 |
+
+ eos_token_id
|
339 |
+
+ bos_token_id
|
340 |
+
+ ([0] * len(token_ids_1))
|
341 |
+
+ eos_token_id
|
342 |
+
)
|
343 |
+
|
344 |
+
def create_token_type_ids_from_sequences(
|
345 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
346 |
+
) -> List[int]:
|
347 |
+
"""
|
348 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
349 |
+
sequence pair mask has the following format:
|
350 |
+
|
351 |
+
```
|
352 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
353 |
+
| first sequence | second sequence |
|
354 |
+
```
|
355 |
+
|
356 |
+
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
357 |
+
|
358 |
+
Args:
|
359 |
+
token_ids_0 (`List[int]`):
|
360 |
+
List of ids.
|
361 |
+
token_ids_1 (`List[int]`, *optional*):
|
362 |
+
Optional second list of IDs for sequence pairs.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
366 |
+
"""
|
367 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
368 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
369 |
+
|
370 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
371 |
+
|
372 |
+
if token_ids_1 is not None:
|
373 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
374 |
+
|
375 |
+
return output
|
376 |
+
|
377 |
+
@property
|
378 |
+
def default_chat_template(self):
|
379 |
+
"""
|
380 |
+
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
381 |
+
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
382 |
+
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
383 |
+
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
384 |
+
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
385 |
+
to fine-tune a model with more flexible role ordering!
|
386 |
+
|
387 |
+
The output should look something like:
|
388 |
+
|
389 |
+
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
390 |
+
<bos>[INST] Prompt [/INST]
|
391 |
+
"""
|
392 |
+
|
393 |
+
template = (
|
394 |
+
"{% if messages[0]['role'] == 'system' %}"
|
395 |
+
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
396 |
+
"{% set system_message = messages[0]['content'] %}"
|
397 |
+
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
398 |
+
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
399 |
+
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
400 |
+
"{% else %}"
|
401 |
+
"{% set loop_messages = messages %}"
|
402 |
+
"{% set system_message = false %}"
|
403 |
+
"{% endif %}"
|
404 |
+
"{% for message in loop_messages %}" # Loop over all non-system messages
|
405 |
+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
406 |
+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
407 |
+
"{% endif %}"
|
408 |
+
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
409 |
+
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
410 |
+
"{% else %}"
|
411 |
+
"{% set content = message['content'] %}"
|
412 |
+
"{% endif %}"
|
413 |
+
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
414 |
+
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
415 |
+
"{% elif message['role'] == 'system' %}"
|
416 |
+
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
417 |
+
"{% elif message['role'] == 'assistant' %}"
|
418 |
+
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
419 |
+
"{% endif %}"
|
420 |
+
"{% endfor %}"
|
421 |
+
)
|
422 |
+
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
423 |
+
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
424 |
+
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
425 |
+
|
426 |
+
return template
|
codeclm/models/llama/tokenization_llama_fast.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import os
|
16 |
+
from shutil import copyfile
|
17 |
+
from typing import Optional, Tuple
|
18 |
+
|
19 |
+
from tokenizers import processors
|
20 |
+
|
21 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
22 |
+
from transformers.utils import is_sentencepiece_available, logging
|
23 |
+
from transformers.utils.versions import require_version
|
24 |
+
|
25 |
+
|
26 |
+
require_version("tokenizers>=0.13.3")
|
27 |
+
|
28 |
+
if is_sentencepiece_available():
|
29 |
+
from .tokenization_llama import LlamaTokenizer
|
30 |
+
else:
|
31 |
+
LlamaTokenizer = None
|
32 |
+
|
33 |
+
logger = logging.get_logger(__name__)
|
34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
|
35 |
+
|
36 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
37 |
+
"vocab_file": {
|
38 |
+
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
|
39 |
+
},
|
40 |
+
"tokenizer_file": {
|
41 |
+
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
|
42 |
+
},
|
43 |
+
}
|
44 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
45 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
46 |
+
|
47 |
+
# fmt: off
|
48 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
49 |
+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
50 |
+
that your responses are socially unbiased and positive in nature.
|
51 |
+
|
52 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
53 |
+
correct. If you don't know the answer to a question, please don't share false information."""
|
54 |
+
# fmt: on
|
55 |
+
|
56 |
+
|
57 |
+
class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
58 |
+
"""
|
59 |
+
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
|
60 |
+
|
61 |
+
This uses notably ByteFallback and no normalization.
|
62 |
+
|
63 |
+
```
|
64 |
+
from transformers import LlamaTokenizerFast
|
65 |
+
|
66 |
+
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
67 |
+
tokenizer.encode("Hello this is a test")
|
68 |
+
>>> [1, 15043, 445, 338, 263, 1243]
|
69 |
+
```
|
70 |
+
|
71 |
+
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
|
72 |
+
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
|
73 |
+
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
|
74 |
+
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
|
75 |
+
|
76 |
+
|
77 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
78 |
+
refer to this superclass for more information regarding those methods.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
vocab_file (`str`):
|
82 |
+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
83 |
+
contains the vocabulary necessary to instantiate a tokenizer.
|
84 |
+
tokenizer_file (`str`):
|
85 |
+
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
|
86 |
+
contains everything needed to load the tokenizer.
|
87 |
+
|
88 |
+
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
|
89 |
+
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
|
90 |
+
spaces.
|
91 |
+
|
92 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
93 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
94 |
+
|
95 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
96 |
+
The end of sequence token.
|
97 |
+
|
98 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
99 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
100 |
+
token instead.
|
101 |
+
"""
|
102 |
+
|
103 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
104 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
105 |
+
slow_tokenizer_class = LlamaTokenizer
|
106 |
+
padding_side = "left"
|
107 |
+
model_input_names = ["input_ids", "attention_mask"]
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
vocab_file=None,
|
112 |
+
tokenizer_file=None,
|
113 |
+
clean_up_tokenization_spaces=False,
|
114 |
+
unk_token="<unk>",
|
115 |
+
bos_token="<s>",
|
116 |
+
eos_token="</s>",
|
117 |
+
add_bos_token=True,
|
118 |
+
add_eos_token=False,
|
119 |
+
use_default_system_prompt=True,
|
120 |
+
**kwargs,
|
121 |
+
):
|
122 |
+
super().__init__(
|
123 |
+
vocab_file=vocab_file,
|
124 |
+
tokenizer_file=tokenizer_file,
|
125 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
126 |
+
unk_token=unk_token,
|
127 |
+
bos_token=bos_token,
|
128 |
+
eos_token=eos_token,
|
129 |
+
use_default_system_prompt=use_default_system_prompt,
|
130 |
+
**kwargs,
|
131 |
+
)
|
132 |
+
self._add_bos_token = add_bos_token
|
133 |
+
self._add_eos_token = add_eos_token
|
134 |
+
self.update_post_processor()
|
135 |
+
self.use_default_system_prompt = use_default_system_prompt
|
136 |
+
self.vocab_file = vocab_file
|
137 |
+
|
138 |
+
@property
|
139 |
+
def can_save_slow_tokenizer(self) -> bool:
|
140 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
141 |
+
|
142 |
+
def update_post_processor(self):
|
143 |
+
"""
|
144 |
+
Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
145 |
+
"""
|
146 |
+
bos = self.bos_token
|
147 |
+
bos_token_id = self.bos_token_id
|
148 |
+
|
149 |
+
eos = self.eos_token
|
150 |
+
eos_token_id = self.eos_token_id
|
151 |
+
|
152 |
+
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
|
153 |
+
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
|
154 |
+
|
155 |
+
special_tokens = []
|
156 |
+
if self.add_bos_token:
|
157 |
+
special_tokens.append((bos, bos_token_id))
|
158 |
+
if self.add_eos_token:
|
159 |
+
special_tokens.append((eos, eos_token_id))
|
160 |
+
self._tokenizer.post_processor = processors.TemplateProcessing(
|
161 |
+
single=single, pair=pair, special_tokens=special_tokens
|
162 |
+
)
|
163 |
+
|
164 |
+
@property
|
165 |
+
def add_eos_token(self):
|
166 |
+
return self._add_eos_token
|
167 |
+
|
168 |
+
@property
|
169 |
+
def add_bos_token(self):
|
170 |
+
return self._add_bos_token
|
171 |
+
|
172 |
+
@add_eos_token.setter
|
173 |
+
def add_eos_token(self, value):
|
174 |
+
self._add_eos_token = value
|
175 |
+
self.update_post_processor()
|
176 |
+
|
177 |
+
@add_bos_token.setter
|
178 |
+
def add_bos_token(self, value):
|
179 |
+
self._add_bos_token = value
|
180 |
+
self.update_post_processor()
|
181 |
+
|
182 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
183 |
+
if not self.can_save_slow_tokenizer:
|
184 |
+
raise ValueError(
|
185 |
+
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
186 |
+
"tokenizer."
|
187 |
+
)
|
188 |
+
|
189 |
+
if not os.path.isdir(save_directory):
|
190 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
191 |
+
return
|
192 |
+
out_vocab_file = os.path.join(
|
193 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
194 |
+
)
|
195 |
+
|
196 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
197 |
+
copyfile(self.vocab_file, out_vocab_file)
|
198 |
+
|
199 |
+
return (out_vocab_file,)
|
200 |
+
|
201 |
+
@property
|
202 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template
|
203 |
+
def default_chat_template(self):
|
204 |
+
"""
|
205 |
+
LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
|
206 |
+
Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
|
207 |
+
user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
|
208 |
+
rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
|
209 |
+
results in an unusual token ordering when it is present. This template should definitely be changed if you wish
|
210 |
+
to fine-tune a model with more flexible role ordering!
|
211 |
+
|
212 |
+
The output should look something like:
|
213 |
+
|
214 |
+
<bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos> <bos>[INST] Prompt [/INST] Answer <eos>
|
215 |
+
<bos>[INST] Prompt [/INST]
|
216 |
+
"""
|
217 |
+
|
218 |
+
template = (
|
219 |
+
"{% if messages[0]['role'] == 'system' %}"
|
220 |
+
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
221 |
+
"{% set system_message = messages[0]['content'] %}"
|
222 |
+
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
223 |
+
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
224 |
+
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
225 |
+
"{% else %}"
|
226 |
+
"{% set loop_messages = messages %}"
|
227 |
+
"{% set system_message = false %}"
|
228 |
+
"{% endif %}"
|
229 |
+
"{% for message in loop_messages %}" # Loop over all non-system messages
|
230 |
+
"{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
|
231 |
+
"{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
|
232 |
+
"{% endif %}"
|
233 |
+
"{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message
|
234 |
+
"{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
|
235 |
+
"{% else %}"
|
236 |
+
"{% set content = message['content'] %}"
|
237 |
+
"{% endif %}"
|
238 |
+
"{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way
|
239 |
+
"{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
|
240 |
+
"{% elif message['role'] == 'system' %}"
|
241 |
+
"{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
|
242 |
+
"{% elif message['role'] == 'assistant' %}"
|
243 |
+
"{{ ' ' + content.strip() + ' ' + eos_token }}"
|
244 |
+
"{% endif %}"
|
245 |
+
"{% endfor %}"
|
246 |
+
)
|
247 |
+
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
248 |
+
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
249 |
+
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
250 |
+
|
251 |
+
return template
|
252 |
+
|
253 |
+
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
|
254 |
+
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens
|
255 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
256 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
257 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
258 |
+
|
259 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
260 |
+
|
261 |
+
if token_ids_1 is not None:
|
262 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
263 |
+
|
264 |
+
return output
|
codeclm/models/lm_levo.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import torch.nn as nn
|
6 |
+
import typing as tp
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from codeclm.models.levo import CausalLM, LlamaConfig
|
10 |
+
from codeclm.modules.streaming import StreamingModule
|
11 |
+
from codeclm.modules.conditioners import (
|
12 |
+
ConditioningAttributes,
|
13 |
+
AudioCondition,
|
14 |
+
ConditionType,
|
15 |
+
ConditionerProvider,
|
16 |
+
ConditionFuser,
|
17 |
+
ClassifierFreeGuidanceDropoutInference,
|
18 |
+
ClassifierFreeGuidanceDropout,
|
19 |
+
AttributeDropout,
|
20 |
+
)
|
21 |
+
from codeclm.utils.utils import create_norm_fn, init_layer, sample_top_k, sample_top_p, multinomial
|
22 |
+
from codeclm.modules.pattern import CodebooksPatternProvider
|
23 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class LMOutput:
|
27 |
+
# The logits are already re-aligned with the input codes
|
28 |
+
# hence no extra shift is required, e.g. when computing CE
|
29 |
+
logits: torch.Tensor # [B, K, T, card]
|
30 |
+
mask: torch.Tensor # [B, K, T]
|
31 |
+
|
32 |
+
|
33 |
+
class LmModel(StreamingModule):
|
34 |
+
"""Transformer-based language model on multiple streams of codes.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
38 |
+
condition_provider (ConditioningProvider): Conditioning provider from metadata.
|
39 |
+
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
40 |
+
code_depth (int): Number of parallel streams to model.
|
41 |
+
code_size (int): Cardinality, vocabulary size.
|
42 |
+
dim (int): Dimension of the transformer encoder.
|
43 |
+
num_heads (int): Number of heads for the transformer encoder.
|
44 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
45 |
+
norm (str): Normalization method.
|
46 |
+
norm_first (bool): Use pre-norm instead of post-norm.
|
47 |
+
emb_lr (float, optional): Embedding-specific learning rate.
|
48 |
+
bias_proj (bool): Use bias for output projections.
|
49 |
+
weight_init (str, optional): Method for weight initialization.
|
50 |
+
depthwise_init (str, optional): Method for depthwise weight initialization.
|
51 |
+
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
52 |
+
cfg_dropout (float): Classifier-free guidance dropout.
|
53 |
+
cfg_coef (float): Classifier-free guidance coefficient.
|
54 |
+
attribute_dropout (dict): Attribute dropout probabilities.
|
55 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
56 |
+
**kwargs: Additional parameters for the transformer encoder.
|
57 |
+
"""
|
58 |
+
def __init__(self,
|
59 |
+
pattern_provider: CodebooksPatternProvider,
|
60 |
+
condition_provider: ConditionerProvider,
|
61 |
+
fuser: ConditionFuser,
|
62 |
+
code_depth: int = 8,
|
63 |
+
code_size: int = 1024,
|
64 |
+
dim: int = 128,
|
65 |
+
intermediate_size: int = 4096,
|
66 |
+
num_heads: int = 8,
|
67 |
+
norm: str = 'layer_norm', norm_first: bool = False,
|
68 |
+
bias_proj: bool = True,
|
69 |
+
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
70 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
71 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
|
72 |
+
lm_type = 'Llama',
|
73 |
+
num_layers=16,
|
74 |
+
cfg = None,
|
75 |
+
**kwargs):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.cfg_coef = cfg_coef
|
79 |
+
|
80 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout,seed=random.randint(0, 9999))
|
81 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout,seed=random.randint(0, 9999))
|
82 |
+
self.condition_provider = condition_provider
|
83 |
+
self.fuser = fuser
|
84 |
+
self.code_size = code_size + 1 # + EOS
|
85 |
+
input_emb_dim = code_size + 2 # EOP
|
86 |
+
self.code_depth = code_depth
|
87 |
+
self.dim = dim
|
88 |
+
self.cfg = cfg
|
89 |
+
self.pattern_provider = pattern_provider
|
90 |
+
self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
|
91 |
+
# if 'activation' in kwargs:
|
92 |
+
# kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
93 |
+
|
94 |
+
model_cfg = LlamaConfig(
|
95 |
+
hidden_size=dim,
|
96 |
+
intermediate_size = intermediate_size,
|
97 |
+
num_attention_heads = num_heads,
|
98 |
+
num_hidden_layers = num_layers,
|
99 |
+
num_key_value_heads = num_heads,
|
100 |
+
vocab_size = self.code_size,
|
101 |
+
use_cache=False,
|
102 |
+
max_position_embeddings=8196,
|
103 |
+
_flash_attn_2_enabled=True,
|
104 |
+
rms_norm_eps= 1e-5,
|
105 |
+
rope_theta= 100000.0,
|
106 |
+
use_flash_attn_2=True,
|
107 |
+
attn_implementation="flash_attention_2"
|
108 |
+
)
|
109 |
+
|
110 |
+
self.transformer = CausalLM(model_cfg)
|
111 |
+
self.mlp = nn.Sequential(
|
112 |
+
nn.Linear(dim * 2, dim),
|
113 |
+
nn.GELU(),
|
114 |
+
nn.Linear(dim, dim)
|
115 |
+
)
|
116 |
+
self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim) #, lr=emb_lr)
|
117 |
+
for _ in range(self.code_depth)])
|
118 |
+
sub_model_cfg = LlamaConfig(
|
119 |
+
hidden_size=dim,
|
120 |
+
intermediate_size = intermediate_size,
|
121 |
+
num_attention_heads = num_heads,
|
122 |
+
num_hidden_layers = 12,
|
123 |
+
num_key_value_heads = num_heads,
|
124 |
+
vocab_size = self.code_size,
|
125 |
+
use_cache=False,
|
126 |
+
max_position_embeddings=10000,
|
127 |
+
rms_norm_eps= 1e-5,
|
128 |
+
rope_theta= 500000.0,
|
129 |
+
_flash_attn_2_enabled=True,
|
130 |
+
use_flash_attn_2=True,
|
131 |
+
attn_implementation="flash_attention_2"
|
132 |
+
)
|
133 |
+
self.transformer2 = CausalLM(sub_model_cfg)
|
134 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
135 |
+
if norm_first:
|
136 |
+
self.out_norm = create_norm_fn(norm, dim)
|
137 |
+
# enable EOS prediction
|
138 |
+
if code_depth > 1:
|
139 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.code_size, bias=False)
|
140 |
+
for _ in range(code_depth - 1)])
|
141 |
+
|
142 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
143 |
+
self._fsdp: tp.Optional[nn.Module]
|
144 |
+
self.__dict__['_fsdp'] = None
|
145 |
+
|
146 |
+
self.reset_streaming()
|
147 |
+
|
148 |
+
def _init_weights(self, weight_init: tp.Optional[str],
|
149 |
+
depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
150 |
+
"""Initialization of the transformer module weights.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
154 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
155 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
156 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
157 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
158 |
+
"""
|
159 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
160 |
+
assert depthwise_init is None or weight_init is not None, \
|
161 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
162 |
+
assert not zero_bias_init or weight_init is not None, \
|
163 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
164 |
+
|
165 |
+
if weight_init is None:
|
166 |
+
return
|
167 |
+
|
168 |
+
for emb_layer in self.emb:
|
169 |
+
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
170 |
+
|
171 |
+
|
172 |
+
@property
|
173 |
+
def special_token_id(self) -> int:
|
174 |
+
return self.code_size # 10001
|
175 |
+
|
176 |
+
@property
|
177 |
+
def eos_token_id(self) -> int:
|
178 |
+
return self.code_size-1 # 10000
|
179 |
+
|
180 |
+
@torch.no_grad()
|
181 |
+
def prepare_condition_tensors(self,
|
182 |
+
batch_size = 1,
|
183 |
+
text: tp.Optional[tp.List[str]] = None,
|
184 |
+
descriptions: tp.Optional[tp.List[str]] = None,
|
185 |
+
audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None,
|
186 |
+
prepare_null_condition = False,
|
187 |
+
):
|
188 |
+
if self.training:
|
189 |
+
attributes = []
|
190 |
+
for i in range(batch_size):
|
191 |
+
attr = ConditioningAttributes()
|
192 |
+
if 'description' in self.condition_provider.conditioners:
|
193 |
+
attr["text"]["description"] = ""
|
194 |
+
if text is not None:
|
195 |
+
attr["text"]["description"] = text[i]
|
196 |
+
if 'prompt_audio' in self.condition_provider.conditioners:
|
197 |
+
mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1)
|
198 |
+
audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1)
|
199 |
+
mask = mask.repeat(1, 1, audio_qt_seq.shape[-1])
|
200 |
+
audio_qt_seq[mask] = 16385
|
201 |
+
attr["audio"]['prompt_audio'] = AudioCondition(
|
202 |
+
wav=audio_qt_seq.long(),
|
203 |
+
length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
|
204 |
+
sample_rate=[self.cfg.sample_rate],)
|
205 |
+
if 'type_info' in self.condition_provider.conditioners:
|
206 |
+
attr["text"]["type_info"] = ""
|
207 |
+
if descriptions is not None:
|
208 |
+
attr["text"]["type_info"] = descriptions[i]
|
209 |
+
attributes.append(attr)
|
210 |
+
# print("before cfg dropout", attributes)
|
211 |
+
attributes = self.cfg_dropout(attributes) # drop ALL conditions
|
212 |
+
# print("after cfg dropout", attributes)
|
213 |
+
attributes = self.att_dropout(attributes) # selectively drop some attributes (text, wav, or more fine-grained)
|
214 |
+
# print("after attribute dropout", attributes)
|
215 |
+
# attribute to discrete tokenized ids
|
216 |
+
tokenized = self.condition_provider.tokenize(attributes)
|
217 |
+
# print("after tokenize", attributes)
|
218 |
+
# discrete tokenized ids to continuous embeddings
|
219 |
+
condition_tensors = self.condition_provider(tokenized)
|
220 |
+
else:
|
221 |
+
conditions = []
|
222 |
+
for i in range(batch_size):
|
223 |
+
attr = ConditioningAttributes()
|
224 |
+
if 'description' in self.condition_provider.conditioners:
|
225 |
+
attr["text"]["description"] = ""
|
226 |
+
if text is not None:
|
227 |
+
attr["text"]["description"] = text[i]
|
228 |
+
if 'prompt_audio' in self.condition_provider.conditioners:
|
229 |
+
mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1)
|
230 |
+
audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1)
|
231 |
+
mask = mask.repeat(1, 1, audio_qt_seq.shape[-1])
|
232 |
+
audio_qt_seq[mask] = 16385
|
233 |
+
attr["audio"]['prompt_audio'] = AudioCondition(
|
234 |
+
wav=audio_qt_seq.long().cuda(),
|
235 |
+
length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
|
236 |
+
sample_rate=[self.cfg.sample_rate],)
|
237 |
+
if 'type_info' in self.condition_provider.conditioners:
|
238 |
+
attr["text"]["type_info"] = ""
|
239 |
+
if descriptions is not None:
|
240 |
+
attr["text"]["type_info"] = descriptions[i]
|
241 |
+
conditions.append(attr)
|
242 |
+
print("conditions", conditions)
|
243 |
+
if prepare_null_condition:
|
244 |
+
cfg_inference = ClassifierFreeGuidanceDropoutInference()
|
245 |
+
null_conditions = cfg_inference(conditions, condition_types=["audio", "text"],
|
246 |
+
customized=None)
|
247 |
+
conditions = conditions + null_conditions
|
248 |
+
tokenized_conditions = self.condition_provider.tokenize(conditions)
|
249 |
+
condition_tensors = self.condition_provider(tokenized_conditions)
|
250 |
+
return condition_tensors
|
251 |
+
|
252 |
+
def forward(self,
|
253 |
+
sequence: torch.Tensor,
|
254 |
+
condition_tensors: ConditionTensors) -> torch.Tensor:
|
255 |
+
"""Apply language model on sequence and conditions.
|
256 |
+
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
257 |
+
S the sequence steps, return the logits with shape [B, card, K, S].
|
258 |
+
|
259 |
+
Args:
|
260 |
+
indices (torch.Tensor): Indices of the codes to model.
|
261 |
+
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
262 |
+
tensors, see `conditions`.
|
263 |
+
Returns:
|
264 |
+
torch.Tensor: Logits.
|
265 |
+
"""
|
266 |
+
|
267 |
+
# import pdb; pdb.set_trace()
|
268 |
+
B, K, S = sequence.shape
|
269 |
+
assert K == self.code_depth, "Sequence shape must match the specified number of codebooks"
|
270 |
+
input_1 = self.emb[0](sequence[:, 0])
|
271 |
+
input_2 = sum([self.layer2_emb[k](sequence[:, k]) for k in range(1, K)])
|
272 |
+
fused_input1, fused_input2 = self.fuser(input_1, input_2, condition_tensors)
|
273 |
+
output = self.transformer(inputs_embeds=fused_input1,
|
274 |
+
use_cache=self._is_streaming,
|
275 |
+
past_key_values=self._streaming_state.get('past_key_values_1', None))
|
276 |
+
if self._is_streaming:
|
277 |
+
self._streaming_state['past_key_values_1'] = output.past_key_values
|
278 |
+
logits = output.logits # [B, S, card]
|
279 |
+
logits = logits.unsqueeze(1) # [B, 1, S, card]
|
280 |
+
|
281 |
+
# if self.out_norm:
|
282 |
+
# out = self.out_norm(out.to(self.out_norm.weight.data.dtype))
|
283 |
+
if K > 1:
|
284 |
+
fused_input2 = torch.cat([fused_input2, output.hidden_states], dim=-1)
|
285 |
+
fused_input2 = self.mlp(fused_input2)
|
286 |
+
output2 = self.transformer2(inputs_embeds=fused_input2,
|
287 |
+
use_cache=self._is_streaming,
|
288 |
+
past_key_values=self._streaming_state.get('past_key_values_2', None))
|
289 |
+
if self._is_streaming:
|
290 |
+
self._streaming_state['past_key_values_2'] = output2.past_key_values
|
291 |
+
|
292 |
+
res_logits = torch.stack([self.linears[k](output2.hidden_states) for k in range(K - 1)], dim=1) # [B, K, S, card] # [B, K, S, card]
|
293 |
+
logits = torch.cat([logits, res_logits], dim=1) # [B, K, S, card]
|
294 |
+
|
295 |
+
# remove the prefix from the model outputs
|
296 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
297 |
+
logits = logits[:, :, -S:, :]
|
298 |
+
|
299 |
+
return logits # [B, K, S, card]
|
300 |
+
|
301 |
+
def compute_predictions(self,
|
302 |
+
codes: torch.Tensor,
|
303 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
304 |
+
**kwargs,
|
305 |
+
): # this function is called during training
|
306 |
+
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
307 |
+
forward using the specified codes interleaving pattern.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
311 |
+
K the number of codebooks and T the number of timesteps.
|
312 |
+
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
313 |
+
tensors, see `conditions`.
|
314 |
+
Returns:
|
315 |
+
LMOutput: Language model outputs
|
316 |
+
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
317 |
+
i.e. the first item corresponds to logits to predict the first code, meaning that
|
318 |
+
no additional shifting of codes and logits is required.
|
319 |
+
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
320 |
+
Given the specified interleaving strategies, parts of the logits and codes should
|
321 |
+
not be considered as valid predictions because of invalid context.
|
322 |
+
"""
|
323 |
+
B, K, T = codes.shape
|
324 |
+
codes = codes.contiguous()
|
325 |
+
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
326 |
+
pattern = self.pattern_provider.get_pattern(T)
|
327 |
+
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
328 |
+
codes, self.special_token_id, keep_only_valid_steps=False
|
329 |
+
)
|
330 |
+
model = self if self._fsdp is None else self._fsdp
|
331 |
+
logits = model(sequence_codes, condition_tensors) # [B, K, S, card]
|
332 |
+
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
333 |
+
# and provide the corresponding mask over invalid positions of tokens
|
334 |
+
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
335 |
+
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
336 |
+
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
337 |
+
logits, float('nan'), keep_only_valid_steps=False
|
338 |
+
)
|
339 |
+
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
340 |
+
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
341 |
+
|
342 |
+
return LMOutput(logits, logits_mask)
|
343 |
+
|
344 |
+
@torch.no_grad()
|
345 |
+
def generate(self, #
|
346 |
+
# conditions: tp.List[ConditioningAttributes] = [],
|
347 |
+
texts = None,
|
348 |
+
descriptions = None,
|
349 |
+
audio_qt_embs = None,
|
350 |
+
num_samples: tp.Optional[int] = None,
|
351 |
+
max_gen_len: int = 256,
|
352 |
+
use_sampling: bool = True,
|
353 |
+
temp: float = 1.0,
|
354 |
+
top_k: int = 250,
|
355 |
+
top_p: float = 0.0,
|
356 |
+
cfg_coef: tp.Optional[float] = None,
|
357 |
+
check: bool = False,
|
358 |
+
record_tokens: bool = True,
|
359 |
+
record_window: int = 150
|
360 |
+
) -> torch.Tensor:
|
361 |
+
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
362 |
+
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
366 |
+
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
367 |
+
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
368 |
+
max_gen_len (int): Maximum generation length.
|
369 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
370 |
+
temp (float): Sampling temperature.
|
371 |
+
top_k (int): K for "top-k" sampling.
|
372 |
+
top_p (float): P for "top-p" sampling.
|
373 |
+
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
374 |
+
check (bool): Whether to apply further checks on generated sequence.
|
375 |
+
callback (Callback, optional): Callback function to report generation progress.
|
376 |
+
Returns:
|
377 |
+
torch.Tensor: Generated tokens.
|
378 |
+
"""
|
379 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
380 |
+
first_param = next(iter(self.parameters()))
|
381 |
+
device = first_param.device
|
382 |
+
|
383 |
+
# 1) Check input shapes are consistent
|
384 |
+
possible_num_samples = []
|
385 |
+
if num_samples is not None:
|
386 |
+
possible_num_samples.append(num_samples)
|
387 |
+
elif texts:
|
388 |
+
possible_num_samples.append(len(texts))
|
389 |
+
elif audio_qt_embs:
|
390 |
+
possible_num_samples.append(len(audio_qt_embs))
|
391 |
+
else:
|
392 |
+
possible_num_samples.append(1)
|
393 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
394 |
+
num_samples = possible_num_samples[0]
|
395 |
+
condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, descriptions=descriptions, audio_qt_emb=audio_qt_embs, prepare_null_condition=True)
|
396 |
+
# 3) Prepare token pool
|
397 |
+
record_token_pool = None
|
398 |
+
if record_tokens:
|
399 |
+
record_token_pool = []
|
400 |
+
|
401 |
+
# 4) set up startoff patterns
|
402 |
+
start_offset = 0
|
403 |
+
assert start_offset < max_gen_len, f"{start_offset}, {max_gen_len}"
|
404 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
405 |
+
# this token is used as default value for codes that are not generated yet
|
406 |
+
unknown_token = -1
|
407 |
+
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
408 |
+
B = num_samples
|
409 |
+
gen_codes = torch.full((B, self.code_depth, max_gen_len),
|
410 |
+
unknown_token, dtype=torch.long, device=device)
|
411 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
412 |
+
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
413 |
+
output_codes = torch.full_like(gen_sequence, self.code_size)
|
414 |
+
# retrieve the start_offset in the sequence:
|
415 |
+
# it is the first sequence step that contains the `start_offset` timestep
|
416 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
417 |
+
assert start_offset_sequence is not None
|
418 |
+
is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
|
419 |
+
ignore_tokens = audio_qt_embs[0][0]
|
420 |
+
# 5) auto-regressive sampling
|
421 |
+
with self.streaming():
|
422 |
+
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
423 |
+
prev_offset = 0
|
424 |
+
for offset in range(start_offset_sequence, gen_sequence_len):
|
425 |
+
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
426 |
+
curr_sequence = gen_sequence[..., prev_offset:offset]
|
427 |
+
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
428 |
+
if check:
|
429 |
+
# check coherence between mask and sequence
|
430 |
+
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
431 |
+
# should never happen as gen_sequence is filled progressively
|
432 |
+
assert not (curr_sequence == unknown_token).any()
|
433 |
+
# sample next token from the model, next token shape is [B, K, 1]
|
434 |
+
next_token = self._sample_next_token(
|
435 |
+
curr_sequence, condition_tensors, use_sampling, temp, top_k, top_p,
|
436 |
+
cfg_coef=cfg_coef,
|
437 |
+
sampled_token_pool=record_token_pool[-record_window:] if record_tokens else None,
|
438 |
+
ignore_tokens = ignore_tokens
|
439 |
+
)
|
440 |
+
# ensure the tokens that should be masked are properly set to special_token_id
|
441 |
+
# as the model never output special_token_id
|
442 |
+
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
443 |
+
next_token[~valid_mask] = self.special_token_id
|
444 |
+
# 检查eos id
|
445 |
+
next_token[is_end] = self.special_token_id
|
446 |
+
is_end = is_end | (next_token == self.eos_token_id)
|
447 |
+
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
448 |
+
# (then mask tokens should be left as is as well, which is correct)
|
449 |
+
gen_sequence[..., offset:offset+1] = torch.where(
|
450 |
+
gen_sequence[..., offset:offset+1] == unknown_token,
|
451 |
+
next_token, gen_sequence[..., offset:offset+1])
|
452 |
+
|
453 |
+
# record sampled tokens in a window
|
454 |
+
if record_tokens:
|
455 |
+
record_token_pool.append(next_token.squeeze())
|
456 |
+
if torch.all(is_end):
|
457 |
+
gen_sequence = gen_sequence[..., :offset+1]
|
458 |
+
break
|
459 |
+
|
460 |
+
prev_offset = offset
|
461 |
+
|
462 |
+
# ensure sequence has been entirely filled
|
463 |
+
assert not (gen_sequence == unknown_token).any()
|
464 |
+
max_gen_len = gen_sequence.shape[-1]
|
465 |
+
output_codes[..., :max_gen_len] = gen_sequence
|
466 |
+
# ensure gen_sequence pattern and mask are matching
|
467 |
+
# which means the gen_sequence is valid according to the pattern
|
468 |
+
# assert (gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence,
|
469 |
+
# self.special_token_id)
|
470 |
+
# ).all()
|
471 |
+
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
472 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(output_codes, special_token=unknown_token)
|
473 |
+
# sanity checks over the returned codes and corresponding masks
|
474 |
+
assert (out_codes != unknown_token).all()
|
475 |
+
assert (out_mask == 1).all()
|
476 |
+
# ensure the returned codes are all valid
|
477 |
+
assert (out_codes >= 0).all() and (out_codes <= self.code_size).all()
|
478 |
+
return out_codes
|
479 |
+
|
480 |
+
def _sample_next_token(self,
|
481 |
+
sequence: torch.Tensor,
|
482 |
+
condition_tensors: ConditionTensors,
|
483 |
+
use_sampling: bool = False,
|
484 |
+
temp: float = 1.0,
|
485 |
+
top_k: int = 0,
|
486 |
+
top_p: float = 0.0,
|
487 |
+
cfg_coef: tp.Optional[float] = None,
|
488 |
+
sampled_token_pool: tp.Optional[list] = None,
|
489 |
+
ignore_tokens: tp.Optional[torch.tensor] = torch.tensor([])) -> torch.Tensor:
|
490 |
+
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
491 |
+
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
492 |
+
|
493 |
+
Args:
|
494 |
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
495 |
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
496 |
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
497 |
+
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
498 |
+
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
499 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
500 |
+
temp (float): Sampling temperature.
|
501 |
+
top_k (int): K for "top-k" sampling.
|
502 |
+
top_p (float): P for "top-p" sampling.
|
503 |
+
cfg_coef (float, optional): classifier free guidance coefficient
|
504 |
+
Returns:
|
505 |
+
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
506 |
+
"""
|
507 |
+
# import pdb; pdb.set_trace()
|
508 |
+
B = sequence.shape[0]
|
509 |
+
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
510 |
+
model = self if self._fsdp is None else self._fsdp
|
511 |
+
|
512 |
+
# Preparing for CFG, predicting both conditional and unconditional logits.
|
513 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
514 |
+
all_logits = model(sequence, condition_tensors=condition_tensors)
|
515 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
516 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
517 |
+
|
518 |
+
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
519 |
+
logits = logits[..., -1] # [B x K x card]
|
520 |
+
|
521 |
+
# add punishment to pre-sampled tokens
|
522 |
+
if sampled_token_pool is not None and len(sampled_token_pool) > 0:
|
523 |
+
sampled_token_pool = torch.stack(sampled_token_pool, -1) # [K, T]
|
524 |
+
for q in range(self.code_depth):
|
525 |
+
# q_count = torch.bincount(sampled_token_pool)
|
526 |
+
q_count = torch.bincount(torch.unique(sampled_token_pool[q]))
|
527 |
+
tmp = min(q_count.shape[-1], self.code_size - 1)
|
528 |
+
logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
|
529 |
+
|
530 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
531 |
+
if(ignore_tokens is not None):
|
532 |
+
logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
|
533 |
+
if use_sampling and temp > 0.0:
|
534 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
535 |
+
if top_p > 0.0:
|
536 |
+
next_token = sample_top_p(probs, p=top_p)
|
537 |
+
elif top_k > 0:
|
538 |
+
next_token_first = sample_top_k(probs[:,[0],:], k=top_k)
|
539 |
+
next_token_res = sample_top_k(probs[:,1:,:], k=1)
|
540 |
+
next_token = torch.cat([next_token_first,next_token_res], dim = 1)
|
541 |
+
else:
|
542 |
+
next_token = multinomial(probs, num_samples=1)
|
543 |
+
else:
|
544 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
545 |
+
|
546 |
+
return next_token
|
codeclm/modules/conditioners.py
ADDED
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import typing as tp
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from dataclasses import dataclass, field, fields
|
6 |
+
from itertools import chain
|
7 |
+
import warnings
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from codeclm.utils.utils import length_to_mask, collate
|
11 |
+
from codeclm.modules.streaming import StreamingModule
|
12 |
+
from collections import defaultdict
|
13 |
+
from copy import deepcopy
|
14 |
+
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
15 |
+
|
16 |
+
# ================================================================
|
17 |
+
# Condition and Condition attributes definitions
|
18 |
+
# ================================================================
|
19 |
+
class AudioCondition(tp.NamedTuple):
|
20 |
+
wav: torch.Tensor
|
21 |
+
length: torch.Tensor
|
22 |
+
sample_rate: tp.List[int]
|
23 |
+
path: tp.List[tp.Optional[str]] = []
|
24 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class ConditioningAttributes:
|
28 |
+
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
29 |
+
audio: tp.Dict[str, AudioCondition] = field(default_factory=dict)
|
30 |
+
|
31 |
+
def __getitem__(self, item):
|
32 |
+
return getattr(self, item)
|
33 |
+
|
34 |
+
@property
|
35 |
+
def text_attributes(self):
|
36 |
+
return self.text.keys()
|
37 |
+
|
38 |
+
@property
|
39 |
+
def audio_attributes(self):
|
40 |
+
return self.audio.keys()
|
41 |
+
|
42 |
+
@property
|
43 |
+
def attributes(self):
|
44 |
+
return {
|
45 |
+
"text": self.text_attributes,
|
46 |
+
"audio": self.audio_attributes,
|
47 |
+
}
|
48 |
+
|
49 |
+
def to_flat_dict(self):
|
50 |
+
return {
|
51 |
+
**{f"text.{k}": v for k, v in self.text.items()},
|
52 |
+
**{f"audio.{k}": v for k, v in self.audio.items()},
|
53 |
+
}
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def from_flat_dict(cls, x):
|
57 |
+
out = cls()
|
58 |
+
for k, v in x.items():
|
59 |
+
kind, att = k.split(".")
|
60 |
+
out[kind][att] = v
|
61 |
+
return out
|
62 |
+
|
63 |
+
# ================================================================
|
64 |
+
# Conditioner (tokenize and encode raw conditions) definitions
|
65 |
+
# ================================================================
|
66 |
+
|
67 |
+
class BaseConditioner(nn.Module):
|
68 |
+
"""Base model for all conditioner modules.
|
69 |
+
We allow the output dim to be different than the hidden dim for two reasons:
|
70 |
+
1) keep our LUTs small when the vocab is large;
|
71 |
+
2) make all condition dims consistent.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dim (int): Hidden dim of the model.
|
75 |
+
output_dim (int): Output dim of the conditioner.
|
76 |
+
"""
|
77 |
+
def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=0):
|
78 |
+
super().__init__()
|
79 |
+
self.dim = dim
|
80 |
+
self.output_dim = output_dim
|
81 |
+
if input_token:
|
82 |
+
self.output_proj = nn.Embedding(dim, output_dim, padding_idx)
|
83 |
+
else:
|
84 |
+
self.output_proj = nn.Linear(dim, output_dim)
|
85 |
+
|
86 |
+
def tokenize(self, *args, **kwargs) -> tp.Any:
|
87 |
+
"""Should be any part of the processing that will lead to a synchronization
|
88 |
+
point, e.g. BPE tokenization with transfer to the GPU.
|
89 |
+
|
90 |
+
The returned value will be saved and return later when calling forward().
|
91 |
+
"""
|
92 |
+
raise NotImplementedError()
|
93 |
+
|
94 |
+
def forward(self, inputs: tp.Any) -> ConditionType:
|
95 |
+
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
96 |
+
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
ConditionType:
|
100 |
+
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
101 |
+
output embedding and D is the dimension of the embedding.
|
102 |
+
- And a mask indicating where the padding tokens.
|
103 |
+
"""
|
104 |
+
raise NotImplementedError()
|
105 |
+
|
106 |
+
class TextConditioner(BaseConditioner):
|
107 |
+
...
|
108 |
+
|
109 |
+
|
110 |
+
class PhonemeTokenizerConditioner(TextConditioner):
|
111 |
+
def __init__(self,
|
112 |
+
output_dim: int,
|
113 |
+
vocab_list,
|
114 |
+
max_len = 600,
|
115 |
+
max_sentence_per_structure = 50,
|
116 |
+
structure_tokens=None,
|
117 |
+
structure_split_tokens=[','],
|
118 |
+
sentence_split_tokens=['.'],
|
119 |
+
mode='sum',
|
120 |
+
structure_output_dim = 64,
|
121 |
+
sentence_output_dim = 64,
|
122 |
+
max_duration = 120,
|
123 |
+
):
|
124 |
+
|
125 |
+
self.vocab_list = vocab_list
|
126 |
+
self.max_len = max_len
|
127 |
+
self.mode = mode
|
128 |
+
self.max_sentence_per_structure = max_sentence_per_structure
|
129 |
+
voc_size = len(self.vocab_list)
|
130 |
+
|
131 |
+
if structure_tokens is None:
|
132 |
+
structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
|
133 |
+
self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
|
134 |
+
self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
|
135 |
+
self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
|
136 |
+
|
137 |
+
# here initialize a output_proj (nn.Embedding) layer
|
138 |
+
# By default the first vocab is "" (null)
|
139 |
+
if mode == 'sum':
|
140 |
+
content_output_dim = output_dim
|
141 |
+
sentence_output_dim = output_dim
|
142 |
+
structure_output_dim = output_dim
|
143 |
+
else: # concat'
|
144 |
+
raise NotImplementedError("concat 模式还未实现")
|
145 |
+
# content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default
|
146 |
+
|
147 |
+
super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
|
148 |
+
self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0)
|
149 |
+
|
150 |
+
self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
|
151 |
+
|
152 |
+
# the first index is "empty structure" token
|
153 |
+
self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
|
154 |
+
self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
|
155 |
+
|
156 |
+
print("max_len", self.max_len)
|
157 |
+
print(self.structure_token_ids)
|
158 |
+
|
159 |
+
self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s
|
160 |
+
print(self.__class__, f"resolution = {self.resolution}")
|
161 |
+
|
162 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
163 |
+
inputs = []
|
164 |
+
for xx in x:
|
165 |
+
xx = '' if xx is None else xx
|
166 |
+
vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
|
167 |
+
inputs.append(torch.tensor(vocab_id).long()) # [T]
|
168 |
+
return inputs
|
169 |
+
|
170 |
+
|
171 |
+
def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
|
172 |
+
"""
|
173 |
+
Encode token_id into three types of embeddings:
|
174 |
+
1) content embedding: phoneme only (or meaningful contents to be sung out)
|
175 |
+
2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
|
176 |
+
The two above share the same embedding layer, can be changed to separate embedding layers.
|
177 |
+
3) sentence_idx embedding (per structure):
|
178 |
+
"""
|
179 |
+
embeds_batch = []
|
180 |
+
for b in range(len(batch_tokens)):
|
181 |
+
tokens = batch_tokens[b]
|
182 |
+
content_tokens = torch.zeros_like(tokens)
|
183 |
+
special_tokens = torch.zeros_like(tokens)
|
184 |
+
sentence_idx_in_structure_tokens = torch.zeros_like(tokens)
|
185 |
+
sentence_reidx_in_structure_tokens = torch.zeros_like(tokens)
|
186 |
+
|
187 |
+
current_sentence_in_structure_idx = 1
|
188 |
+
current_structure = 0
|
189 |
+
for i in range(tokens.shape[-1]):
|
190 |
+
token = tokens[i]
|
191 |
+
if token in self.structure_token_ids: # structure token
|
192 |
+
# only update structure token, leave content and sentence index token null (default 0)
|
193 |
+
special_tokens[i] = token
|
194 |
+
content_tokens[i] = token
|
195 |
+
current_structure = token
|
196 |
+
current_sentence_in_structure_idx = 1
|
197 |
+
sentence_idx_in_structure_tokens[i] = 0
|
198 |
+
|
199 |
+
elif token in self.sentence_split_token_ids: # utterance split token
|
200 |
+
# only update structure token, leave content and sentence index token null (default 0)
|
201 |
+
# add up sentence index
|
202 |
+
special_tokens[i] = current_structure
|
203 |
+
content_tokens[i] = token
|
204 |
+
sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
|
205 |
+
current_sentence_in_structure_idx += 1
|
206 |
+
|
207 |
+
elif token in self.structure_split_token_ids: # structure split token
|
208 |
+
# update structure token (current structure), content token (current token),
|
209 |
+
# blank index token
|
210 |
+
content_tokens[i] = token
|
211 |
+
special_tokens[i] = current_structure
|
212 |
+
sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1]
|
213 |
+
else: # content tokens
|
214 |
+
content_tokens[i] = token
|
215 |
+
special_tokens[i] = current_structure
|
216 |
+
sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
|
217 |
+
# 反推
|
218 |
+
current_sentence_num = sentence_idx_in_structure_tokens[-1]
|
219 |
+
for i in range(tokens.shape[-1]-1,-1,-1):
|
220 |
+
if current_sentence_num != 0:
|
221 |
+
sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1)
|
222 |
+
if sentence_idx_in_structure_tokens[i] == 0 and i > 0:
|
223 |
+
current_sentence_num = sentence_idx_in_structure_tokens[i-1]
|
224 |
+
|
225 |
+
# print("tokens", tokens.max(), tokens.min())
|
226 |
+
# print("special tokens", special_tokens.max(), special_tokens.min())
|
227 |
+
# print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
|
228 |
+
device = self.output_proj.weight.device
|
229 |
+
|
230 |
+
# import pdb; pdb.set_trace()
|
231 |
+
content_embeds = self.output_proj(content_tokens.to(device)) # [T, N]
|
232 |
+
structure_embeds = self.output_proj(special_tokens.to(device))
|
233 |
+
# sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
|
234 |
+
sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device))
|
235 |
+
|
236 |
+
if self.mode == 'sum':
|
237 |
+
embeds = content_embeds + structure_embeds + sentence_idx_embeds
|
238 |
+
else:
|
239 |
+
embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
|
240 |
+
embeds_batch.append(embeds)
|
241 |
+
|
242 |
+
# set batch_size = 1, [B, T, N]
|
243 |
+
if self.max_len is not None:
|
244 |
+
max_len = self.max_len
|
245 |
+
else:
|
246 |
+
max_len = max([e.shape[0] for e in embeds_batch])
|
247 |
+
embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
|
248 |
+
|
249 |
+
return embeds, embeds, mask
|
250 |
+
|
251 |
+
|
252 |
+
def pad_2d_tensor(self, xs, max_len):
|
253 |
+
new_tensor = []
|
254 |
+
new_mask = []
|
255 |
+
for x in xs:
|
256 |
+
seq_len, dim = x.size()
|
257 |
+
pad_len = max_len - seq_len
|
258 |
+
|
259 |
+
if pad_len > 0:
|
260 |
+
pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D
|
261 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=0)
|
262 |
+
mask = torch.cat((torch.ones_like(x[:, 0]),
|
263 |
+
torch.zeros_like(pad_tensor[:, 0])), 0) # T
|
264 |
+
elif pad_len < 0:
|
265 |
+
padded_tensor = x[:max_len]
|
266 |
+
mask = torch.ones_like(padded_tensor[:, 0])
|
267 |
+
else:
|
268 |
+
padded_tensor = x
|
269 |
+
mask = torch.ones_like(x[:, 0])
|
270 |
+
|
271 |
+
new_tensor.append(padded_tensor)
|
272 |
+
new_mask.append(mask)
|
273 |
+
# [B, T, D] & [B, T]
|
274 |
+
return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)
|
275 |
+
|
276 |
+
|
277 |
+
class QwTokenizerConditioner(TextConditioner):
|
278 |
+
def __init__(self, output_dim: int,
|
279 |
+
token_path = "",
|
280 |
+
max_len = 300,
|
281 |
+
add_token_list=[]): #""
|
282 |
+
from transformers import Qwen2Tokenizer
|
283 |
+
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
284 |
+
if add_token_list != []:
|
285 |
+
self.text_tokenizer.add_tokens(add_token_list, special_tokens=True)
|
286 |
+
voc_size = len(self.text_tokenizer.get_vocab())
|
287 |
+
# here initialize a output_proj (nn.Embedding) layer
|
288 |
+
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
289 |
+
self.max_len = max_len
|
290 |
+
self.padding_idx =' <|endoftext|>'
|
291 |
+
|
292 |
+
vocab = self.text_tokenizer.get_vocab()
|
293 |
+
# struct是全部的结构
|
294 |
+
struct_tokens = [i for i in add_token_list if i[0]=='[' and i[-1]==']']
|
295 |
+
self.struct_token_ids = [vocab[i] for i in struct_tokens]
|
296 |
+
self.pad_token_idx = 151643
|
297 |
+
|
298 |
+
self.structure_emb = nn.Embedding(200, output_dim, padding_idx=0)
|
299 |
+
# self.split_token_id = vocab["."]
|
300 |
+
print("all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids})
|
301 |
+
|
302 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
303 |
+
x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x]
|
304 |
+
# x = [xi if xi is not None else "" for xi in x]
|
305 |
+
inputs = self.text_tokenizer(x, return_tensors="pt", padding=True)
|
306 |
+
return inputs
|
307 |
+
|
308 |
+
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
309 |
+
"""
|
310 |
+
Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that
|
311 |
+
belong to these structures accordingly,
|
312 |
+
Then delete or keep these structure embeddings.
|
313 |
+
"""
|
314 |
+
mask = inputs['attention_mask']
|
315 |
+
tokens = inputs['input_ids']
|
316 |
+
B = tokens.shape[0]
|
317 |
+
is_sp_embed = torch.any(torch.stack([tokens == i for i in self.struct_token_ids], dim=-1),dim=-1)
|
318 |
+
|
319 |
+
tp_cover_range = torch.zeros_like(tokens)
|
320 |
+
for b, is_sp in enumerate(is_sp_embed):
|
321 |
+
sp_list = torch.where(is_sp)[0].tolist()
|
322 |
+
sp_list.append(mask[b].sum())
|
323 |
+
for i, st in enumerate(sp_list[:-1]):
|
324 |
+
tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645
|
325 |
+
|
326 |
+
if self.max_len is not None:
|
327 |
+
if inputs['input_ids'].shape[-1] > self.max_len:
|
328 |
+
warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
|
329 |
+
{[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
|
330 |
+
tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device)
|
331 |
+
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
332 |
+
tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device)
|
333 |
+
device = self.output_proj.weight.device
|
334 |
+
content_embeds = self.output_proj(tokens.to(device))
|
335 |
+
structure_embeds = self.structure_emb(tp_cover_range.to(device))
|
336 |
+
|
337 |
+
embeds = content_embeds + structure_embeds
|
338 |
+
return embeds, embeds, mask
|
339 |
+
|
340 |
+
def pad_2d_tensor(self, x, max_len, pad_id):
|
341 |
+
batch_size, seq_len = x.size()
|
342 |
+
pad_len = max_len - seq_len
|
343 |
+
|
344 |
+
if pad_len > 0:
|
345 |
+
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
346 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
347 |
+
elif pad_len < 0:
|
348 |
+
padded_tensor = x[:, :max_len]
|
349 |
+
else:
|
350 |
+
padded_tensor = x
|
351 |
+
|
352 |
+
return padded_tensor
|
353 |
+
|
354 |
+
|
355 |
+
class QwTextConditioner(TextConditioner):
|
356 |
+
def __init__(self, output_dim: int,
|
357 |
+
token_path = "",
|
358 |
+
max_len = 300): #""
|
359 |
+
|
360 |
+
from transformers import Qwen2Tokenizer
|
361 |
+
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
362 |
+
voc_size = len(self.text_tokenizer.get_vocab())
|
363 |
+
# here initialize a output_proj (nn.Embedding) layer
|
364 |
+
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
365 |
+
|
366 |
+
self.max_len = max_len
|
367 |
+
|
368 |
+
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
369 |
+
x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x]
|
370 |
+
inputs = self.text_tokenizer(x, return_tensors="pt", padding=True)
|
371 |
+
return inputs
|
372 |
+
|
373 |
+
def forward(self, inputs: tp.Dict[str, torch.Tensor], structure_dur = None) -> ConditionType:
|
374 |
+
"""
|
375 |
+
Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that
|
376 |
+
belong to these structures accordingly,
|
377 |
+
Then delete or keep these structure embeddings.
|
378 |
+
"""
|
379 |
+
mask = inputs['attention_mask']
|
380 |
+
tokens = inputs['input_ids']
|
381 |
+
|
382 |
+
if self.max_len is not None:
|
383 |
+
if inputs['input_ids'].shape[-1] > self.max_len:
|
384 |
+
warnings.warn(f"Max len limit ({self.max_len}) Exceed! \
|
385 |
+
{[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!")
|
386 |
+
tokens = self.pad_2d_tensor(tokens, self.max_len, 151643).to(self.output_proj.weight.device)
|
387 |
+
mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device)
|
388 |
+
|
389 |
+
embeds = self.output_proj(tokens)
|
390 |
+
return embeds, embeds, mask
|
391 |
+
|
392 |
+
def pad_2d_tensor(self, x, max_len, pad_id):
|
393 |
+
batch_size, seq_len = x.size()
|
394 |
+
pad_len = max_len - seq_len
|
395 |
+
|
396 |
+
if pad_len > 0:
|
397 |
+
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
398 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
399 |
+
elif pad_len < 0:
|
400 |
+
padded_tensor = x[:, :max_len]
|
401 |
+
else:
|
402 |
+
padded_tensor = x
|
403 |
+
|
404 |
+
return padded_tensor
|
405 |
+
|
406 |
+
|
407 |
+
class AudioConditioner(BaseConditioner):
|
408 |
+
...
|
409 |
+
|
410 |
+
class QuantizedEmbeddingConditioner(AudioConditioner):
|
411 |
+
def __init__(self, dim: int,
|
412 |
+
code_size: int,
|
413 |
+
code_depth: int,
|
414 |
+
max_len: int,
|
415 |
+
**kwargs):
|
416 |
+
super().__init__(dim, dim, input_token=True)
|
417 |
+
self.code_depth = code_depth
|
418 |
+
# add 1 for <s> token
|
419 |
+
self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)])
|
420 |
+
# add End-Of-Text embedding
|
421 |
+
self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
|
422 |
+
self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True)
|
423 |
+
self.output_proj = None
|
424 |
+
self.max_len = max_len
|
425 |
+
self.vocab_size = code_size
|
426 |
+
|
427 |
+
def tokenize(self, x: AudioCondition) -> AudioCondition:
|
428 |
+
"""no extra ops"""
|
429 |
+
# wav, length, sample_rate, path, seek_time = x
|
430 |
+
# assert length is not None
|
431 |
+
return x #AudioCondition(wav, length, sample_rate, path, seek_time)
|
432 |
+
|
433 |
+
def forward(self, x: AudioCondition):
|
434 |
+
wav, lengths, *_ = x
|
435 |
+
B = wav.shape[0]
|
436 |
+
wav = wav.reshape(B, self.code_depth, -1).long()
|
437 |
+
if wav.shape[2] < self.max_len - 1:
|
438 |
+
wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1)
|
439 |
+
else:
|
440 |
+
wav = wav[:, :, :self.max_len-1]
|
441 |
+
embeds1 = self.emb[0](wav[:, 0])
|
442 |
+
embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1),
|
443 |
+
embeds1), dim=1)
|
444 |
+
embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) # B,T,D
|
445 |
+
embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1),
|
446 |
+
embeds2), dim=1)
|
447 |
+
lengths = lengths + 1
|
448 |
+
lengths = torch.clamp(lengths, max=self.max_len)
|
449 |
+
|
450 |
+
if lengths is not None:
|
451 |
+
mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() # type: ignore
|
452 |
+
else:
|
453 |
+
mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int)
|
454 |
+
return embeds1, embeds2, mask
|
455 |
+
|
456 |
+
|
457 |
+
# ================================================================
|
458 |
+
# Aggregate all conditions and corresponding conditioners
|
459 |
+
# ================================================================
|
460 |
+
class ConditionerProvider(nn.Module):
|
461 |
+
"""Prepare and provide conditions given all the supported conditioners.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
conditioners (dict): Dictionary of conditioners.
|
465 |
+
device (torch.device or str, optional): Device for conditioners and output condition types.
|
466 |
+
"""
|
467 |
+
def __init__(self, conditioners: tp.Dict[str, BaseConditioner]):
|
468 |
+
super().__init__()
|
469 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
470 |
+
|
471 |
+
@property
|
472 |
+
def text_conditions(self):
|
473 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
474 |
+
|
475 |
+
@property
|
476 |
+
def audio_conditions(self):
|
477 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, AudioConditioner)]
|
478 |
+
|
479 |
+
@property
|
480 |
+
def has_audio_condition(self):
|
481 |
+
return len(self.audio_conditions) > 0
|
482 |
+
|
483 |
+
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
484 |
+
"""Match attributes/audios with existing conditioners in self, and compute tokenize them accordingly.
|
485 |
+
This should be called before starting any real GPU work to avoid synchronization points.
|
486 |
+
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
490 |
+
text and audio conditions.
|
491 |
+
"""
|
492 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
493 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
494 |
+
f" but types were {set([type(x) for x in inputs])}")
|
495 |
+
|
496 |
+
output = {}
|
497 |
+
text = self._collate_text(inputs)
|
498 |
+
audios = self._collate_audios(inputs)
|
499 |
+
|
500 |
+
assert set(text.keys() | audios.keys()).issubset(set(self.conditioners.keys())), (
|
501 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
502 |
+
f"got {text.keys(), audios.keys()}")
|
503 |
+
|
504 |
+
for attribute, batch in chain(text.items(), audios.items()):
|
505 |
+
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
506 |
+
return output
|
507 |
+
|
508 |
+
def forward(self, tokenized: tp.Dict[str, tp.Any], structure_dur = None) -> tp.Dict[str, ConditionType]:
|
509 |
+
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
510 |
+
The output is for example:
|
511 |
+
{
|
512 |
+
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
513 |
+
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
514 |
+
...
|
515 |
+
}
|
516 |
+
|
517 |
+
Args:
|
518 |
+
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
519 |
+
"""
|
520 |
+
output = {}
|
521 |
+
for attribute, inputs in tokenized.items():
|
522 |
+
if attribute == 'description' and structure_dur is not None:
|
523 |
+
condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur)
|
524 |
+
else:
|
525 |
+
condition1, condition2, mask = self.conditioners[attribute](inputs)
|
526 |
+
output[attribute] = (condition1, condition2, mask)
|
527 |
+
return output
|
528 |
+
|
529 |
+
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
530 |
+
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
531 |
+
are the attributes and the values are the aggregated input per attribute.
|
532 |
+
For example:
|
533 |
+
Input:
|
534 |
+
[
|
535 |
+
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
536 |
+
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, audio=...),
|
537 |
+
]
|
538 |
+
Output:
|
539 |
+
{
|
540 |
+
"genre": ["Rock", "Hip-hop"],
|
541 |
+
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
542 |
+
}
|
543 |
+
|
544 |
+
Args:
|
545 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
546 |
+
Returns:
|
547 |
+
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
548 |
+
"""
|
549 |
+
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
550 |
+
texts = [x.text for x in samples]
|
551 |
+
for text in texts:
|
552 |
+
for condition in self.text_conditions:
|
553 |
+
out[condition].append(text[condition])
|
554 |
+
return out
|
555 |
+
|
556 |
+
def _collate_audios(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, AudioCondition]:
|
557 |
+
"""Generate a dict where the keys are attributes by which we fetch similar audios,
|
558 |
+
and the values are Tensors of audios according to said attributes.
|
559 |
+
|
560 |
+
*Note*: by the time the samples reach this function, each sample should have some audios
|
561 |
+
inside the "audio" attribute. It should be either:
|
562 |
+
1. A real audio
|
563 |
+
2. A null audio due to the sample having no similar audios (nullified by the dataset)
|
564 |
+
3. A null audio due to it being dropped in a dropout module (nullified by dropout)
|
565 |
+
|
566 |
+
Args:
|
567 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
568 |
+
Returns:
|
569 |
+
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
570 |
+
"""
|
571 |
+
# import pdb; pdb.set_trace()
|
572 |
+
wavs = defaultdict(list)
|
573 |
+
lengths = defaultdict(list)
|
574 |
+
sample_rates = defaultdict(list)
|
575 |
+
paths = defaultdict(list)
|
576 |
+
seek_times = defaultdict(list)
|
577 |
+
out: tp.Dict[str, AudioCondition] = {}
|
578 |
+
|
579 |
+
for sample in samples:
|
580 |
+
for attribute in self.audio_conditions:
|
581 |
+
wav, length, sample_rate, path, seek_time = sample.audio[attribute]
|
582 |
+
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
583 |
+
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
584 |
+
wavs[attribute].append(wav.flatten()) # [C*T]
|
585 |
+
lengths[attribute].append(length)
|
586 |
+
sample_rates[attribute].extend(sample_rate)
|
587 |
+
paths[attribute].extend(path)
|
588 |
+
seek_times[attribute].extend(seek_time)
|
589 |
+
|
590 |
+
# stack all wavs to a single tensor
|
591 |
+
for attribute in self.audio_conditions:
|
592 |
+
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
593 |
+
out[attribute] = AudioCondition(
|
594 |
+
stacked_wav.unsqueeze(1),
|
595 |
+
torch.cat(lengths[attribute]), sample_rates[attribute],
|
596 |
+
paths[attribute], seek_times[attribute])
|
597 |
+
|
598 |
+
return out
|
599 |
+
|
600 |
+
|
601 |
+
class ConditionFuser(StreamingModule):
|
602 |
+
"""Condition fuser handles the logic to combine the different conditions
|
603 |
+
to the actual model input.
|
604 |
+
|
605 |
+
Args:
|
606 |
+
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
607 |
+
each condition. For example:
|
608 |
+
{
|
609 |
+
"prepend": ["description"],
|
610 |
+
"sum": ["genre", "bpm"],
|
611 |
+
}
|
612 |
+
"""
|
613 |
+
FUSING_METHODS = ["sum", "prepend"] #, "cross", "input_interpolate"] (not support in this simplest version)
|
614 |
+
|
615 |
+
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]]):
|
616 |
+
super().__init__()
|
617 |
+
assert all([k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
618 |
+
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
619 |
+
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
620 |
+
self.cond2fuse: tp.Dict[str, str] = {}
|
621 |
+
for fuse_method, conditions in fuse2cond.items():
|
622 |
+
for condition in conditions:
|
623 |
+
self.cond2fuse[condition] = fuse_method
|
624 |
+
|
625 |
+
def forward(
|
626 |
+
self,
|
627 |
+
input1: torch.Tensor,
|
628 |
+
input2: torch.Tensor,
|
629 |
+
conditions: tp.Dict[str, ConditionType]
|
630 |
+
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
631 |
+
"""Fuse the conditions to the provided model input.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
input (torch.Tensor): Transformer input.
|
635 |
+
conditions (dict[str, ConditionType]): Dict of conditions.
|
636 |
+
Returns:
|
637 |
+
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
638 |
+
after the conditions have been fused. The second output tensor is the tensor
|
639 |
+
used for cross-attention or None if no cross attention inputs exist.
|
640 |
+
"""
|
641 |
+
#import pdb; pdb.set_trace()
|
642 |
+
B, T, _ = input1.shape
|
643 |
+
|
644 |
+
if 'offsets' in self._streaming_state:
|
645 |
+
first_step = False
|
646 |
+
offsets = self._streaming_state['offsets']
|
647 |
+
else:
|
648 |
+
first_step = True
|
649 |
+
offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device)
|
650 |
+
|
651 |
+
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
652 |
+
f"given conditions contain unknown attributes for fuser, " \
|
653 |
+
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
654 |
+
|
655 |
+
# if 'prepend' mode is used,
|
656 |
+
# the concatenation order will be the SAME with the conditions in config:
|
657 |
+
# prepend: ['description', 'prompt_audio'] (then goes the input)
|
658 |
+
fused_input_1 = input1
|
659 |
+
fused_input_2 = input2
|
660 |
+
for fuse_op in self.fuse2cond.keys():
|
661 |
+
fuse_op_conditions = self.fuse2cond[fuse_op]
|
662 |
+
if fuse_op == 'sum' and len(fuse_op_conditions) > 0:
|
663 |
+
for cond in fuse_op_conditions:
|
664 |
+
this_cond_1, this_cond_2, cond_mask = conditions[cond]
|
665 |
+
fused_input_1 += this_cond_1
|
666 |
+
fused_input_2 += this_cond_2
|
667 |
+
elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0:
|
668 |
+
if not first_step:
|
669 |
+
continue
|
670 |
+
reverse_list = deepcopy(fuse_op_conditions)
|
671 |
+
reverse_list.reverse()
|
672 |
+
for cond in reverse_list:
|
673 |
+
this_cond_1, this_cond_2, cond_mask = conditions[cond]
|
674 |
+
fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) # concat along T dim
|
675 |
+
fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim
|
676 |
+
elif fuse_op not in self.FUSING_METHODS:
|
677 |
+
raise ValueError(f"unknown op ({fuse_op})")
|
678 |
+
|
679 |
+
if self._is_streaming:
|
680 |
+
self._streaming_state['offsets'] = offsets + T
|
681 |
+
|
682 |
+
return fused_input_1, fused_input_2
|
683 |
+
|
684 |
+
|
685 |
+
|
686 |
+
# ================================================================
|
687 |
+
# Condition Dropout
|
688 |
+
# ================================================================
|
689 |
+
class DropoutModule(nn.Module):
|
690 |
+
"""Base module for all dropout modules."""
|
691 |
+
def __init__(self, seed: int = 1234):
|
692 |
+
super().__init__()
|
693 |
+
self.rng = torch.Generator()
|
694 |
+
self.rng.manual_seed(seed)
|
695 |
+
|
696 |
+
|
697 |
+
|
698 |
+
class ClassifierFreeGuidanceDropout(DropoutModule):
|
699 |
+
"""Classifier Free Guidance dropout.
|
700 |
+
All attributes are dropped with the same probability.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
p (float): Probability to apply condition dropout during training.
|
704 |
+
seed (int): Random seed.
|
705 |
+
"""
|
706 |
+
def __init__(self, p: float, seed: int = 1234):
|
707 |
+
super().__init__(seed=seed)
|
708 |
+
self.p = p
|
709 |
+
|
710 |
+
def check(self, sample, condition_type, condition):
|
711 |
+
|
712 |
+
if condition_type not in ['text', 'audio']:
|
713 |
+
raise ValueError("dropout_condition got an unexpected condition type!"
|
714 |
+
f" expected 'text', 'audio' but got '{condition_type}'")
|
715 |
+
|
716 |
+
if condition not in getattr(sample, condition_type):
|
717 |
+
raise ValueError(
|
718 |
+
"dropout_condition received an unexpected condition!"
|
719 |
+
f" expected audio={sample.audio.keys()} and text={sample.text.keys()}"
|
720 |
+
f" but got '{condition}' of type '{condition_type}'!")
|
721 |
+
|
722 |
+
|
723 |
+
def get_null_wav(self, wav, sr=48000) -> AudioCondition:
|
724 |
+
out = wav * 0 + 16385
|
725 |
+
return AudioCondition(
|
726 |
+
wav=out,
|
727 |
+
length=torch.Tensor([0]).long(),
|
728 |
+
sample_rate=[sr],)
|
729 |
+
|
730 |
+
def dropout_condition(self,
|
731 |
+
sample: ConditioningAttributes,
|
732 |
+
condition_type: str,
|
733 |
+
condition: str) -> ConditioningAttributes:
|
734 |
+
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
735 |
+
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
736 |
+
If the condition is of any other type, set its value to None.
|
737 |
+
Works in-place.
|
738 |
+
"""
|
739 |
+
self.check(sample, condition_type, condition)
|
740 |
+
|
741 |
+
if condition_type == 'audio':
|
742 |
+
audio_cond = sample.audio[condition]
|
743 |
+
depth = audio_cond.wav.shape[1]
|
744 |
+
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
745 |
+
else:
|
746 |
+
sample.text[condition] = None
|
747 |
+
|
748 |
+
return sample
|
749 |
+
|
750 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
751 |
+
"""
|
752 |
+
Args:
|
753 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
754 |
+
Returns:
|
755 |
+
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
756 |
+
"""
|
757 |
+
# decide on which attributes to drop in a batched fashion
|
758 |
+
# drop = torch.rand(1, generator=self.rng).item() < self.p
|
759 |
+
# if not drop:
|
760 |
+
# return samples
|
761 |
+
|
762 |
+
# nullify conditions of all attributes
|
763 |
+
samples = deepcopy(samples)
|
764 |
+
|
765 |
+
for sample in samples:
|
766 |
+
drop = torch.rand(1, generator=self.rng).item()
|
767 |
+
if drop<self.p:
|
768 |
+
for condition_type in ["audio", "text"]:
|
769 |
+
for condition in sample.attributes[condition_type]:
|
770 |
+
self.dropout_condition(sample, condition_type, condition)
|
771 |
+
return samples
|
772 |
+
|
773 |
+
def __repr__(self):
|
774 |
+
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
775 |
+
|
776 |
+
|
777 |
+
class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
|
778 |
+
"""Classifier Free Guidance dropout during inference.
|
779 |
+
All attributes are dropped with the same probability.
|
780 |
+
|
781 |
+
Args:
|
782 |
+
p (float): Probability to apply condition dropout during training.
|
783 |
+
seed (int): Random seed.
|
784 |
+
"""
|
785 |
+
def __init__(self, seed: int = 1234):
|
786 |
+
super().__init__(p=1, seed=seed)
|
787 |
+
|
788 |
+
def dropout_condition_customized(self,
|
789 |
+
sample: ConditioningAttributes,
|
790 |
+
condition_type: str,
|
791 |
+
condition: str,
|
792 |
+
customized: list = None) -> ConditioningAttributes:
|
793 |
+
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
794 |
+
If the condition is of type "audio", then nullify it using `nullify_condition` function.
|
795 |
+
If the condition is of any other type, set its value to None.
|
796 |
+
Works in-place.
|
797 |
+
"""
|
798 |
+
self.check(sample, condition_type, condition)
|
799 |
+
|
800 |
+
if condition_type == 'audio':
|
801 |
+
audio_cond = sample.audio[condition]
|
802 |
+
depth = audio_cond.wav.shape[1]
|
803 |
+
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
804 |
+
else:
|
805 |
+
if customized is None:
|
806 |
+
sample.text[condition] = None
|
807 |
+
else:
|
808 |
+
text_cond = deepcopy(sample.text[condition])
|
809 |
+
if "structure" in customized:
|
810 |
+
for _s in ['[inst]', '[outro]', '[intro]', '[verse]', '[chorus]', '[bridge]']:
|
811 |
+
text_cond = text_cond.replace(_s, "")
|
812 |
+
text_cond = text_cond.replace(' , ', '')
|
813 |
+
text_cond = text_cond.replace(" ", " ")
|
814 |
+
if '.' in customized:
|
815 |
+
text_cond = text_cond.replace(" . ", " ")
|
816 |
+
text_cond = text_cond.replace(".", " ")
|
817 |
+
|
818 |
+
sample.text[condition] = text_cond
|
819 |
+
|
820 |
+
return sample
|
821 |
+
|
822 |
+
def forward(self, samples: tp.List[ConditioningAttributes],
|
823 |
+
condition_types=["wav", "text"],
|
824 |
+
customized=None,
|
825 |
+
) -> tp.List[ConditioningAttributes]:
|
826 |
+
"""
|
827 |
+
100% dropout some condition attributes (description, prompt_wav) or types (text, wav) of
|
828 |
+
samples during inference.
|
829 |
+
|
830 |
+
Args:
|
831 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
832 |
+
Returns:
|
833 |
+
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
834 |
+
"""
|
835 |
+
new_samples = deepcopy(samples)
|
836 |
+
for condition_type in condition_types:
|
837 |
+
for sample in new_samples:
|
838 |
+
for condition in sample.attributes[condition_type]:
|
839 |
+
self.dropout_condition_customized(sample, condition_type, condition, customized)
|
840 |
+
return new_samples
|
841 |
+
|
842 |
+
class AttributeDropout(ClassifierFreeGuidanceDropout):
|
843 |
+
"""Dropout with a given probability per attribute.
|
844 |
+
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
845 |
+
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
846 |
+
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
847 |
+
must also be dropped.
|
848 |
+
|
849 |
+
Args:
|
850 |
+
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
851 |
+
...
|
852 |
+
"genre": 0.1,
|
853 |
+
"artist": 0.5,
|
854 |
+
"audio": 0.25,
|
855 |
+
...
|
856 |
+
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
857 |
+
seed (int, optional): Random seed.
|
858 |
+
"""
|
859 |
+
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
860 |
+
super().__init__(p=p, seed=seed)
|
861 |
+
self.active_on_eval = active_on_eval
|
862 |
+
# construct dict that return the values from p otherwise 0
|
863 |
+
self.p = {}
|
864 |
+
for condition_type, probs in p.items():
|
865 |
+
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
866 |
+
|
867 |
+
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
868 |
+
"""
|
869 |
+
Args:
|
870 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
871 |
+
Returns:
|
872 |
+
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
873 |
+
"""
|
874 |
+
if not self.training and not self.active_on_eval:
|
875 |
+
return samples
|
876 |
+
|
877 |
+
samples = deepcopy(samples)
|
878 |
+
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
879 |
+
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
880 |
+
if torch.rand(1, generator=self.rng).item() < p:
|
881 |
+
for sample in samples:
|
882 |
+
self.dropout_condition(sample, condition_type, condition)
|
883 |
+
return samples
|
codeclm/modules/pattern.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import lru_cache
|
4 |
+
import logging
|
5 |
+
import typing as tp
|
6 |
+
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
import torch
|
9 |
+
|
10 |
+
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
11 |
+
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class Pattern:
|
17 |
+
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
18 |
+
|
19 |
+
The codebook pattern consists in a layout, defining for each sequence step
|
20 |
+
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
21 |
+
The first item of the pattern is always an empty list in order to properly insert a special token
|
22 |
+
to start with. For convenience, we also keep track of ``code_depth`` the number of codebooks used for the pattern
|
23 |
+
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
24 |
+
|
25 |
+
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
26 |
+
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
27 |
+
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
|
28 |
+
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
29 |
+
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
30 |
+
is returned along with a mask indicating valid tokens.
|
31 |
+
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
32 |
+
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
33 |
+
to fill and specify invalid positions if needed.
|
34 |
+
See the dedicated methods for more details.
|
35 |
+
"""
|
36 |
+
# Pattern layout, for each sequence step, we have a list of coordinates
|
37 |
+
# corresponding to the original codebook timestep and position.
|
38 |
+
# The first list is always an empty list in order to properly insert
|
39 |
+
# a special token to start with.
|
40 |
+
layout: PatternLayout
|
41 |
+
timesteps: int
|
42 |
+
code_depth: int
|
43 |
+
|
44 |
+
def __post_init__(self):
|
45 |
+
assert len(self.layout) > 0
|
46 |
+
assert self.layout[0] == []
|
47 |
+
self._validate_layout()
|
48 |
+
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
49 |
+
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
50 |
+
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
51 |
+
|
52 |
+
def _validate_layout(self):
|
53 |
+
"""Runs checks on the layout to ensure a valid pattern is defined.
|
54 |
+
A pattern is considered invalid if:
|
55 |
+
- Multiple timesteps for a same codebook are defined in the same sequence step
|
56 |
+
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
57 |
+
(this would mean that we have future timesteps before past timesteps).
|
58 |
+
"""
|
59 |
+
q_timesteps = {q: 0 for q in range(self.code_depth)}
|
60 |
+
for s, seq_coords in enumerate(self.layout):
|
61 |
+
if len(seq_coords) > 0:
|
62 |
+
qs = set()
|
63 |
+
for coord in seq_coords:
|
64 |
+
qs.add(coord.q)
|
65 |
+
last_q_timestep = q_timesteps[coord.q]
|
66 |
+
# assert coord.t >= last_q_timestep, \
|
67 |
+
# f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
68 |
+
q_timesteps[coord.q] = coord.t
|
69 |
+
# each sequence step contains at max 1 coordinate per codebook
|
70 |
+
assert len(qs) == len(seq_coords), \
|
71 |
+
f"Multiple entries for a same codebook are found at step {s}"
|
72 |
+
|
73 |
+
@property
|
74 |
+
def num_sequence_steps(self):
|
75 |
+
return len(self.layout) - 1
|
76 |
+
|
77 |
+
@property
|
78 |
+
def max_delay(self):
|
79 |
+
max_t_in_seq_coords = 0
|
80 |
+
for seq_coords in self.layout[1:]:
|
81 |
+
for coords in seq_coords:
|
82 |
+
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
83 |
+
return max_t_in_seq_coords - self.timesteps
|
84 |
+
|
85 |
+
@property
|
86 |
+
def valid_layout(self):
|
87 |
+
valid_step = len(self.layout) - self.max_delay
|
88 |
+
return self.layout[:valid_step]
|
89 |
+
|
90 |
+
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
91 |
+
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
92 |
+
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
93 |
+
and the actual codebook coordinates.
|
94 |
+
"""
|
95 |
+
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
96 |
+
if q is not None:
|
97 |
+
assert q <= self.code_depth, "provided number of codebooks is greater than the pattern's number of codebooks"
|
98 |
+
coords = []
|
99 |
+
for s, seq_codes in enumerate(self.layout):
|
100 |
+
for code in seq_codes:
|
101 |
+
if code.t == t and (q is None or code.q == q):
|
102 |
+
coords.append((s, code))
|
103 |
+
return coords
|
104 |
+
|
105 |
+
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
106 |
+
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
107 |
+
|
108 |
+
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
109 |
+
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
110 |
+
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
111 |
+
|
112 |
+
def _build_pattern_sequence_scatter_indexes(self, timesteps: int,
|
113 |
+
code_depth: int,
|
114 |
+
keep_only_valid_steps: bool,
|
115 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
116 |
+
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
timesteps (int): Maximum number of timesteps steps to consider.
|
120 |
+
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
121 |
+
device (torch.device or str): Device for created tensors.
|
122 |
+
Returns:
|
123 |
+
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
124 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
125 |
+
"""
|
126 |
+
assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}"
|
127 |
+
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
128 |
+
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
129 |
+
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
130 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
131 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
132 |
+
indexes = torch.zeros(code_depth, len(ref_layout), dtype=torch.long).numpy()
|
133 |
+
mask = torch.zeros(code_depth, len(ref_layout), dtype=torch.bool).numpy()
|
134 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
135 |
+
# the last value is code_depth * timesteps as we have flattened z and append special token as the last token
|
136 |
+
# which will correspond to the index: code_depth * timesteps
|
137 |
+
indexes[:] = code_depth * timesteps
|
138 |
+
# iterate over the pattern and fill scattered indexes and mask
|
139 |
+
for s, sequence_coords in enumerate(ref_layout):
|
140 |
+
for coords in sequence_coords:
|
141 |
+
if coords.t < timesteps:
|
142 |
+
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
143 |
+
mask[coords.q, s] = 1
|
144 |
+
indexes = torch.from_numpy(indexes).to(device)
|
145 |
+
mask = torch.from_numpy(mask).to(device)
|
146 |
+
return indexes, mask
|
147 |
+
|
148 |
+
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
149 |
+
"""Build sequence corresponding to the pattern from the input tensor z.
|
150 |
+
The sequence is built using up to sequence_steps if specified, and non-pattern
|
151 |
+
coordinates are filled with the special token.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
155 |
+
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
156 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
157 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
158 |
+
Returns:
|
159 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
160 |
+
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
161 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
162 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
163 |
+
"""
|
164 |
+
B, K, T = z.shape
|
165 |
+
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
166 |
+
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
167 |
+
)
|
168 |
+
z = z.reshape(B, -1)
|
169 |
+
# we append the special token as the last index of our flattened z tensor
|
170 |
+
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
171 |
+
values = z[:, indexes.view(-1)]
|
172 |
+
values = values.view(B, K, indexes.shape[-1])
|
173 |
+
# import pdb; pdb.set_trace()
|
174 |
+
return values, indexes, mask
|
175 |
+
|
176 |
+
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, code_depth: int,
|
177 |
+
keep_only_valid_steps: bool = False,
|
178 |
+
is_model_output: bool = False,
|
179 |
+
device: tp.Union[torch.device, str] = 'cpu'):
|
180 |
+
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
181 |
+
from interleaving pattern.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
sequence_steps (int): Sequence steps.
|
185 |
+
code_depth (int): Number of codebooks.
|
186 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
187 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
188 |
+
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
189 |
+
device (torch.device or str): Device for created tensors.
|
190 |
+
Returns:
|
191 |
+
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
192 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
193 |
+
"""
|
194 |
+
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
195 |
+
timesteps = self.timesteps
|
196 |
+
assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}"
|
197 |
+
assert sequence_steps <= len(ref_layout), \
|
198 |
+
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
199 |
+
|
200 |
+
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
201 |
+
if is_model_output:
|
202 |
+
ref_layout = ref_layout[1:]
|
203 |
+
|
204 |
+
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
205 |
+
indexes = torch.zeros(code_depth, timesteps, dtype=torch.long).numpy()
|
206 |
+
mask = torch.zeros(code_depth, timesteps, dtype=torch.bool).numpy()
|
207 |
+
# fill indexes with last sequence step value that will correspond to our special token
|
208 |
+
indexes[:] = code_depth * sequence_steps
|
209 |
+
for s, sequence_codes in enumerate(ref_layout):
|
210 |
+
if s < sequence_steps:
|
211 |
+
for code in sequence_codes:
|
212 |
+
if code.t < timesteps:
|
213 |
+
indexes[code.q, code.t] = s + code.q * sequence_steps
|
214 |
+
mask[code.q, code.t] = 1
|
215 |
+
indexes = torch.from_numpy(indexes).to(device)
|
216 |
+
mask = torch.from_numpy(mask).to(device)
|
217 |
+
return indexes, mask
|
218 |
+
|
219 |
+
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
220 |
+
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
221 |
+
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
222 |
+
are filled with the special token.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
226 |
+
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
227 |
+
Returns:
|
228 |
+
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
229 |
+
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
230 |
+
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
231 |
+
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
232 |
+
"""
|
233 |
+
B, K, S = s.shape
|
234 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
235 |
+
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
236 |
+
)
|
237 |
+
s = s.view(B, -1)
|
238 |
+
# we append the special token as the last index of our flattened z tensor
|
239 |
+
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
240 |
+
values = s[:, indexes.view(-1)]
|
241 |
+
values = values.view(B, K, indexes.shape[-1])
|
242 |
+
return values, indexes, mask
|
243 |
+
|
244 |
+
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
245 |
+
"""Revert model logits obtained on a sequence built from the pattern
|
246 |
+
back to a tensor matching the original sequence.
|
247 |
+
|
248 |
+
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
249 |
+
1. It is designed to work with the extra cardinality dimension
|
250 |
+
2. We return the logits for the first sequence item that matches the special_token and
|
251 |
+
which matching target in the original sequence is the first item of the sequence,
|
252 |
+
while we skip the last logits as there is no matching target
|
253 |
+
"""
|
254 |
+
B, card, K, S = logits.shape
|
255 |
+
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
256 |
+
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
257 |
+
)
|
258 |
+
logits = logits.reshape(B, card, -1)
|
259 |
+
# we append the special token as the last index of our flattened z tensor
|
260 |
+
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
261 |
+
values = logits[:, :, indexes.view(-1)]
|
262 |
+
|
263 |
+
values = values.view(B, card, K, indexes.shape[-1])
|
264 |
+
return values, indexes, mask
|
265 |
+
|
266 |
+
|
267 |
+
|
268 |
+
class CodebooksPatternProvider(ABC):
|
269 |
+
"""Abstraction around providing pattern for interleaving codebooks.
|
270 |
+
|
271 |
+
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
272 |
+
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
273 |
+
number of codebooks `code_depth`, the pattern provider can generate a specified pattern
|
274 |
+
corresponding to a sequence of `T` timesteps with `code_depth` parallel codebooks. This pattern
|
275 |
+
can be used to construct a new sequence from the original codes respecting the specified
|
276 |
+
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
277 |
+
being a tuple with the original timestep and codebook to build the new sequence.
|
278 |
+
Note that all patterns must start with an empty list that is then used to insert a first
|
279 |
+
sequence step of special tokens in the newly generated sequence.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
code_depth (int): number of codebooks.
|
283 |
+
cached (bool): if True, patterns for a given length are cached. In general
|
284 |
+
that should be true for efficiency reason to avoid synchronization points.
|
285 |
+
"""
|
286 |
+
def __init__(self, code_depth: int, cached: bool = True):
|
287 |
+
assert code_depth > 0
|
288 |
+
self.code_depth = code_depth
|
289 |
+
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
290 |
+
|
291 |
+
@abstractmethod
|
292 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
293 |
+
"""Builds pattern with specific interleaving between codebooks.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
timesteps (int): Total number of timesteps.
|
297 |
+
"""
|
298 |
+
raise NotImplementedError()
|
299 |
+
|
300 |
+
|
301 |
+
class DelayedPatternProvider(CodebooksPatternProvider):
|
302 |
+
"""Provider for delayed pattern across delayed codebooks.
|
303 |
+
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
304 |
+
from different timesteps.
|
305 |
+
|
306 |
+
Example:
|
307 |
+
Taking timesteps=4 and code_depth=3, delays=None, the multi-codebook sequence:
|
308 |
+
[[1, 2, 3, 4],
|
309 |
+
[1, 2, 3, 4],
|
310 |
+
[1, 2, 3, 4]]
|
311 |
+
The resulting sequence obtained from the returned pattern is:
|
312 |
+
[[S, 1, 2, 3, 4],
|
313 |
+
[S, S, 1, 2, 3],
|
314 |
+
[S, S, S, 1, 2]]
|
315 |
+
(with S being a special token)
|
316 |
+
|
317 |
+
Args:
|
318 |
+
code_depth (int): Number of codebooks.
|
319 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
320 |
+
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
321 |
+
flatten_first (int): Flatten the first N timesteps.
|
322 |
+
empty_initial (int): Prepend with N empty list of coordinates.
|
323 |
+
"""
|
324 |
+
def __init__(self, code_depth: int, delays: tp.Optional[tp.List[int]] = None,
|
325 |
+
flatten_first: int = 0, empty_initial: int = 0):
|
326 |
+
super().__init__(code_depth)
|
327 |
+
if delays is None:
|
328 |
+
delays = list(range(code_depth))
|
329 |
+
self.delays = delays
|
330 |
+
self.flatten_first = flatten_first
|
331 |
+
self.empty_initial = empty_initial
|
332 |
+
assert len(self.delays) == self.code_depth
|
333 |
+
assert sorted(self.delays) == self.delays
|
334 |
+
|
335 |
+
def get_pattern(self, timesteps: int) -> Pattern:
|
336 |
+
out: PatternLayout = [[]]
|
337 |
+
max_delay = max(self.delays)
|
338 |
+
if self.empty_initial:
|
339 |
+
out += [[] for _ in range(self.empty_initial)]
|
340 |
+
if self.flatten_first:
|
341 |
+
for t in range(min(timesteps, self.flatten_first)):
|
342 |
+
for q in range(self.code_depth):
|
343 |
+
out.append([LayoutCoord(t, q)])
|
344 |
+
for t in range(self.flatten_first, timesteps + max_delay):
|
345 |
+
v = []
|
346 |
+
for q, delay in enumerate(self.delays):
|
347 |
+
t_for_q = t - delay
|
348 |
+
if t_for_q >= self.flatten_first:
|
349 |
+
v.append(LayoutCoord(t_for_q, q))
|
350 |
+
out.append(v)
|
351 |
+
return Pattern(out, code_depth=self.code_depth, timesteps=timesteps)
|
codeclm/modules/streaming.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Streaming module API that should be implemented by all Streaming components,
|
3 |
+
"""
|
4 |
+
|
5 |
+
from contextlib import contextmanager
|
6 |
+
import typing as tp
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
State = tp.Dict[str, torch.Tensor]
|
12 |
+
|
13 |
+
class StreamingModule(nn.Module):
|
14 |
+
"""Common API for streaming components.
|
15 |
+
|
16 |
+
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
17 |
+
By convention, the first dim of each tensor must be the batch size.
|
18 |
+
Don't use dots in the key names, as this would clash with submodules
|
19 |
+
(like in state_dict).
|
20 |
+
|
21 |
+
If `self._is_streaming` is True, the component should use and remember
|
22 |
+
the proper state inside `self._streaming_state`.
|
23 |
+
|
24 |
+
To set a streaming component in streaming state, use
|
25 |
+
|
26 |
+
with module.streaming():
|
27 |
+
...
|
28 |
+
|
29 |
+
This will automatically reset the streaming state when exiting the context manager.
|
30 |
+
This also automatically propagates to all streaming children module.
|
31 |
+
|
32 |
+
Some module might also implement the `StreamingModule.flush` method, although
|
33 |
+
this one is trickier, as all parents module must be StreamingModule and implement
|
34 |
+
it as well for it to work properly. See `StreamingSequential` after.
|
35 |
+
"""
|
36 |
+
def __init__(self) -> None:
|
37 |
+
super().__init__()
|
38 |
+
self._streaming_state: State = {}
|
39 |
+
self._is_streaming = False
|
40 |
+
|
41 |
+
def _apply_named_streaming(self, fn: tp.Any):
|
42 |
+
for name, module in self.named_modules():
|
43 |
+
if isinstance(module, StreamingModule):
|
44 |
+
fn(name, module)
|
45 |
+
|
46 |
+
def _set_streaming(self, streaming: bool):
|
47 |
+
def _set_streaming(name, module):
|
48 |
+
module._is_streaming = streaming
|
49 |
+
self._apply_named_streaming(_set_streaming)
|
50 |
+
|
51 |
+
@contextmanager
|
52 |
+
def streaming(self):
|
53 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
54 |
+
self._set_streaming(True)
|
55 |
+
try:
|
56 |
+
yield
|
57 |
+
finally:
|
58 |
+
self._set_streaming(False)
|
59 |
+
self.reset_streaming()
|
60 |
+
|
61 |
+
def reset_streaming(self):
|
62 |
+
"""Reset the streaming state."""
|
63 |
+
def _reset(name: str, module: StreamingModule):
|
64 |
+
module._streaming_state.clear()
|
65 |
+
|
66 |
+
self._apply_named_streaming(_reset)
|
67 |
+
|
68 |
+
def get_streaming_state(self) -> State:
|
69 |
+
"""Return the streaming state, including that of sub-modules."""
|
70 |
+
state: State = {}
|
71 |
+
|
72 |
+
def _add(name: str, module: StreamingModule):
|
73 |
+
if name:
|
74 |
+
name += "."
|
75 |
+
for key, value in module._streaming_state.items():
|
76 |
+
state[name + key] = value
|
77 |
+
|
78 |
+
self._apply_named_streaming(_add)
|
79 |
+
return state
|
80 |
+
|
81 |
+
def set_streaming_state(self, state: State):
|
82 |
+
"""Set the streaming state, including that of sub-modules."""
|
83 |
+
state = dict(state)
|
84 |
+
|
85 |
+
def _set(name: str, module: StreamingModule):
|
86 |
+
if name:
|
87 |
+
name += "."
|
88 |
+
module._streaming_state.clear()
|
89 |
+
for key, value in list(state.items()):
|
90 |
+
# complexity is not ideal here, but probably fine.
|
91 |
+
if key.startswith(name):
|
92 |
+
local_key = key[len(name):]
|
93 |
+
if '.' not in local_key:
|
94 |
+
module._streaming_state[local_key] = value
|
95 |
+
del state[key]
|
96 |
+
|
97 |
+
self._apply_named_streaming(_set)
|
98 |
+
assert len(state) == 0, list(state.keys())
|
99 |
+
|
100 |
+
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
101 |
+
"""Flush any remaining outputs that were waiting for completion.
|
102 |
+
Typically, for convolutions, this will add the final padding
|
103 |
+
and process the last buffer.
|
104 |
+
|
105 |
+
This should take an optional argument `x`, which will be provided
|
106 |
+
if a module before this one in the streaming pipeline has already
|
107 |
+
spitted out a flushed out buffer.
|
108 |
+
"""
|
109 |
+
if x is None:
|
110 |
+
return None
|
111 |
+
else:
|
112 |
+
return self(x)
|
codeclm/tokenizer/Flow1dVAE/audio.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@File : audio.py
|
5 |
+
@Time : 2023/8/8 下午7:18
|
6 |
+
@Author : waytan
|
7 |
+
@Contact : [email protected]
|
8 |
+
@License : (C)Copyright 2023, Tencent
|
9 |
+
@Desc : Audio
|
10 |
+
"""
|
11 |
+
import json
|
12 |
+
import subprocess as sp
|
13 |
+
import typing as tp
|
14 |
+
from pathlib import Path
|
15 |
+
|
16 |
+
import lameenc
|
17 |
+
import julius
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import torchaudio as ta
|
21 |
+
from contextlib import contextmanager
|
22 |
+
import tempfile
|
23 |
+
import os
|
24 |
+
|
25 |
+
@contextmanager
|
26 |
+
def temp_filenames(count: int, delete=True):
|
27 |
+
names = []
|
28 |
+
try:
|
29 |
+
for _ in range(count):
|
30 |
+
names.append(tempfile.NamedTemporaryFile(delete=False).name)
|
31 |
+
yield names
|
32 |
+
finally:
|
33 |
+
if delete:
|
34 |
+
for name in names:
|
35 |
+
os.unlink(name)
|
36 |
+
|
37 |
+
|
38 |
+
def _read_info(path):
|
39 |
+
stdout_data = sp.check_output([
|
40 |
+
'ffprobe', "-loglevel", "panic",
|
41 |
+
str(path), '-print_format', 'json', '-show_format', '-show_streams'
|
42 |
+
])
|
43 |
+
return json.loads(stdout_data.decode('utf-8'))
|
44 |
+
|
45 |
+
|
46 |
+
class AudioFile:
|
47 |
+
"""
|
48 |
+
Allows to read audio from any format supported by ffmpeg, as well as resampling or
|
49 |
+
converting to mono on the fly. See :method:`read` for more details.
|
50 |
+
"""
|
51 |
+
def __init__(self, path: Path):
|
52 |
+
self.path = Path(path)
|
53 |
+
self._info = None
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
features = [("path", self.path)]
|
57 |
+
features.append(("samplerate", self.samplerate()))
|
58 |
+
features.append(("channels", self.channels()))
|
59 |
+
features.append(("streams", len(self)))
|
60 |
+
features_str = ", ".join(f"{name}={value}" for name, value in features)
|
61 |
+
return f"AudioFile({features_str})"
|
62 |
+
|
63 |
+
@property
|
64 |
+
def info(self):
|
65 |
+
if self._info is None:
|
66 |
+
self._info = _read_info(self.path)
|
67 |
+
return self._info
|
68 |
+
|
69 |
+
@property
|
70 |
+
def duration(self):
|
71 |
+
return float(self.info['format']['duration'])
|
72 |
+
|
73 |
+
@property
|
74 |
+
def _audio_streams(self):
|
75 |
+
return [
|
76 |
+
index for index, stream in enumerate(self.info["streams"])
|
77 |
+
if stream["codec_type"] == "audio"
|
78 |
+
]
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
return len(self._audio_streams)
|
82 |
+
|
83 |
+
def channels(self, stream=0):
|
84 |
+
return int(self.info['streams'][self._audio_streams[stream]]['channels'])
|
85 |
+
|
86 |
+
def samplerate(self, stream=0):
|
87 |
+
return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
|
88 |
+
|
89 |
+
def read(self,
|
90 |
+
seek_time=None,
|
91 |
+
duration=None,
|
92 |
+
streams=slice(None),
|
93 |
+
samplerate=None,
|
94 |
+
channels=None):
|
95 |
+
"""
|
96 |
+
Slightly more efficient implementation than stempeg,
|
97 |
+
in particular, this will extract all stems at once
|
98 |
+
rather than having to loop over one file multiple times
|
99 |
+
for each stream.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
seek_time (float): seek time in seconds or None if no seeking is needed.
|
103 |
+
duration (float): duration in seconds to extract or None to extract until the end.
|
104 |
+
streams (slice, int or list): streams to extract, can be a single int, a list or
|
105 |
+
a slice. If it is a slice or list, the output will be of size [S, C, T]
|
106 |
+
with S the number of streams, C the number of channels and T the number of samples.
|
107 |
+
If it is an int, the output will be [C, T].
|
108 |
+
samplerate (int): if provided, will resample on the fly. If None, no resampling will
|
109 |
+
be done. Original sampling rate can be obtained with :method:`samplerate`.
|
110 |
+
channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
|
111 |
+
as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
|
112 |
+
See https://sound.stackexchange.com/a/42710.
|
113 |
+
Our definition of mono is simply the average of the two channels. Any other
|
114 |
+
value will be ignored.
|
115 |
+
"""
|
116 |
+
streams = np.array(range(len(self)))[streams]
|
117 |
+
single = not isinstance(streams, np.ndarray)
|
118 |
+
if single:
|
119 |
+
streams = [streams]
|
120 |
+
|
121 |
+
if duration is None:
|
122 |
+
target_size = None
|
123 |
+
query_duration = None
|
124 |
+
else:
|
125 |
+
target_size = int((samplerate or self.samplerate()) * duration)
|
126 |
+
query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
|
127 |
+
|
128 |
+
with temp_filenames(len(streams)) as filenames:
|
129 |
+
command = ['ffmpeg', '-y']
|
130 |
+
command += ['-loglevel', 'panic']
|
131 |
+
if seek_time:
|
132 |
+
command += ['-ss', str(seek_time)]
|
133 |
+
command += ['-i', str(self.path)]
|
134 |
+
for stream, filename in zip(streams, filenames):
|
135 |
+
command += ['-map', f'0:{self._audio_streams[stream]}']
|
136 |
+
if query_duration is not None:
|
137 |
+
command += ['-t', str(query_duration)]
|
138 |
+
command += ['-threads', '1']
|
139 |
+
command += ['-f', 'f32le']
|
140 |
+
if samplerate is not None:
|
141 |
+
command += ['-ar', str(samplerate)]
|
142 |
+
command += [filename]
|
143 |
+
|
144 |
+
sp.run(command, check=True)
|
145 |
+
wavs = []
|
146 |
+
for filename in filenames:
|
147 |
+
wav = np.fromfile(filename, dtype=np.float32)
|
148 |
+
wav = torch.from_numpy(wav)
|
149 |
+
wav = wav.view(-1, self.channels()).t()
|
150 |
+
if channels is not None:
|
151 |
+
wav = convert_audio_channels(wav, channels)
|
152 |
+
if target_size is not None:
|
153 |
+
wav = wav[..., :target_size]
|
154 |
+
wavs.append(wav)
|
155 |
+
wav = torch.stack(wavs, dim=0)
|
156 |
+
if single:
|
157 |
+
wav = wav[0]
|
158 |
+
return wav
|
159 |
+
|
160 |
+
|
161 |
+
def convert_audio_channels(wav, channels=2):
|
162 |
+
"""Convert audio to the given number of channels."""
|
163 |
+
*shape, src_channels, length = wav.shape
|
164 |
+
if src_channels == channels:
|
165 |
+
pass
|
166 |
+
elif channels == 1:
|
167 |
+
# Case 1:
|
168 |
+
# The caller asked 1-channel audio, but the stream have multiple
|
169 |
+
# channels, downmix all channels.
|
170 |
+
wav = wav.mean(dim=-2, keepdim=True)
|
171 |
+
elif src_channels == 1:
|
172 |
+
# Case 2:
|
173 |
+
# The caller asked for multiple channels, but the input file have
|
174 |
+
# one single channel, replicate the audio over all channels.
|
175 |
+
wav = wav.expand(*shape, channels, length)
|
176 |
+
elif src_channels >= channels:
|
177 |
+
# Case 3:
|
178 |
+
# The caller asked for multiple channels, and the input file have
|
179 |
+
# more channels than requested. In that case return the first channels.
|
180 |
+
wav = wav[..., :channels, :]
|
181 |
+
else:
|
182 |
+
# Case 4: What is a reasonable choice here?
|
183 |
+
raise ValueError('The audio file has less channels than requested but is not mono.')
|
184 |
+
return wav
|
185 |
+
|
186 |
+
|
187 |
+
def convert_audio(wav, from_samplerate, to_samplerate, channels):
|
188 |
+
"""Convert audio from a given samplerate to a target one and target number of channels."""
|
189 |
+
wav = convert_audio_channels(wav, channels)
|
190 |
+
return julius.resample_frac(wav, from_samplerate, to_samplerate)
|
191 |
+
|
192 |
+
|
193 |
+
def i16_pcm(wav):
|
194 |
+
"""Convert audio to 16 bits integer PCM format."""
|
195 |
+
if wav.dtype.is_floating_point:
|
196 |
+
return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
|
197 |
+
else:
|
198 |
+
return wav
|
199 |
+
|
200 |
+
|
201 |
+
def f32_pcm(wav):
|
202 |
+
"""Convert audio to float 32 bits PCM format."""
|
203 |
+
if wav.dtype.is_floating_point:
|
204 |
+
return wav
|
205 |
+
else:
|
206 |
+
return wav.float() / (2**15 - 1)
|
207 |
+
|
208 |
+
|
209 |
+
def as_dtype_pcm(wav):
|
210 |
+
"""Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
|
211 |
+
if wav.dtype.is_floating_point:
|
212 |
+
return f32_pcm(wav)
|
213 |
+
else:
|
214 |
+
return i16_pcm(wav)
|
215 |
+
|
216 |
+
|
217 |
+
def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
|
218 |
+
"""Save given audio as mp3. This should work on all OSes."""
|
219 |
+
c, _ = wav.shape
|
220 |
+
wav = i16_pcm(wav)
|
221 |
+
encoder = lameenc.Encoder()
|
222 |
+
encoder.set_bit_rate(bitrate)
|
223 |
+
encoder.set_in_sample_rate(samplerate)
|
224 |
+
encoder.set_channels(c)
|
225 |
+
encoder.set_quality(2) # 2-highest, 7-fastest
|
226 |
+
if not verbose:
|
227 |
+
encoder.silence()
|
228 |
+
wav = wav.data.cpu()
|
229 |
+
wav = wav.transpose(0, 1).numpy()
|
230 |
+
mp3_data = encoder.encode(wav.tobytes())
|
231 |
+
mp3_data += encoder.flush()
|
232 |
+
with open(path, "wb") as f:
|
233 |
+
f.write(mp3_data)
|
234 |
+
|
235 |
+
|
236 |
+
def prevent_clip(wav, mode='rescale'):
|
237 |
+
"""
|
238 |
+
different strategies for avoiding raw clipping.
|
239 |
+
"""
|
240 |
+
if mode is None or mode == 'none':
|
241 |
+
return wav
|
242 |
+
assert wav.dtype.is_floating_point, "too late for clipping"
|
243 |
+
if mode == 'rescale':
|
244 |
+
wav = wav / max(1.01 * wav.abs().max(), 1)
|
245 |
+
elif mode == 'clamp':
|
246 |
+
wav = wav.clamp(-0.99, 0.99)
|
247 |
+
elif mode == 'tanh':
|
248 |
+
wav = torch.tanh(wav)
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Invalid mode {mode}")
|
251 |
+
return wav
|
252 |
+
|
253 |
+
|
254 |
+
def save_audio(wav: torch.Tensor,
|
255 |
+
path: tp.Union[str, Path],
|
256 |
+
samplerate: int,
|
257 |
+
bitrate: int = 320,
|
258 |
+
clip: tp.Union[str] = 'rescale',
|
259 |
+
bits_per_sample: tp.Union[int] = 16,
|
260 |
+
as_float: bool = False):
|
261 |
+
"""Save audio file, automatically preventing clipping if necessary
|
262 |
+
based on the given `clip` strategy. If the path ends in `.mp3`, this
|
263 |
+
will save as mp3 with the given `bitrate`.
|
264 |
+
"""
|
265 |
+
wav = prevent_clip(wav, mode=clip)
|
266 |
+
path = Path(path)
|
267 |
+
suffix = path.suffix.lower()
|
268 |
+
if suffix == ".mp3":
|
269 |
+
encode_mp3(wav, path, samplerate, bitrate, verbose=True)
|
270 |
+
elif suffix == ".wav":
|
271 |
+
if as_float:
|
272 |
+
bits_per_sample = 32
|
273 |
+
encoding = 'PCM_F'
|
274 |
+
else:
|
275 |
+
encoding = 'PCM_S'
|
276 |
+
ta.save(str(path), wav, sample_rate=samplerate,
|
277 |
+
encoding=encoding, bits_per_sample=bits_per_sample)
|
278 |
+
elif suffix == ".flac":
|
279 |
+
ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
|
280 |
+
else:
|
281 |
+
raise ValueError(f"Invalid suffix for path: {suffix}")
|
282 |
+
|
283 |
+
|
284 |
+
def load_track(track, audio_channels, samplerate):
|
285 |
+
errors = {}
|
286 |
+
wav = None
|
287 |
+
|
288 |
+
try:
|
289 |
+
wav = AudioFile(track).read(
|
290 |
+
streams=0,
|
291 |
+
samplerate=samplerate,
|
292 |
+
channels=audio_channels)
|
293 |
+
except sp.CalledProcessError:
|
294 |
+
errors['ffmpeg'] = 'FFmpeg could not read the file.'
|
295 |
+
|
296 |
+
if wav is None:
|
297 |
+
try:
|
298 |
+
wav, sr = ta.load(str(track))
|
299 |
+
except RuntimeError as err:
|
300 |
+
errors['torchaudio'] = err.args[0]
|
301 |
+
else:
|
302 |
+
wav = convert_audio(wav, sr, samplerate, audio_channels)
|
303 |
+
|
304 |
+
return wav, errors
|
codeclm/tokenizer/Flow1dVAE/cal_token_stat.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kaldiio
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
bar = torch.zeros(1, 16384)
|
7 |
+
with open('token.scp', 'r') as f:
|
8 |
+
for item_idx, line in tqdm(enumerate(f)):
|
9 |
+
idx, pos = line.strip().split()
|
10 |
+
codes = kaldiio.load_mat(pos)
|
11 |
+
for i0 in range(codes.shape[-1]):
|
12 |
+
bar[0, codes[0, 0, i0]] += 1
|
13 |
+
if(item_idx % 1000 == 0):
|
14 |
+
print("=========")
|
15 |
+
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
16 |
+
print("=========")
|
17 |
+
print("=========")
|
18 |
+
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
19 |
+
print("=========")
|
codeclm/tokenizer/Flow1dVAE/compare_model_weight.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
from safetensors.torch import load_file
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
m0, m1 = sys.argv[1], sys.argv[2]
|
7 |
+
m0 = load_file(m0)
|
8 |
+
m1 = load_file(m1)
|
9 |
+
|
10 |
+
ks = [k for k in m0.keys() if 'bestrq' in k]
|
11 |
+
for k in ks:
|
12 |
+
print(k, (m0[k] - m1[k]).abs().sum())
|
13 |
+
|
codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "Transformer2DModel",
|
3 |
+
"_diffusers_version": "0.22.0.dev0",
|
4 |
+
"activation_fn": "gelu-approximate",
|
5 |
+
"attention_bias": true,
|
6 |
+
"attention_head_dim": 72,
|
7 |
+
"attention_type": "default",
|
8 |
+
"cross_attention_dim": null,
|
9 |
+
"double_self_attention": false,
|
10 |
+
"dropout": 0.0,
|
11 |
+
"in_channels": 96,
|
12 |
+
"norm_elementwise_affine": false,
|
13 |
+
"norm_eps": 1e-06,
|
14 |
+
"norm_num_groups": 32,
|
15 |
+
"norm_type": "ada_norm_single",
|
16 |
+
"num_attention_heads": 22,
|
17 |
+
"num_embeds_ada_norm": 1000,
|
18 |
+
"num_layers": 24,
|
19 |
+
"num_vector_embeds": null,
|
20 |
+
"only_cross_attention": false,
|
21 |
+
"out_channels": 32,
|
22 |
+
"patch_size": 2,
|
23 |
+
"sample_size": 384,
|
24 |
+
"upcast_attention": false,
|
25 |
+
"use_linear_projection": false
|
26 |
+
}
|
codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "DDIMScheduler",
|
3 |
+
"_diffusers_version": "0.8.0",
|
4 |
+
"beta_end": 0.02,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 0.0015,
|
7 |
+
"clip_sample": false,
|
8 |
+
"num_train_timesteps": 1000,
|
9 |
+
"prediction_type": "sample",
|
10 |
+
"set_alpha_to_one": false,
|
11 |
+
"skip_prk_steps": true,
|
12 |
+
"steps_offset": 1,
|
13 |
+
"trained_betas": null
|
14 |
+
}
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,torchaudio
|
2 |
+
import os,sys,json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
7 |
+
from generate_septoken import Tango as Tango_sep
|
8 |
+
from generate_2rvq import Tango as Tango_1x2
|
9 |
+
import kaldiio
|
10 |
+
from kaldiio import WriteHelper
|
11 |
+
from audio import AudioFile
|
12 |
+
|
13 |
+
from demucs.models.pretrained import get_model_from_yaml
|
14 |
+
from filelock import FileLock
|
15 |
+
|
16 |
+
# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
|
17 |
+
class Separator:
|
18 |
+
def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
19 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
20 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
21 |
+
else:
|
22 |
+
self.device = torch.device("cpu")
|
23 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
24 |
+
|
25 |
+
def init_demucs_model(self, model_path, config_path):
|
26 |
+
model = get_model_from_yaml(config_path, model_path)
|
27 |
+
model.to(self.device)
|
28 |
+
model.eval()
|
29 |
+
return model
|
30 |
+
|
31 |
+
def load_audio(self, f):
|
32 |
+
a, fs = torchaudio.load(f)
|
33 |
+
if (fs != 48000):
|
34 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
35 |
+
# if a.shape[-1] >= 48000*10:
|
36 |
+
# a = a[..., :48000*10]
|
37 |
+
# else:
|
38 |
+
# a = torch.cat([a, a], -1)
|
39 |
+
# return a[:, 0:48000*10]
|
40 |
+
return a
|
41 |
+
|
42 |
+
def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
|
43 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
44 |
+
output_paths = []
|
45 |
+
# lock_path = os.path.join(output_dir, f"{name}.lock")
|
46 |
+
# with FileLock(lock_path): # 加一个避免多卡访问时死锁
|
47 |
+
for stem in self.demucs_model.sources:
|
48 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
49 |
+
if os.path.exists(output_path):
|
50 |
+
output_paths.append(output_path)
|
51 |
+
if len(output_paths) == 1: # 4
|
52 |
+
# drums_path, bass_path, other_path, vocal_path = output_paths
|
53 |
+
vocal_path = output_paths[0]
|
54 |
+
else:
|
55 |
+
lock_path = os.path.join(output_dir, f"{name}_separate.lock")
|
56 |
+
with FileLock(lock_path):
|
57 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
58 |
+
full_audio = self.load_audio(audio_path)
|
59 |
+
vocal_audio = self.load_audio(vocal_path)
|
60 |
+
minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
|
61 |
+
# bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
|
62 |
+
bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
|
63 |
+
for path in [drums_path, bass_path, other_path, vocal_path]:
|
64 |
+
os.remove(path)
|
65 |
+
return full_audio, vocal_audio, bgm_audio
|
66 |
+
|
67 |
+
def read_wav(fname, sample_rate=48_000):
|
68 |
+
try:
|
69 |
+
orig_samples, fs = torchaudio.load(fname)
|
70 |
+
except:
|
71 |
+
af = AudioFile(fname)
|
72 |
+
orig_samples = af.read()
|
73 |
+
fs = af.samplerate()
|
74 |
+
orig_samples = orig_samples[0]
|
75 |
+
if(fs!=sample_rate):
|
76 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
|
77 |
+
fs = sample_rate
|
78 |
+
if orig_samples.shape[0] == 1:
|
79 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
80 |
+
return orig_samples
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
# Define Model
|
84 |
+
json_path = sys.argv[1]
|
85 |
+
|
86 |
+
mus_infos = []
|
87 |
+
with open(json_path) as f:
|
88 |
+
for line in f:
|
89 |
+
item = json.loads(line)
|
90 |
+
mus_infos.append(item)
|
91 |
+
|
92 |
+
tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors")
|
93 |
+
tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
|
94 |
+
separator = Separator()
|
95 |
+
|
96 |
+
# Feature extraction loop
|
97 |
+
# for i in tqdm(range(2000)):
|
98 |
+
first_time = True
|
99 |
+
for item in tqdm(mus_infos):
|
100 |
+
if(os.path.exists(item['path'])):
|
101 |
+
full_path = item['path']
|
102 |
+
else:
|
103 |
+
full_path = '/mnt/share/' + item['path']
|
104 |
+
|
105 |
+
full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path)
|
106 |
+
|
107 |
+
# full_tensor = read_wav(full_path)
|
108 |
+
# vocal_tensor = read_wav(vocal_path)
|
109 |
+
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
110 |
+
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
111 |
+
# bgm_tensor = full_tensor - vocal_tensor
|
112 |
+
codes_1x2 = tango_1x2.sound2code(full_tensor)
|
113 |
+
codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor)
|
114 |
+
codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy()
|
115 |
+
save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy')
|
116 |
+
assert save_path != full_path, (save_path, full_path)
|
117 |
+
np.save(save_path, codes)
|
118 |
+
|
119 |
+
if(first_time):
|
120 |
+
first_time = False
|
121 |
+
print(codes_vocal.shape, codes_bgm.shape)
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,torchaudio
|
2 |
+
import os,sys,json
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
6 |
+
from generate_septoken import Tango
|
7 |
+
import kaldiio
|
8 |
+
from kaldiio import WriteHelper
|
9 |
+
from audio import AudioFile
|
10 |
+
|
11 |
+
def read_wav(fname, sample_rate=48_000):
|
12 |
+
try:
|
13 |
+
orig_samples, fs = torchaudio.load(fname)
|
14 |
+
except:
|
15 |
+
af = AudioFile(fname)
|
16 |
+
orig_samples = af.read()
|
17 |
+
fs = af.samplerate()
|
18 |
+
orig_samples = orig_samples[0]
|
19 |
+
if(fs!=sample_rate):
|
20 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
|
21 |
+
fs = sample_rate
|
22 |
+
if orig_samples.shape[0] == 1:
|
23 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
24 |
+
return orig_samples
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
# Define Model
|
28 |
+
json_path = sys.argv[1]
|
29 |
+
outdir = sys.argv[2]
|
30 |
+
|
31 |
+
mus_infos = []
|
32 |
+
with open(json_path) as f:
|
33 |
+
for line in f:
|
34 |
+
item = json.loads(line)
|
35 |
+
mus_infos.append(item)
|
36 |
+
|
37 |
+
tango = Tango(model_path="./saved/model_septoken/model_2.safetensors")
|
38 |
+
|
39 |
+
|
40 |
+
# Feature extraction loop
|
41 |
+
# for i in tqdm(range(2000)):
|
42 |
+
first_time = True
|
43 |
+
with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm:
|
44 |
+
print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir))
|
45 |
+
print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir))
|
46 |
+
for item in tqdm(mus_infos):
|
47 |
+
try:
|
48 |
+
# if True:
|
49 |
+
idx = item['idx']
|
50 |
+
# print(idx)
|
51 |
+
if(os.path.exists(item['path'])):
|
52 |
+
full_path = item['path']
|
53 |
+
else:
|
54 |
+
full_path = '/mnt/share/' + item['path']
|
55 |
+
if(os.path.exists(item['vocal_path'])):
|
56 |
+
vocal_path = item['vocal_path']
|
57 |
+
bgm_paths = item['bgm_path']
|
58 |
+
else:
|
59 |
+
vocal_path = '/mnt/share/' + item['vocal_path']
|
60 |
+
bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']]
|
61 |
+
vocal_tensor = read_wav(vocal_path)
|
62 |
+
# full_tensor = read_wav(full_path)
|
63 |
+
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
64 |
+
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
65 |
+
# bgm_tensor = full_tensor - vocal_tensor
|
66 |
+
bgm_tensor = sum([read_wav(p) for p in bgm_paths])
|
67 |
+
codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
|
68 |
+
writer_vocal(str(idx), codes_vocal.cpu())
|
69 |
+
writer_bgm(str(idx), codes_bgm.cpu())
|
70 |
+
if(first_time):
|
71 |
+
first_time = False
|
72 |
+
print(codes_vocal.shape, codes_bgm.shape)
|
73 |
+
except:
|
74 |
+
print(item['vocal_path'])
|
75 |
+
print(item['bgm_path'])
|
76 |
+
continue
|
77 |
+
|
78 |
+
# idx = item['idx']
|
79 |
+
# # print(idx)
|
80 |
+
# full_path = item['path']
|
81 |
+
# vocal_path = item['vocal_path']
|
82 |
+
# bgm_paths = item['bgm_path']
|
83 |
+
# full_tensor = read_wav(full_path)
|
84 |
+
# vocal_tensor = read_wav(vocal_path)
|
85 |
+
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
86 |
+
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
87 |
+
# bgm_tensor = full_tensor - vocal_tensor
|
88 |
+
# codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
|
89 |
+
# writer_vocal(str(idx), codes_vocal.cpu())
|
90 |
+
# writer_bgm(str(idx), codes_bgm.cpu())
|
91 |
+
# if(first_time):
|
92 |
+
# first_time = False
|
93 |
+
# print(codes_vocal.shape, codes_bgm.shape)
|
94 |
+
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,torchaudio
|
2 |
+
import os,sys,json
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
6 |
+
from generate_2rvq import Tango
|
7 |
+
import kaldiio
|
8 |
+
from kaldiio import WriteHelper
|
9 |
+
import torch
|
10 |
+
import subprocess
|
11 |
+
import time
|
12 |
+
import sys
|
13 |
+
|
14 |
+
def get_gpu_memory():
|
15 |
+
_output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
16 |
+
|
17 |
+
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
18 |
+
COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
|
19 |
+
memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
|
20 |
+
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
21 |
+
return memory_free_values
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
# Define Model
|
25 |
+
json_path = sys.argv[1]
|
26 |
+
outdir = sys.argv[2]
|
27 |
+
|
28 |
+
gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
|
29 |
+
while True:
|
30 |
+
free_mem = get_gpu_memory()
|
31 |
+
free_mem = free_mem[gpu_idx]
|
32 |
+
if(free_mem > 25_000):
|
33 |
+
print("GPU memory {}, run matrix cal".format(free_mem))
|
34 |
+
break
|
35 |
+
else:
|
36 |
+
print("GPU memory {}, sleep 1min".format(free_mem))
|
37 |
+
time.sleep(60)
|
38 |
+
|
39 |
+
mus_infos = []
|
40 |
+
with open(json_path) as f:
|
41 |
+
for line in f:
|
42 |
+
item = json.loads(line)
|
43 |
+
mus_infos.append(item)
|
44 |
+
|
45 |
+
tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
|
46 |
+
|
47 |
+
|
48 |
+
# Feature extraction loop
|
49 |
+
# for i in tqdm(range(2000)):
|
50 |
+
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
51 |
+
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
52 |
+
for item in tqdm(mus_infos):
|
53 |
+
try:
|
54 |
+
# if True:
|
55 |
+
idx = item['idx']
|
56 |
+
# print(idx)
|
57 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
58 |
+
if(os.path.exists(item['path'])):
|
59 |
+
codes = tango.file2code(item['path'])
|
60 |
+
else:
|
61 |
+
codes = tango.file2code('/mnt/share/' + item['path'])
|
62 |
+
writer(str(idx), codes.cpu())
|
63 |
+
except:
|
64 |
+
print(item['path'])
|
65 |
+
continue
|
66 |
+
# idx = item['idx']
|
67 |
+
# # print(idx)
|
68 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
69 |
+
# codes = tango.file2code(item['path'])
|
70 |
+
# writer(str(idx), codes.cpu())
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,torchaudio
|
2 |
+
import os,sys,json
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
6 |
+
from generate_4rvq import Tango
|
7 |
+
import kaldiio
|
8 |
+
from kaldiio import WriteHelper
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
# Define Model
|
12 |
+
json_path = sys.argv[1]
|
13 |
+
outdir = sys.argv[2]
|
14 |
+
|
15 |
+
mus_infos = []
|
16 |
+
with open(json_path) as f:
|
17 |
+
for line in f:
|
18 |
+
item = json.loads(line)
|
19 |
+
mus_infos.append(item)
|
20 |
+
|
21 |
+
tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
|
22 |
+
|
23 |
+
|
24 |
+
# Feature extraction loop
|
25 |
+
# for i in tqdm(range(2000)):
|
26 |
+
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
27 |
+
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
28 |
+
for item in tqdm(mus_infos):
|
29 |
+
try:
|
30 |
+
# if True:
|
31 |
+
idx = item['idx']
|
32 |
+
# print(idx)
|
33 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
34 |
+
if(os.path.exists(item['path'])):
|
35 |
+
codes = tango.file2code(item['path'])
|
36 |
+
else:
|
37 |
+
codes = tango.file2code('/mnt/share/' + item['path'])
|
38 |
+
writer(str(idx), codes.cpu())
|
39 |
+
except:
|
40 |
+
print(item['path'])
|
41 |
+
continue
|
42 |
+
# idx = item['idx']
|
43 |
+
# # print(idx)
|
44 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
45 |
+
# codes = tango.file2code(item['path'])
|
46 |
+
# writer(str(idx), codes.cpu())
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,torchaudio
|
2 |
+
import os,sys,json
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
6 |
+
from generate_4rvq import Tango
|
7 |
+
import kaldiio
|
8 |
+
from kaldiio import WriteHelper
|
9 |
+
import torch
|
10 |
+
import subprocess
|
11 |
+
import time
|
12 |
+
import sys
|
13 |
+
|
14 |
+
def get_gpu_memory():
|
15 |
+
_output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
16 |
+
|
17 |
+
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
18 |
+
COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
|
19 |
+
memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
|
20 |
+
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
21 |
+
return memory_free_values
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
# Define Model
|
25 |
+
json_path = sys.argv[1]
|
26 |
+
outdir = sys.argv[2]
|
27 |
+
ds = int(sys.argv[3])
|
28 |
+
|
29 |
+
gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
|
30 |
+
while True:
|
31 |
+
free_mem = get_gpu_memory()
|
32 |
+
free_mem = free_mem[gpu_idx]
|
33 |
+
if(free_mem > 25_000):
|
34 |
+
print("GPU memory {}, run matrix cal".format(free_mem))
|
35 |
+
break
|
36 |
+
else:
|
37 |
+
print("GPU memory {}, sleep 1min".format(free_mem))
|
38 |
+
time.sleep(60)
|
39 |
+
|
40 |
+
mus_infos = []
|
41 |
+
with open(json_path) as f:
|
42 |
+
for line in f:
|
43 |
+
item = json.loads(line)
|
44 |
+
mus_infos.append(item)
|
45 |
+
|
46 |
+
tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
|
47 |
+
|
48 |
+
|
49 |
+
# Feature extraction loop
|
50 |
+
# for i in tqdm(range(2000)):
|
51 |
+
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
52 |
+
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
53 |
+
bar = torch.zeros(4, 16384)
|
54 |
+
for item_idx, item in tqdm(enumerate(mus_infos)):
|
55 |
+
try:
|
56 |
+
# if True:
|
57 |
+
idx = item['idx']
|
58 |
+
# print(idx)
|
59 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
60 |
+
if(os.path.exists(item['path'])):
|
61 |
+
codes = tango.file2code_ds(item['path'], ds)
|
62 |
+
else:
|
63 |
+
codes = tango.file2code_ds('/mnt/share/' + item['path'], ds)
|
64 |
+
codes = codes.cpu()
|
65 |
+
writer(str(idx), codes)
|
66 |
+
for i0 in range(codes.shape[-1]):
|
67 |
+
bar[0, codes[0, 0, i0]] += 1
|
68 |
+
bar[1, codes[0, 1, i0]] += 1
|
69 |
+
bar[2, codes[0, 2, i0]] += 1
|
70 |
+
bar[3, codes[0, 3, i0]] += 1
|
71 |
+
except Exception as e:
|
72 |
+
print(item['path'])
|
73 |
+
# print(e.message, e.args)
|
74 |
+
# exit(1)
|
75 |
+
continue
|
76 |
+
|
77 |
+
if(item_idx % 1000 == 0):
|
78 |
+
print("=========")
|
79 |
+
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
80 |
+
print("=========")
|
81 |
+
|
82 |
+
# idx = item['idx']
|
83 |
+
# # print(idx)
|
84 |
+
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
85 |
+
# codes = tango.file2code(item['path'])
|
86 |
+
# writer(str(idx), codes.cpu())
|
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from model_1rvq import PromptCondAudioDiffusion
|
5 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
6 |
+
import torchaudio
|
7 |
+
import librosa
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
from tools.get_1dvae_large import get_model
|
12 |
+
import tools.torch_tools as torch_tools
|
13 |
+
from safetensors.torch import load_file
|
14 |
+
|
15 |
+
class Tango:
|
16 |
+
def __init__(self, \
|
17 |
+
model_path, \
|
18 |
+
vae_config="",
|
19 |
+
vae_model="",
|
20 |
+
layer_num=6, \
|
21 |
+
device="cuda:0"):
|
22 |
+
|
23 |
+
self.sample_rate = 48000
|
24 |
+
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
25 |
+
self.device = device
|
26 |
+
|
27 |
+
self.vae = get_model(vae_config, vae_model)
|
28 |
+
self.vae = self.vae.to(device)
|
29 |
+
self.vae=self.vae.eval()
|
30 |
+
self.layer_num = layer_num
|
31 |
+
|
32 |
+
self.MAX_DURATION = 360
|
33 |
+
main_config = {
|
34 |
+
"num_channels":32,
|
35 |
+
"unet_model_name":None,
|
36 |
+
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
37 |
+
"snr_gamma":None,
|
38 |
+
}
|
39 |
+
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
40 |
+
if model_path.endswith(".safetensors"):
|
41 |
+
main_weights = load_file(model_path)
|
42 |
+
else:
|
43 |
+
main_weights = torch.load(model_path, map_location=device)
|
44 |
+
self.model.load_state_dict(main_weights, strict=False)
|
45 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
46 |
+
|
47 |
+
self.model.eval()
|
48 |
+
self.model.init_device_dtype(torch.device(device), torch.float32)
|
49 |
+
print("scaling factor: ", self.model.normfeat.std)
|
50 |
+
|
51 |
+
# self.scheduler = DDIMScheduler.from_pretrained( \
|
52 |
+
# scheduler_name, subfolder="scheduler")
|
53 |
+
# self.scheduler = DDPMScheduler.from_pretrained( \
|
54 |
+
# scheduler_name, subfolder="scheduler")
|
55 |
+
# print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
56 |
+
|
57 |
+
# def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"):
|
58 |
+
# """ Genrate audio without condition. """
|
59 |
+
# with torch.no_grad():
|
60 |
+
# if(orig_samples.shape[-1]<int(duration*48000)+480):
|
61 |
+
# orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*48000+480)-orig_samples.shape[-1], \
|
62 |
+
# dtype=orig_samples.dtype, device=orig_samples.device)], -1)
|
63 |
+
|
64 |
+
# orig_samples = orig_samples.to(self.device)
|
65 |
+
# saved_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
66 |
+
# orig_samples = orig_samples[:,0:40*48000].clamp(-1,1)
|
67 |
+
# max_volume = orig_samples.abs().max(dim=-1)[0]
|
68 |
+
# orig_samples = orig_samples/max_volume.unsqueeze(-1)
|
69 |
+
# print("orig_samples.shape", orig_samples.shape)
|
70 |
+
|
71 |
+
# latent_length = int((st_et[1] - st_et[0]) * 48000) // 1920 + 1
|
72 |
+
|
73 |
+
# true_latents = self.vae.encode_audio(orig_samples).permute(0,2,1)
|
74 |
+
|
75 |
+
# print("true_latents.shape", true_latents.shape)
|
76 |
+
# latents = self.model.inference(orig_samples.repeat(batch_size, 1), [lyric, ]*batch_size, true_latents, latent_length, additional_feats=[], guidance_scale=1.5, num_steps = steps, disable_progress=disable_progress,layer=6, scenario = scenario)
|
77 |
+
# print("latents.shape", latents.shape)
|
78 |
+
# print("latent_length", latent_length)
|
79 |
+
|
80 |
+
# latents = latents[:,:,:latent_length]
|
81 |
+
# audio = self.vae.decode_audio(latents)
|
82 |
+
# print("audio.shape:",audio.shape)
|
83 |
+
# audio = torch.cat((audio, torch.zeros(audio.shape[0],audio.shape[1], 48000*40 - audio.shape[-1], dtype=audio.dtype, device=audio.device)), dim=-1)
|
84 |
+
# print("audio.shape:",audio.shape)
|
85 |
+
# # audio = audio.reshape(audio.shape[0]//2, 2, -1)
|
86 |
+
# # audio = torch.from_numpy(audio)
|
87 |
+
|
88 |
+
# if(saved_samples.shape[-1]<audio.shape[-1]):
|
89 |
+
# saved_samples = torch.cat([saved_samples, torch.zeros(saved_samples.shape[0], audio.shape[-1]-saved_samples.shape[-1], dtype=saved_samples.dtype, device=saved_samples.device)],-1)
|
90 |
+
# else:
|
91 |
+
# saved_samples = saved_samples[:,0:audio.shape[-1]]
|
92 |
+
# output = torch.cat([saved_samples.detach().cpu(),audio[0].detach().cpu()],0)
|
93 |
+
# return output
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
97 |
+
def sound2code(self, orig_samples, batch_size=3):
|
98 |
+
if(orig_samples.ndim == 2):
|
99 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
100 |
+
elif(orig_samples.ndim == 3):
|
101 |
+
audios = orig_samples.to(self.device)
|
102 |
+
else:
|
103 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
104 |
+
audios = self.preprocess_audio(audios)
|
105 |
+
audios = audios.squeeze(0)
|
106 |
+
orig_length = audios.shape[-1]
|
107 |
+
min_samples = int(40 * self.sample_rate)
|
108 |
+
# 40秒对应10个token
|
109 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
110 |
+
print("output_len: ", output_len)
|
111 |
+
|
112 |
+
while(audios.shape[-1] < min_samples):
|
113 |
+
audios = torch.cat([audios, audios], -1)
|
114 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
115 |
+
audios = torch.cat([audios, audios], -1)
|
116 |
+
audios=audios[:,:int(int_max_len*(min_samples))]
|
117 |
+
codes_list=[]
|
118 |
+
|
119 |
+
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
120 |
+
|
121 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
122 |
+
# import pdb; pdb.set_trace()
|
123 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
124 |
+
codes_list.append(torch.cat(codes, 1))
|
125 |
+
# print("codes_list",codes_list[0].shape)
|
126 |
+
|
127 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None] # B 3 T -> 3 B T
|
128 |
+
codes=codes[:,:,:output_len]
|
129 |
+
|
130 |
+
return codes
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
134 |
+
codes = codes.to(self.device)
|
135 |
+
|
136 |
+
min_samples = int(duration * 25) # 40ms per frame
|
137 |
+
hop_samples = min_samples // 4 * 3
|
138 |
+
ovlp_samples = min_samples - hop_samples
|
139 |
+
hop_frames = hop_samples
|
140 |
+
ovlp_frames = ovlp_samples
|
141 |
+
first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
|
142 |
+
first_latent_length = 0
|
143 |
+
first_latent_codes_length = 0
|
144 |
+
|
145 |
+
if(isinstance(prompt, torch.Tensor)):
|
146 |
+
# prepare prompt
|
147 |
+
prompt = prompt.to(self.device)
|
148 |
+
if(prompt.ndim == 3):
|
149 |
+
assert prompt.shape[0] == 1, prompt.shape
|
150 |
+
prompt = prompt[0]
|
151 |
+
elif(prompt.ndim == 1):
|
152 |
+
prompt = prompt.unsqueeze(0).repeat(2,1)
|
153 |
+
elif(prompt.ndim == 2):
|
154 |
+
if(prompt.shape[0] == 1):
|
155 |
+
prompt = prompt.repeat(2,1)
|
156 |
+
|
157 |
+
if(prompt.shape[-1] < int(30 * self.sample_rate)):
|
158 |
+
# if less than 30s, just choose the first 10s
|
159 |
+
prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
160 |
+
else:
|
161 |
+
# else choose from 20.48s which might includes verse or chorus
|
162 |
+
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
163 |
+
|
164 |
+
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
165 |
+
# print("true_latent.shape", true_latent.shape)
|
166 |
+
# print("first_latent.shape", first_latent.shape)
|
167 |
+
#true_latent.shape torch.Size([1, 250, 64])
|
168 |
+
# first_latent.shape torch.Size([1, 1000, 64])
|
169 |
+
|
170 |
+
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
171 |
+
first_latent_length = true_latent.shape[1]
|
172 |
+
first_latent_codes = self.sound2code(prompt)
|
173 |
+
first_latent_codes_length = first_latent_codes.shape[-1]
|
174 |
+
codes = torch.cat([first_latent_codes, codes], -1)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
codes_len= codes.shape[-1]
|
180 |
+
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
181 |
+
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
182 |
+
# code repeat
|
183 |
+
if(codes_len < min_samples):
|
184 |
+
while(codes.shape[-1] < min_samples):
|
185 |
+
codes = torch.cat([codes, codes], -1)
|
186 |
+
codes = codes[:,:,0:min_samples]
|
187 |
+
codes_len = codes.shape[-1]
|
188 |
+
if((codes_len - ovlp_samples) % hop_samples > 0):
|
189 |
+
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
190 |
+
while(codes.shape[-1] < len_codes):
|
191 |
+
codes = torch.cat([codes, codes], -1)
|
192 |
+
codes = codes[:,:,0:len_codes]
|
193 |
+
latent_length = min_samples
|
194 |
+
latent_list = []
|
195 |
+
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
196 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
197 |
+
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
198 |
+
codes_input=[]
|
199 |
+
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
200 |
+
if(sinx == 0):
|
201 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
202 |
+
incontext_length = first_latent_length
|
203 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
204 |
+
latent_list.append(latents)
|
205 |
+
else:
|
206 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
207 |
+
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
208 |
+
print("true_latent.shape", true_latent.shape)
|
209 |
+
len_add_to_1000 = min_samples - true_latent.shape[-2]
|
210 |
+
# print("len_add_to_1000", len_add_to_1000)
|
211 |
+
# exit()
|
212 |
+
incontext_length = true_latent.shape[-2]
|
213 |
+
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
214 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
215 |
+
latent_list.append(latents)
|
216 |
+
|
217 |
+
latent_list = [l.float() for l in latent_list]
|
218 |
+
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
219 |
+
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
220 |
+
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
221 |
+
ovlp_samples = min_samples - hop_samples
|
222 |
+
with torch.no_grad():
|
223 |
+
output = None
|
224 |
+
for i in range(len(latent_list)):
|
225 |
+
latent = latent_list[i]
|
226 |
+
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
227 |
+
|
228 |
+
if output is None:
|
229 |
+
output = cur_output
|
230 |
+
else:
|
231 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
232 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
233 |
+
print("output.shape", output.shape)
|
234 |
+
print("ov_win.shape", ov_win.shape)
|
235 |
+
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
236 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
237 |
+
output = output[:, 0:target_len]
|
238 |
+
return output
|
239 |
+
|
240 |
+
@torch.no_grad()
|
241 |
+
def preprocess_audio(self, input_audios, threshold=0.8):
|
242 |
+
assert len(input_audios.shape) == 3, input_audios.shape
|
243 |
+
nchan = input_audios.shape[1]
|
244 |
+
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
245 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
246 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
247 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
248 |
+
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
252 |
+
codes = self.sound2code(sound)
|
253 |
+
# print(codes.shape)
|
254 |
+
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
255 |
+
# print(fname, wave.shape)
|
256 |
+
return wave
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False):
|
260 |
+
min_samples = int(40 * 25) # 40ms per frame
|
261 |
+
hop_samples = min_samples // 4 * 3
|
262 |
+
ovlp_samples = min_samples - hop_samples
|
263 |
+
dur = 20
|
264 |
+
|
265 |
+
latent_list = []
|
266 |
+
for i in range(0, sound.shape[-1], dur*48000):
|
267 |
+
if(i+dur*2*48000 > sound.shape[-1]):
|
268 |
+
latent = tango.vae.encode_audio(sound.cuda()[None,:,i:])
|
269 |
+
break
|
270 |
+
else:
|
271 |
+
latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000])
|
272 |
+
latent_list.append(latent)
|
273 |
+
|
274 |
+
output = None
|
275 |
+
for i in range(len(latent_list)):
|
276 |
+
print(i)
|
277 |
+
latent = latent_list[i]
|
278 |
+
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
279 |
+
if output is None:
|
280 |
+
output = cur_output
|
281 |
+
else:
|
282 |
+
output = torch.cat([output, cur_output], -1)
|
283 |
+
return output
|
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from model_2rvq import PromptCondAudioDiffusion
|
5 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
6 |
+
import torchaudio
|
7 |
+
import librosa
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
# from tools.get_mulan import get_mulan
|
12 |
+
from tools.get_1dvae_large import get_model
|
13 |
+
import tools.torch_tools as torch_tools
|
14 |
+
from safetensors.torch import load_file
|
15 |
+
from audio import AudioFile
|
16 |
+
import kaldiio
|
17 |
+
|
18 |
+
class Tango:
|
19 |
+
def __init__(self, \
|
20 |
+
model_path, \
|
21 |
+
layer_num=6, \
|
22 |
+
rvq_num=1, \
|
23 |
+
device="cuda:0"):
|
24 |
+
|
25 |
+
self.sample_rate = 48000
|
26 |
+
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
27 |
+
self.device = device
|
28 |
+
|
29 |
+
self.vae = get_model()
|
30 |
+
self.vae = self.vae.to(device)
|
31 |
+
self.vae=self.vae.eval()
|
32 |
+
self.layer_num = layer_num
|
33 |
+
|
34 |
+
self.MAX_DURATION = 360
|
35 |
+
main_config = {
|
36 |
+
"num_channels":32,
|
37 |
+
"unet_model_name":None,
|
38 |
+
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
39 |
+
"snr_gamma":None,
|
40 |
+
}
|
41 |
+
self.rvq_num = rvq_num
|
42 |
+
# print("rvq_num: ", self.rvq_num)
|
43 |
+
# exit()
|
44 |
+
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
45 |
+
if model_path.endswith(".safetensors"):
|
46 |
+
main_weights = load_file(model_path)
|
47 |
+
else:
|
48 |
+
main_weights = torch.load(model_path, map_location=device)
|
49 |
+
self.model.load_state_dict(main_weights, strict=False)
|
50 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
51 |
+
|
52 |
+
self.model.eval()
|
53 |
+
self.model.init_device_dtype(torch.device(device), torch.float32)
|
54 |
+
print("scaling factor: ", self.model.normfeat.std)
|
55 |
+
|
56 |
+
# self.scheduler = DDIMScheduler.from_pretrained( \
|
57 |
+
# scheduler_name, subfolder="scheduler")
|
58 |
+
# self.scheduler = DDPMScheduler.from_pretrained( \
|
59 |
+
# scheduler_name, subfolder="scheduler")
|
60 |
+
print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
66 |
+
def sound2code(self, orig_samples, batch_size=8):
|
67 |
+
if(orig_samples.ndim == 2):
|
68 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
69 |
+
elif(orig_samples.ndim == 3):
|
70 |
+
audios = orig_samples.to(self.device)
|
71 |
+
else:
|
72 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
73 |
+
audios = self.preprocess_audio(audios)
|
74 |
+
audios = audios.squeeze(0)
|
75 |
+
orig_length = audios.shape[-1]
|
76 |
+
min_samples = int(40 * self.sample_rate)
|
77 |
+
# 40秒对应10个token
|
78 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
79 |
+
# print("output_len: ", output_len)
|
80 |
+
|
81 |
+
while(audios.shape[-1] < min_samples):
|
82 |
+
audios = torch.cat([audios, audios], -1)
|
83 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
84 |
+
audios = torch.cat([audios, audios], -1)
|
85 |
+
audios=audios[:,:int(int_max_len*(min_samples))]
|
86 |
+
codes_list=[]
|
87 |
+
|
88 |
+
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
89 |
+
|
90 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
91 |
+
# import pdb; pdb.set_trace()
|
92 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
|
93 |
+
# print("codes",codes[0].shape)
|
94 |
+
|
95 |
+
codes_list.append(torch.cat(codes, 1))
|
96 |
+
# print("codes_list",codes_list[0].shape)
|
97 |
+
|
98 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
99 |
+
codes=codes[:,:,:output_len]
|
100 |
+
|
101 |
+
return codes
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
105 |
+
def sound2code_ds(self, orig_samples, ds, batch_size=8):
|
106 |
+
if(orig_samples.ndim == 2):
|
107 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
108 |
+
elif(orig_samples.ndim == 3):
|
109 |
+
audios = orig_samples.to(self.device)
|
110 |
+
else:
|
111 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
112 |
+
audios = self.preprocess_audio(audios)
|
113 |
+
audios = audios.squeeze(0)
|
114 |
+
orig_length = audios.shape[-1]
|
115 |
+
min_samples = int(40 * self.sample_rate)
|
116 |
+
# 40秒对应10个token
|
117 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
118 |
+
# print("output_len: ", output_len)
|
119 |
+
|
120 |
+
while(audios.shape[-1] < min_samples):
|
121 |
+
audios = torch.cat([audios, audios], -1)
|
122 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
123 |
+
audios = torch.cat([audios, audios], -1)
|
124 |
+
audios=audios[:,:int(int_max_len*(min_samples))]
|
125 |
+
codes_list=[]
|
126 |
+
|
127 |
+
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
128 |
+
|
129 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
130 |
+
# import pdb; pdb.set_trace()
|
131 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
|
132 |
+
# print("codes",codes[0].shape)
|
133 |
+
|
134 |
+
codes_list.append(torch.cat(codes, 1))
|
135 |
+
# print("codes_list",codes_list[0].shape)
|
136 |
+
|
137 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
138 |
+
codes=codes[:,:,:output_len]
|
139 |
+
|
140 |
+
return codes
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
144 |
+
codes = codes.to(self.device)
|
145 |
+
|
146 |
+
min_samples = duration * 25 # 40ms per frame
|
147 |
+
hop_samples = min_samples // 4 * 3
|
148 |
+
ovlp_samples = min_samples - hop_samples
|
149 |
+
hop_frames = hop_samples
|
150 |
+
ovlp_frames = ovlp_samples
|
151 |
+
first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
|
152 |
+
first_latent_length = 0
|
153 |
+
first_latent_codes_length = 0
|
154 |
+
|
155 |
+
if(isinstance(prompt, torch.Tensor)):
|
156 |
+
# prepare prompt
|
157 |
+
prompt = prompt.to(self.device)
|
158 |
+
if(prompt.ndim == 3):
|
159 |
+
assert prompt.shape[0] == 1, prompt.shape
|
160 |
+
prompt = prompt[0]
|
161 |
+
elif(prompt.ndim == 1):
|
162 |
+
prompt = prompt.unsqueeze(0).repeat(2,1)
|
163 |
+
elif(prompt.ndim == 2):
|
164 |
+
if(prompt.shape[0] == 1):
|
165 |
+
prompt = prompt.repeat(2,1)
|
166 |
+
|
167 |
+
if(prompt.shape[-1] < int(30 * self.sample_rate)):
|
168 |
+
# if less than 30s, just choose the first 10s
|
169 |
+
prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
170 |
+
else:
|
171 |
+
# else choose from 20.48s which might includes verse or chorus
|
172 |
+
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
173 |
+
|
174 |
+
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
175 |
+
# print("true_latent.shape", true_latent.shape)
|
176 |
+
# print("first_latent.shape", first_latent.shape)
|
177 |
+
#true_latent.shape torch.Size([1, 250, 64])
|
178 |
+
# first_latent.shape torch.Size([1, 1000, 64])
|
179 |
+
|
180 |
+
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
181 |
+
first_latent_length = true_latent.shape[1]
|
182 |
+
first_latent_codes = self.sound2code(prompt)
|
183 |
+
first_latent_codes_length = first_latent_codes.shape[-1]
|
184 |
+
codes = torch.cat([first_latent_codes, codes], -1)
|
185 |
+
|
186 |
+
codes_len= codes.shape[-1]
|
187 |
+
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
188 |
+
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
189 |
+
# code repeat
|
190 |
+
if(codes_len < min_samples):
|
191 |
+
while(codes.shape[-1] < min_samples):
|
192 |
+
codes = torch.cat([codes, codes], -1)
|
193 |
+
codes = codes[:,:,0:min_samples]
|
194 |
+
codes_len = codes.shape[-1]
|
195 |
+
if((codes_len - ovlp_samples) % hop_samples > 0):
|
196 |
+
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
197 |
+
while(codes.shape[-1] < len_codes):
|
198 |
+
codes = torch.cat([codes, codes], -1)
|
199 |
+
codes = codes[:,:,0:len_codes]
|
200 |
+
latent_length = min_samples
|
201 |
+
latent_list = []
|
202 |
+
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
203 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
204 |
+
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
205 |
+
codes_input=[]
|
206 |
+
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
207 |
+
if(sinx == 0):
|
208 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
209 |
+
incontext_length = first_latent_length
|
210 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
211 |
+
latent_list.append(latents)
|
212 |
+
else:
|
213 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
214 |
+
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
215 |
+
print("true_latent.shape", true_latent.shape)
|
216 |
+
len_add_to_1000 = 1000 - true_latent.shape[-2]
|
217 |
+
# print("len_add_to_1000", len_add_to_1000)
|
218 |
+
# exit()
|
219 |
+
incontext_length = true_latent.shape[-2]
|
220 |
+
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
221 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
222 |
+
latent_list.append(latents)
|
223 |
+
|
224 |
+
latent_list = [l.float() for l in latent_list]
|
225 |
+
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
226 |
+
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
227 |
+
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
228 |
+
ovlp_samples = min_samples - hop_samples
|
229 |
+
with torch.no_grad():
|
230 |
+
output = None
|
231 |
+
for i in range(len(latent_list)):
|
232 |
+
latent = latent_list[i]
|
233 |
+
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
234 |
+
|
235 |
+
if output is None:
|
236 |
+
output = cur_output
|
237 |
+
else:
|
238 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
239 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
240 |
+
print("output.shape", output.shape)
|
241 |
+
print("ov_win.shape", ov_win.shape)
|
242 |
+
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
243 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
244 |
+
output = output[:, 0:target_len]
|
245 |
+
return output
|
246 |
+
|
247 |
+
@torch.no_grad()
|
248 |
+
def preprocess_audio(self, input_audios, threshold=0.8):
|
249 |
+
assert len(input_audios.shape) == 3, input_audios.shape
|
250 |
+
nchan = input_audios.shape[1]
|
251 |
+
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
252 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
253 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
254 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
255 |
+
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
256 |
+
|
257 |
+
@torch.no_grad()
|
258 |
+
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
259 |
+
codes = self.sound2code(sound)
|
260 |
+
# print(codes.shape)
|
261 |
+
# exit()
|
262 |
+
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
263 |
+
# print(fname, wave.shape)
|
264 |
+
return wave
|
265 |
+
|
266 |
+
def file2code(self, fname):
|
267 |
+
try:
|
268 |
+
orig_samples, fs = torchaudio.load(fname)
|
269 |
+
except:
|
270 |
+
af = AudioFile(fname)
|
271 |
+
orig_samples = af.read()
|
272 |
+
fs = af.samplerate()
|
273 |
+
orig_samples = orig_samples[0]
|
274 |
+
if(fs!=self.sample_rate):
|
275 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
276 |
+
fs = self.sample_rate
|
277 |
+
if orig_samples.shape[0] == 1:
|
278 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
279 |
+
return self.sound2code(orig_samples)
|
280 |
+
|
281 |
+
def file2code_ds(self, fname, ds):
|
282 |
+
try:
|
283 |
+
orig_samples, fs = torchaudio.load(fname)
|
284 |
+
except:
|
285 |
+
af = AudioFile(fname)
|
286 |
+
orig_samples = af.read()
|
287 |
+
fs = af.samplerate()
|
288 |
+
orig_samples = orig_samples[0]
|
289 |
+
if(fs!=self.sample_rate):
|
290 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
291 |
+
fs = self.sample_rate
|
292 |
+
if orig_samples.shape[0] == 1:
|
293 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
294 |
+
return self.sound2code_ds(orig_samples, ds)
|
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from model_4rvq import PromptCondAudioDiffusion
|
5 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
6 |
+
import torchaudio
|
7 |
+
import librosa
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
# from tools.get_mulan import get_mulan
|
12 |
+
from tools.get_1dvae_large import get_model
|
13 |
+
import tools.torch_tools as torch_tools
|
14 |
+
from safetensors.torch import load_file
|
15 |
+
from audio import AudioFile
|
16 |
+
|
17 |
+
class Tango:
|
18 |
+
def __init__(self, \
|
19 |
+
model_path, \
|
20 |
+
layer_num=6, \
|
21 |
+
rvq_num=1, \
|
22 |
+
device="cuda:0"):
|
23 |
+
|
24 |
+
self.sample_rate = 48000
|
25 |
+
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
26 |
+
self.device = device
|
27 |
+
|
28 |
+
self.vae = get_model()
|
29 |
+
self.vae = self.vae.to(device)
|
30 |
+
self.vae=self.vae.eval()
|
31 |
+
self.layer_num = layer_num
|
32 |
+
|
33 |
+
self.MAX_DURATION = 360
|
34 |
+
main_config = {
|
35 |
+
"num_channels":32,
|
36 |
+
"unet_model_name":None,
|
37 |
+
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
38 |
+
"snr_gamma":None,
|
39 |
+
}
|
40 |
+
self.rvq_num = rvq_num
|
41 |
+
# print("rvq_num: ", self.rvq_num)
|
42 |
+
# exit()
|
43 |
+
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
44 |
+
if model_path.endswith(".safetensors"):
|
45 |
+
main_weights = load_file(model_path)
|
46 |
+
else:
|
47 |
+
main_weights = torch.load(model_path, map_location=device)
|
48 |
+
self.model.load_state_dict(main_weights, strict=False)
|
49 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
50 |
+
|
51 |
+
self.model.eval()
|
52 |
+
self.model.init_device_dtype(torch.device(device), torch.float32)
|
53 |
+
print("scaling factor: ", self.model.normfeat.std)
|
54 |
+
|
55 |
+
# self.scheduler = DDIMScheduler.from_pretrained( \
|
56 |
+
# scheduler_name, subfolder="scheduler")
|
57 |
+
# self.scheduler = DDPMScheduler.from_pretrained( \
|
58 |
+
# scheduler_name, subfolder="scheduler")
|
59 |
+
print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
65 |
+
def sound2code(self, orig_samples, batch_size=8):
|
66 |
+
if(orig_samples.ndim == 2):
|
67 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
68 |
+
elif(orig_samples.ndim == 3):
|
69 |
+
audios = orig_samples.to(self.device)
|
70 |
+
else:
|
71 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
72 |
+
audios = self.preprocess_audio(audios)
|
73 |
+
audios = audios.squeeze(0)
|
74 |
+
orig_length = audios.shape[-1]
|
75 |
+
min_samples = int(40 * self.sample_rate)
|
76 |
+
# 40秒对应10个token
|
77 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
78 |
+
# print("output_len: ", output_len)
|
79 |
+
|
80 |
+
while(audios.shape[-1] < min_samples):
|
81 |
+
audios = torch.cat([audios, audios], -1)
|
82 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
83 |
+
audios = torch.cat([audios, audios], -1)
|
84 |
+
audios=audios[:,:int(int_max_len*(min_samples))]
|
85 |
+
codes_list=[]
|
86 |
+
|
87 |
+
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
88 |
+
|
89 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
90 |
+
# import pdb; pdb.set_trace()
|
91 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
|
92 |
+
# print("codes",codes[0].shape)
|
93 |
+
|
94 |
+
codes_list.append(torch.cat(codes, 1))
|
95 |
+
# print("codes_list",codes_list[0].shape)
|
96 |
+
|
97 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
98 |
+
codes=codes[:,:,:output_len]
|
99 |
+
|
100 |
+
return codes
|
101 |
+
|
102 |
+
@torch.no_grad()
|
103 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
104 |
+
def sound2code_ds(self, orig_samples, ds, batch_size=6):
|
105 |
+
if(orig_samples.ndim == 2):
|
106 |
+
audios = orig_samples.unsqueeze(0).to(self.device)
|
107 |
+
elif(orig_samples.ndim == 3):
|
108 |
+
audios = orig_samples.to(self.device)
|
109 |
+
else:
|
110 |
+
assert orig_samples.ndim in (2,3), orig_samples.shape
|
111 |
+
audios = self.preprocess_audio(audios)
|
112 |
+
audios = audios.squeeze(0)
|
113 |
+
orig_length = audios.shape[-1]
|
114 |
+
min_samples = int(40 * self.sample_rate)
|
115 |
+
# 40秒对应10个token
|
116 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
117 |
+
# print("output_len: ", output_len)
|
118 |
+
|
119 |
+
while(audios.shape[-1] < min_samples):
|
120 |
+
audios = torch.cat([audios, audios], -1)
|
121 |
+
int_max_len=audios.shape[-1]//min_samples+1
|
122 |
+
audios = torch.cat([audios, audios], -1)
|
123 |
+
audios=audios[:,:int(int_max_len*(min_samples))]
|
124 |
+
codes_list=[]
|
125 |
+
|
126 |
+
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
127 |
+
|
128 |
+
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
129 |
+
# import pdb; pdb.set_trace()
|
130 |
+
codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
|
131 |
+
# print("codes",codes[0].shape)
|
132 |
+
|
133 |
+
codes_list.append(torch.cat(codes, 1))
|
134 |
+
# print("codes_list",codes_list[0].shape)
|
135 |
+
|
136 |
+
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
137 |
+
codes=codes[:,:,:output_len]
|
138 |
+
|
139 |
+
return codes
|
140 |
+
|
141 |
+
@torch.no_grad()
|
142 |
+
def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
143 |
+
codes = codes.to(self.device)
|
144 |
+
|
145 |
+
min_samples = duration * 25 # 40ms per frame
|
146 |
+
hop_samples = min_samples // 4 * 3
|
147 |
+
ovlp_samples = min_samples - hop_samples
|
148 |
+
hop_frames = hop_samples
|
149 |
+
ovlp_frames = ovlp_samples
|
150 |
+
first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
|
151 |
+
first_latent_length = 0
|
152 |
+
first_latent_codes_length = 0
|
153 |
+
|
154 |
+
if(isinstance(prompt, torch.Tensor)):
|
155 |
+
# prepare prompt
|
156 |
+
prompt = prompt.to(self.device)
|
157 |
+
if(prompt.ndim == 3):
|
158 |
+
assert prompt.shape[0] == 1, prompt.shape
|
159 |
+
prompt = prompt[0]
|
160 |
+
elif(prompt.ndim == 1):
|
161 |
+
prompt = prompt.unsqueeze(0).repeat(2,1)
|
162 |
+
elif(prompt.ndim == 2):
|
163 |
+
if(prompt.shape[0] == 1):
|
164 |
+
prompt = prompt.repeat(2,1)
|
165 |
+
|
166 |
+
if(prompt.shape[-1] < int(30 * self.sample_rate)):
|
167 |
+
# if less than 30s, just choose the first 10s
|
168 |
+
prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
169 |
+
else:
|
170 |
+
# else choose from 20.48s which might includes verse or chorus
|
171 |
+
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
172 |
+
|
173 |
+
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
174 |
+
# print("true_latent.shape", true_latent.shape)
|
175 |
+
# print("first_latent.shape", first_latent.shape)
|
176 |
+
#true_latent.shape torch.Size([1, 250, 64])
|
177 |
+
# first_latent.shape torch.Size([1, 1000, 64])
|
178 |
+
|
179 |
+
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
180 |
+
first_latent_length = true_latent.shape[1]
|
181 |
+
first_latent_codes = self.sound2code(prompt)
|
182 |
+
first_latent_codes_length = first_latent_codes.shape[-1]
|
183 |
+
codes = torch.cat([first_latent_codes, codes], -1)
|
184 |
+
|
185 |
+
codes_len= codes.shape[-1]
|
186 |
+
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
187 |
+
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
188 |
+
# code repeat
|
189 |
+
if(codes_len < min_samples):
|
190 |
+
while(codes.shape[-1] < min_samples):
|
191 |
+
codes = torch.cat([codes, codes], -1)
|
192 |
+
codes = codes[:,:,0:min_samples]
|
193 |
+
codes_len = codes.shape[-1]
|
194 |
+
if((codes_len - ovlp_samples) % hop_samples > 0):
|
195 |
+
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
196 |
+
while(codes.shape[-1] < len_codes):
|
197 |
+
codes = torch.cat([codes, codes], -1)
|
198 |
+
codes = codes[:,:,0:len_codes]
|
199 |
+
latent_length = min_samples
|
200 |
+
latent_list = []
|
201 |
+
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
202 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
203 |
+
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
204 |
+
codes_input=[]
|
205 |
+
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
206 |
+
if(sinx == 0):
|
207 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
208 |
+
incontext_length = first_latent_length
|
209 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
210 |
+
latent_list.append(latents)
|
211 |
+
else:
|
212 |
+
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
213 |
+
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
214 |
+
print("true_latent.shape", true_latent.shape)
|
215 |
+
len_add_to_1000 = 1000 - true_latent.shape[-2]
|
216 |
+
# print("len_add_to_1000", len_add_to_1000)
|
217 |
+
# exit()
|
218 |
+
incontext_length = true_latent.shape[-2]
|
219 |
+
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
220 |
+
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
221 |
+
latent_list.append(latents)
|
222 |
+
|
223 |
+
latent_list = [l.float() for l in latent_list]
|
224 |
+
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
225 |
+
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
226 |
+
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
227 |
+
ovlp_samples = min_samples - hop_samples
|
228 |
+
with torch.no_grad():
|
229 |
+
output = None
|
230 |
+
for i in range(len(latent_list)):
|
231 |
+
latent = latent_list[i]
|
232 |
+
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
233 |
+
|
234 |
+
if output is None:
|
235 |
+
output = cur_output
|
236 |
+
else:
|
237 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
238 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
239 |
+
print("output.shape", output.shape)
|
240 |
+
print("ov_win.shape", ov_win.shape)
|
241 |
+
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
242 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
243 |
+
output = output[:, 0:target_len]
|
244 |
+
return output
|
245 |
+
|
246 |
+
@torch.no_grad()
|
247 |
+
def preprocess_audio(self, input_audios, threshold=0.8):
|
248 |
+
assert len(input_audios.shape) == 3, input_audios.shape
|
249 |
+
nchan = input_audios.shape[1]
|
250 |
+
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
251 |
+
norm_value = torch.ones_like(input_audios[:,0])
|
252 |
+
max_volume = input_audios.abs().max(dim=-1)[0]
|
253 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
254 |
+
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
258 |
+
codes = self.sound2code(sound)
|
259 |
+
# print(codes.shape)
|
260 |
+
# exit()
|
261 |
+
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
262 |
+
# print(fname, wave.shape)
|
263 |
+
return wave
|
264 |
+
|
265 |
+
def file2code(self, fname):
|
266 |
+
try:
|
267 |
+
orig_samples, fs = torchaudio.load(fname)
|
268 |
+
except:
|
269 |
+
af = AudioFile(fname)
|
270 |
+
orig_samples = af.read()
|
271 |
+
fs = af.samplerate()
|
272 |
+
orig_samples = orig_samples[0]
|
273 |
+
if(fs!=self.sample_rate):
|
274 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
275 |
+
fs = self.sample_rate
|
276 |
+
if orig_samples.shape[0] == 1:
|
277 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
278 |
+
return self.sound2code(orig_samples)
|
279 |
+
|
280 |
+
def file2code_ds(self, fname, ds):
|
281 |
+
try:
|
282 |
+
orig_samples, fs = torchaudio.load(fname)
|
283 |
+
except:
|
284 |
+
af = AudioFile(fname)
|
285 |
+
orig_samples = af.read()
|
286 |
+
fs = af.samplerate()
|
287 |
+
orig_samples = orig_samples[0]
|
288 |
+
if(fs!=self.sample_rate):
|
289 |
+
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
290 |
+
fs = self.sample_rate
|
291 |
+
if orig_samples.shape[0] == 1:
|
292 |
+
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
293 |
+
return self.sound2code_ds(orig_samples, ds)
|
codeclm/tokenizer/Flow1dVAE/generate_septoken.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
from model_septoken import PromptCondAudioDiffusion
|
5 |
+
from diffusers import DDIMScheduler, DDPMScheduler
|
6 |
+
import torchaudio
|
7 |
+
import librosa
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import numpy as np
|
11 |
+
# from tools.get_mulan import get_mulan
|
12 |
+
from tools.get_1dvae_large import get_model
|
13 |
+
import tools.torch_tools as torch_tools
|
14 |
+
from safetensors.torch import load_file
|
15 |
+
from third_party.demucs.models.pretrained import get_model_from_yaml
|
16 |
+
from filelock import FileLock
|
17 |
+
import kaldiio
|
18 |
+
# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
|
19 |
+
class Separator:
|
20 |
+
def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
21 |
+
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
22 |
+
self.device = torch.device(f"cuda:{gpu_id}")
|
23 |
+
else:
|
24 |
+
self.device = torch.device("cpu")
|
25 |
+
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
26 |
+
|
27 |
+
def init_demucs_model(self, model_path, config_path):
|
28 |
+
model = get_model_from_yaml(config_path, model_path)
|
29 |
+
model.to(self.device)
|
30 |
+
model.eval()
|
31 |
+
return model
|
32 |
+
|
33 |
+
def load_audio(self, f):
|
34 |
+
a, fs = torchaudio.load(f)
|
35 |
+
if (fs != 48000):
|
36 |
+
a = torchaudio.functional.resample(a, fs, 48000)
|
37 |
+
# if a.shape[-1] >= 48000*10:
|
38 |
+
# a = a[..., :48000*10]
|
39 |
+
# else:
|
40 |
+
# a = torch.cat([a, a], -1)
|
41 |
+
# return a[:, 0:48000*10]
|
42 |
+
return a
|
43 |
+
|
44 |
+
def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
|
45 |
+
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
46 |
+
output_paths = []
|
47 |
+
# lock_path = os.path.join(output_dir, f"{name}.lock")
|
48 |
+
# with FileLock(lock_path): # 加一个避免多卡访问时死锁
|
49 |
+
for stem in self.demucs_model.sources:
|
50 |
+
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
51 |
+
if os.path.exists(output_path):
|
52 |
+
output_paths.append(output_path)
|
53 |
+
if len(output_paths) == 1: # 4
|
54 |
+
# drums_path, bass_path, other_path, vocal_path = output_paths
|
55 |
+
vocal_path = output_paths[0]
|
56 |
+
else:
|
57 |
+
lock_path = os.path.join(output_dir, f"{name}_separate.lock")
|
58 |
+
with FileLock(lock_path):
|
59 |
+
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
60 |
+
full_audio = self.load_audio(audio_path)
|
61 |
+
vocal_audio = self.load_audio(vocal_path)
|
62 |
+
minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
|
63 |
+
# bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
|
64 |
+
bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
|
65 |
+
for path in [drums_path, bass_path, other_path, vocal_path]:
|
66 |
+
os.remove(path)
|
67 |
+
return full_audio, vocal_audio, bgm_audio
|
68 |
+
|
69 |
+
class Tango:
|
70 |
+
def __init__(self, \
|
71 |
+
model_path, \
|
72 |
+
vae_config,
|
73 |
+
vae_model,
|
74 |
+
layer_vocal=7,\
|
75 |
+
layer_bgm=3,\
|
76 |
+
device="cuda:0"):
|
77 |
+
|
78 |
+
self.sample_rate = 48000
|
79 |
+
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
80 |
+
self.device = device
|
81 |
+
|
82 |
+
self.vae = get_model(vae_config, vae_model)
|
83 |
+
self.vae = self.vae.to(device)
|
84 |
+
self.vae=self.vae.eval()
|
85 |
+
self.layer_vocal=layer_vocal
|
86 |
+
self.layer_bgm=layer_bgm
|
87 |
+
|
88 |
+
self.MAX_DURATION = 360
|
89 |
+
main_config = {
|
90 |
+
"num_channels":32,
|
91 |
+
"unet_model_name":None,
|
92 |
+
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
93 |
+
"snr_gamma":None,
|
94 |
+
}
|
95 |
+
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
96 |
+
if model_path.endswith(".safetensors"):
|
97 |
+
main_weights = load_file(model_path)
|
98 |
+
else:
|
99 |
+
main_weights = torch.load(model_path, map_location=device)
|
100 |
+
self.model.load_state_dict(main_weights, strict=False)
|
101 |
+
print ("Successfully loaded checkpoint from:", model_path)
|
102 |
+
|
103 |
+
self.model.eval()
|
104 |
+
self.model.init_device_dtype(torch.device(device), torch.float32)
|
105 |
+
print("scaling factor: ", self.model.normfeat.std)
|
106 |
+
|
107 |
+
# self.scheduler = DDIMScheduler.from_pretrained( \
|
108 |
+
# scheduler_name, subfolder="scheduler")
|
109 |
+
# self.scheduler = DDPMScheduler.from_pretrained( \
|
110 |
+
# scheduler_name, subfolder="scheduler")
|
111 |
+
print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
112 |
+
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
116 |
+
def sound2code(self, orig_vocal, orig_bgm, batch_size=8):
|
117 |
+
if(orig_vocal.ndim == 2):
|
118 |
+
audios_vocal = orig_vocal.unsqueeze(0).to(self.device)
|
119 |
+
elif(orig_vocal.ndim == 3):
|
120 |
+
audios_vocal = orig_vocal.to(self.device)
|
121 |
+
else:
|
122 |
+
assert orig_vocal.ndim in (2,3), orig_vocal.shape
|
123 |
+
|
124 |
+
if(orig_bgm.ndim == 2):
|
125 |
+
audios_bgm = orig_bgm.unsqueeze(0).to(self.device)
|
126 |
+
elif(orig_bgm.ndim == 3):
|
127 |
+
audios_bgm = orig_bgm.to(self.device)
|
128 |
+
else:
|
129 |
+
assert orig_bgm.ndim in (2,3), orig_bgm.shape
|
130 |
+
|
131 |
+
|
132 |
+
audios_vocal = self.preprocess_audio(audios_vocal)
|
133 |
+
audios_vocal = audios_vocal.squeeze(0)
|
134 |
+
audios_bgm = self.preprocess_audio(audios_bgm)
|
135 |
+
audios_bgm = audios_bgm.squeeze(0)
|
136 |
+
if audios_vocal.shape[-1] > audios_bgm.shape[-1]:
|
137 |
+
audios_vocal = audios_vocal[:,:audios_bgm.shape[-1]]
|
138 |
+
else:
|
139 |
+
audios_bgm = audios_bgm[:,:audios_vocal.shape[-1]]
|
140 |
+
|
141 |
+
|
142 |
+
orig_length = audios_vocal.shape[-1]
|
143 |
+
min_samples = int(40 * self.sample_rate)
|
144 |
+
# 40秒对应10个token
|
145 |
+
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
146 |
+
|
147 |
+
while(audios_vocal.shape[-1] < min_samples):
|
148 |
+
audios_vocal = torch.cat([audios_vocal, audios_vocal], -1)
|
149 |
+
audios_bgm = torch.cat([audios_bgm, audios_bgm], -1)
|
150 |
+
int_max_len=audios_vocal.shape[-1]//min_samples+1
|
151 |
+
audios_vocal = torch.cat([audios_vocal, audios_vocal], -1)
|
152 |
+
audios_bgm = torch.cat([audios_bgm, audios_bgm], -1)
|
153 |
+
audios_vocal=audios_vocal[:,:int(int_max_len*(min_samples))]
|
154 |
+
audios_bgm=audios_bgm[:,:int(int_max_len*(min_samples))]
|
155 |
+
codes_vocal_list=[]
|
156 |
+
codes_bgm_list=[]
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
audio_vocal_input = audios_vocal.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
161 |
+
audio_bgm_input = audios_bgm.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
162 |
+
|
163 |
+
for audio_inx in range(0, audio_vocal_input.shape[0], batch_size):
|
164 |
+
[codes_vocal,codes_bgm], _, spk_embeds = self.model.fetch_codes_batch((audio_vocal_input[audio_inx:audio_inx+batch_size]), (audio_bgm_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer_vocal=self.layer_vocal,layer_bgm=self.layer_bgm)
|
165 |
+
codes_vocal_list.append(codes_vocal)
|
166 |
+
codes_bgm_list.append(codes_bgm)
|
167 |
+
|
168 |
+
codes_vocal = torch.cat(codes_vocal_list, 0).permute(1,0,2).reshape(1, -1)[None]
|
169 |
+
codes_bgm = torch.cat(codes_bgm_list, 0).permute(1,0,2).reshape(1, -1)[None]
|
170 |
+
codes_vocal=codes_vocal[:,:,:output_len]
|
171 |
+
codes_bgm=codes_bgm[:,:,:output_len]
|
172 |
+
|
173 |
+
return codes_vocal, codes_bgm
|
174 |
+
|
175 |
+
@torch.no_grad()
|
176 |
+
def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
177 |
+
codes_vocal,codes_bgm = codes
|
178 |
+
codes_vocal = codes_vocal.to(self.device)
|
179 |
+
codes_bgm = codes_bgm.to(self.device)
|
180 |
+
|
181 |
+
min_samples = duration * 25 # 40ms per frame
|
182 |
+
hop_samples = min_samples // 4 * 3
|
183 |
+
ovlp_samples = min_samples - hop_samples
|
184 |
+
hop_frames = hop_samples
|
185 |
+
ovlp_frames = ovlp_samples
|
186 |
+
first_latent = torch.randn(codes_vocal.shape[0], min_samples, 64).to(self.device)
|
187 |
+
first_latent_length = 0
|
188 |
+
first_latent_codes_length = 0
|
189 |
+
|
190 |
+
|
191 |
+
if(isinstance(prompt_vocal, torch.Tensor)):
|
192 |
+
# prepare prompt
|
193 |
+
prompt_vocal = prompt_vocal.to(self.device)
|
194 |
+
prompt_bgm = prompt_bgm.to(self.device)
|
195 |
+
if(prompt_vocal.ndim == 3):
|
196 |
+
assert prompt_vocal.shape[0] == 1, prompt_vocal.shape
|
197 |
+
prompt_vocal = prompt_vocal[0]
|
198 |
+
prompt_bgm = prompt_bgm[0]
|
199 |
+
elif(prompt_vocal.ndim == 1):
|
200 |
+
prompt_vocal = prompt_vocal.unsqueeze(0).repeat(2,1)
|
201 |
+
prompt_bgm = prompt_bgm.unsqueeze(0).repeat(2,1)
|
202 |
+
elif(prompt_vocal.ndim == 2):
|
203 |
+
if(prompt_vocal.shape[0] == 1):
|
204 |
+
prompt_vocal = prompt_vocal.repeat(2,1)
|
205 |
+
prompt_bgm = prompt_bgm.repeat(2,1)
|
206 |
+
|
207 |
+
if(prompt_vocal.shape[-1] < int(30 * self.sample_rate)):
|
208 |
+
# if less than 30s, just choose the first 10s
|
209 |
+
prompt_vocal = prompt_vocal[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
210 |
+
prompt_bgm = prompt_bgm[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
211 |
+
else:
|
212 |
+
# else choose from 20.48s which might includes verse or chorus
|
213 |
+
prompt_vocal = prompt_vocal[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
214 |
+
prompt_bgm = prompt_bgm[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
215 |
+
|
216 |
+
true_latent = self.vae.encode_audio(prompt_vocal+prompt_bgm).permute(0,2,1)
|
217 |
+
|
218 |
+
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
219 |
+
first_latent_length = true_latent.shape[1]
|
220 |
+
first_latent_codes = self.sound2code(prompt_vocal, prompt_bgm)
|
221 |
+
first_latent_codes_vocal = first_latent_codes[0]
|
222 |
+
first_latent_codes_bgm = first_latent_codes[1]
|
223 |
+
first_latent_codes_length = first_latent_codes_vocal.shape[-1]
|
224 |
+
codes_vocal = torch.cat([first_latent_codes_vocal, codes_vocal], -1)
|
225 |
+
codes_bgm = torch.cat([first_latent_codes_bgm, codes_bgm], -1)
|
226 |
+
|
227 |
+
|
228 |
+
codes_len= codes_vocal.shape[-1]
|
229 |
+
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
230 |
+
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
231 |
+
# code repeat
|
232 |
+
if(codes_len < min_samples):
|
233 |
+
while(codes_vocal.shape[-1] < min_samples):
|
234 |
+
codes_vocal = torch.cat([codes_vocal, codes_vocal], -1)
|
235 |
+
codes_bgm = torch.cat([codes_bgm, codes_bgm], -1)
|
236 |
+
|
237 |
+
codes_vocal = codes_vocal[:,:,0:min_samples]
|
238 |
+
codes_bgm = codes_bgm[:,:,0:min_samples]
|
239 |
+
codes_len = codes_vocal.shape[-1]
|
240 |
+
if((codes_len - ovlp_samples) % hop_samples > 0):
|
241 |
+
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
242 |
+
while(codes_vocal.shape[-1] < len_codes):
|
243 |
+
codes_vocal = torch.cat([codes_vocal, codes_vocal], -1)
|
244 |
+
codes_bgm = torch.cat([codes_bgm, codes_bgm], -1)
|
245 |
+
codes_vocal = codes_vocal[:,:,0:len_codes]
|
246 |
+
codes_bgm = codes_bgm[:,:,0:len_codes]
|
247 |
+
latent_length = min_samples
|
248 |
+
latent_list = []
|
249 |
+
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes_vocal.device)
|
250 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
251 |
+
for sinx in range(0, codes_vocal.shape[-1]-hop_samples, hop_samples):
|
252 |
+
codes_vocal_input=codes_vocal[:,:,sinx:sinx+min_samples]
|
253 |
+
codes_bgm_input=codes_bgm[:,:,sinx:sinx+min_samples]
|
254 |
+
if(sinx == 0):
|
255 |
+
incontext_length = first_latent_length
|
256 |
+
latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
257 |
+
latent_list.append(latents)
|
258 |
+
else:
|
259 |
+
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
260 |
+
len_add_to_1000 = min_samples - true_latent.shape[-2]
|
261 |
+
incontext_length = true_latent.shape[-2]
|
262 |
+
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
263 |
+
latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
264 |
+
latent_list.append(latents)
|
265 |
+
|
266 |
+
latent_list = [l.float() for l in latent_list]
|
267 |
+
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
268 |
+
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
269 |
+
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
270 |
+
ovlp_samples = min_samples - hop_samples
|
271 |
+
with torch.no_grad():
|
272 |
+
output = None
|
273 |
+
for i in range(len(latent_list)):
|
274 |
+
latent = latent_list[i]
|
275 |
+
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
276 |
+
|
277 |
+
if output is None:
|
278 |
+
output = cur_output
|
279 |
+
else:
|
280 |
+
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
281 |
+
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
282 |
+
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
283 |
+
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
284 |
+
output = output[:, 0:target_len]
|
285 |
+
return output
|
286 |
+
|
287 |
+
@torch.no_grad()
|
288 |
+
def preprocess_audio(self, input_audios_vocal, threshold=0.8):
|
289 |
+
assert len(input_audios_vocal.shape) == 3, input_audios_vocal.shape
|
290 |
+
nchan = input_audios_vocal.shape[1]
|
291 |
+
input_audios_vocal = input_audios_vocal.reshape(input_audios_vocal.shape[0], -1)
|
292 |
+
norm_value = torch.ones_like(input_audios_vocal[:,0])
|
293 |
+
max_volume = input_audios_vocal.abs().max(dim=-1)[0]
|
294 |
+
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
295 |
+
return input_audios_vocal.reshape(input_audios_vocal.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
296 |
+
|
297 |
+
@torch.no_grad()
|
298 |
+
def sound2sound(self, orig_vocal,orig_bgm, prompt_vocal=None,prompt_bgm=None, steps=50, disable_progress=False):
|
299 |
+
codes_vocal, codes_bgm = self.sound2code(orig_vocal,orig_bgm)
|
300 |
+
codes=[codes_vocal, codes_bgm]
|
301 |
+
wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
302 |
+
return wave
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py
ADDED
@@ -0,0 +1,1278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union
|
3 |
+
from beartype import beartype
|
4 |
+
from beartype.door import is_bearable
|
5 |
+
import random
|
6 |
+
import pandas as pd
|
7 |
+
import os
|
8 |
+
from torchaudio.functional import resample
|
9 |
+
import torch
|
10 |
+
import typing as tp
|
11 |
+
from pathlib import Path
|
12 |
+
import torchaudio as ta
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import json
|
16 |
+
import yaml
|
17 |
+
import torchaudio
|
18 |
+
import math
|
19 |
+
import re
|
20 |
+
from loguru import logger
|
21 |
+
import ffmpeg
|
22 |
+
|
23 |
+
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
24 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
25 |
+
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.n_samples = n_samples
|
29 |
+
self.sample_rate = sample_rate
|
30 |
+
self.randomize = randomize
|
31 |
+
|
32 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
33 |
+
if self.n_samples < 0: #means not clip
|
34 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
35 |
+
t_start = 0.
|
36 |
+
t_end = 1.0
|
37 |
+
offset = 0
|
38 |
+
else:
|
39 |
+
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
40 |
+
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
41 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
42 |
+
t_start = 0.
|
43 |
+
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
44 |
+
offset = 0
|
45 |
+
# print('c1:',chunk.shape)
|
46 |
+
else:
|
47 |
+
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
48 |
+
t_start = offset / float(cur_sample_rate) / duration
|
49 |
+
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
50 |
+
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
51 |
+
# print('offset:',offset)
|
52 |
+
# print('c0:',chunk.shape)
|
53 |
+
# Pad with silence if necessary.
|
54 |
+
if(chunk.shape[0]>1):
|
55 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
56 |
+
else:
|
57 |
+
chunk = chunk[[0],:].float()
|
58 |
+
if(cur_sample_rate!=self.sample_rate):
|
59 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
60 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
61 |
+
# print('b:',self.sample_rate,chunk.shape)
|
62 |
+
|
63 |
+
if self.n_samples > 0:
|
64 |
+
if chunk.shape[-1] < self.n_samples:
|
65 |
+
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
66 |
+
else:
|
67 |
+
chunk = chunk[:,0:self.n_samples]
|
68 |
+
seconds_start = math.floor(offset / cur_sample_rate)
|
69 |
+
seconds_total = math.floor(duration)
|
70 |
+
|
71 |
+
return (
|
72 |
+
chunk,
|
73 |
+
t_start,
|
74 |
+
t_end,
|
75 |
+
seconds_start,
|
76 |
+
seconds_total
|
77 |
+
)
|
78 |
+
|
79 |
+
class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module):
|
80 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3):
|
81 |
+
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.n_samples = n_samples
|
85 |
+
self.sample_rate = sample_rate
|
86 |
+
self.randomize = randomize
|
87 |
+
|
88 |
+
self.w_start = w_start
|
89 |
+
self.w_interval = w_interval
|
90 |
+
|
91 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
92 |
+
if self.n_samples < 0: #means not clip
|
93 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
94 |
+
t_start = 0.
|
95 |
+
t_end = 1.0
|
96 |
+
offset = 0
|
97 |
+
else:
|
98 |
+
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
99 |
+
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
100 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
101 |
+
t_start = 0.
|
102 |
+
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
103 |
+
offset = 0
|
104 |
+
# print('c1:',chunk.shape)
|
105 |
+
else:
|
106 |
+
n_offset_option = (duration - self.w_start) // self.w_interval
|
107 |
+
if n_offset_option <= 1:
|
108 |
+
offset = 0
|
109 |
+
else:
|
110 |
+
offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate)
|
111 |
+
# offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
112 |
+
t_start = offset / float(cur_sample_rate) / duration
|
113 |
+
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
114 |
+
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
115 |
+
# print('offset:',offset)
|
116 |
+
# print('c0:',chunk.shape)
|
117 |
+
# Pad with silence if necessary.
|
118 |
+
if(chunk.shape[0]>1):
|
119 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
120 |
+
else:
|
121 |
+
chunk = chunk[[0],:].float()
|
122 |
+
if(cur_sample_rate!=self.sample_rate):
|
123 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
124 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
125 |
+
# print('b:',self.sample_rate,chunk.shape)
|
126 |
+
|
127 |
+
if self.n_samples > 0:
|
128 |
+
if chunk.shape[-1] < self.n_samples:
|
129 |
+
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
130 |
+
else:
|
131 |
+
chunk = chunk[:,0:self.n_samples]
|
132 |
+
seconds_start = math.floor(offset / cur_sample_rate)
|
133 |
+
seconds_total = math.floor(duration)
|
134 |
+
|
135 |
+
return (
|
136 |
+
chunk,
|
137 |
+
t_start,
|
138 |
+
t_end,
|
139 |
+
seconds_start,
|
140 |
+
seconds_total
|
141 |
+
)
|
142 |
+
|
143 |
+
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
144 |
+
if USE_DUMMY_AUDIO:
|
145 |
+
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
146 |
+
|
147 |
+
class SafeAudioReader:
|
148 |
+
"""
|
149 |
+
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
150 |
+
"""
|
151 |
+
def __init__(self,
|
152 |
+
duration: float, # 返回音频长度
|
153 |
+
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
154 |
+
randomize: bool = True,
|
155 |
+
use_avoid_watermark_policy = False,
|
156 |
+
):
|
157 |
+
self.n_samples = int(sample_rate * duration)
|
158 |
+
self.reader = (
|
159 |
+
Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \
|
160 |
+
else Read_and_PadCrop_Normalized_T
|
161 |
+
)(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
162 |
+
|
163 |
+
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
164 |
+
def __call__(self,
|
165 |
+
filepath: os.PathLike, # 音频路径
|
166 |
+
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
167 |
+
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
168 |
+
) -> torch.Tensor:
|
169 |
+
if USE_DUMMY_AUDIO:
|
170 |
+
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
171 |
+
return wav
|
172 |
+
try:
|
173 |
+
if origin_sample_rate is None or origin_duration is None:
|
174 |
+
# audio_info = torchaudio.info(filepath)
|
175 |
+
# origin_sample_rate = audio_info.sample_rate
|
176 |
+
# origin_duration = audio_info.num_frames / origin_sample_rate
|
177 |
+
info = ffmpeg.probe(filepath)
|
178 |
+
origin_duration = float(info['format']['duration'])
|
179 |
+
origin_sample_rate = int(info['streams'][0]['sample_rate'])
|
180 |
+
wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
|
181 |
+
wav = wav.squeeze_(0)
|
182 |
+
except Exception as e:
|
183 |
+
logger.error(f"Error reading {filepath}: {e}")
|
184 |
+
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
185 |
+
return wav
|
186 |
+
|
187 |
+
|
188 |
+
class PromptTemplate:
|
189 |
+
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
190 |
+
self.template_text = template_text
|
191 |
+
self.tag_map = tag_map
|
192 |
+
self.lang = lang
|
193 |
+
|
194 |
+
@property
|
195 |
+
def tags(self):
|
196 |
+
return tuple(self.tag_map.keys())
|
197 |
+
|
198 |
+
def apply(self, **kwargs):
|
199 |
+
for tag in list(kwargs.keys()):
|
200 |
+
if kwargs[tag] == '':
|
201 |
+
kwargs.pop(tag)
|
202 |
+
for tag in self.tags:
|
203 |
+
if tag in kwargs:
|
204 |
+
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
205 |
+
else:
|
206 |
+
kwargs[tag] = ''
|
207 |
+
prompt = self.template_text.format(**kwargs)
|
208 |
+
|
209 |
+
return self.beautify(prompt)
|
210 |
+
|
211 |
+
def beautify(self, text):
|
212 |
+
if self.lang == 'en':
|
213 |
+
return self._beautify_en(text)
|
214 |
+
elif self.lang == 'zh':
|
215 |
+
return self._beautify_zh(text)
|
216 |
+
else:
|
217 |
+
raise ValueError(f'Unknown language {self.lang}')
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def _beautify_en(text):
|
221 |
+
# no continuous commas without content between them
|
222 |
+
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
223 |
+
# no continuous whitespace
|
224 |
+
text = re.sub(r'\s+', ' ', text)
|
225 |
+
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
226 |
+
text = re.sub(r'\s+,', r',', text)
|
227 |
+
text = re.sub(r',\s+', r', ', text)
|
228 |
+
# no whitespace before the full stop
|
229 |
+
text = re.sub(r'\s+\.', r'.', text)
|
230 |
+
# strip whitespace, comma, and replace ',.'
|
231 |
+
text = text.strip(' ,')
|
232 |
+
text = text.replace(',.', '.')
|
233 |
+
return text
|
234 |
+
|
235 |
+
@staticmethod
|
236 |
+
def _beautify_zh(text):
|
237 |
+
# no continuous commas without content between them
|
238 |
+
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
239 |
+
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
240 |
+
# assume there should be NO whitespace in Chinese
|
241 |
+
text = re.sub(r'\s+', r'', text)
|
242 |
+
# strip whitespace, comma, and replace ',。'
|
243 |
+
text = text.strip(', 、')
|
244 |
+
text = text.replace(',。', '。')
|
245 |
+
return text
|
246 |
+
|
247 |
+
def __repr__(self):
|
248 |
+
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
249 |
+
|
250 |
+
__str__ = __repr__
|
251 |
+
|
252 |
+
def parse_prompt_template(prompt_template_text, lang='en'):
|
253 |
+
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
254 |
+
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
255 |
+
|
256 |
+
template_text = prompt_template_text.strip()
|
257 |
+
span_texts = span_pattern.findall(prompt_template_text)
|
258 |
+
tag_map = {}
|
259 |
+
for span_text in span_texts:
|
260 |
+
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
261 |
+
tag_map[tag] = span_text
|
262 |
+
template_text = template_text.replace(span_text, '{'+tag+'}')
|
263 |
+
|
264 |
+
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
265 |
+
|
266 |
+
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
267 |
+
with open(path, 'r') as f:
|
268 |
+
lines = f.readlines()
|
269 |
+
cnt = 0
|
270 |
+
pts = []
|
271 |
+
for line in lines:
|
272 |
+
pt = parse_prompt_template(line, lang=lang)
|
273 |
+
cnt += 1
|
274 |
+
if len(pt.tags) < num:
|
275 |
+
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
276 |
+
pts.append(pt)
|
277 |
+
|
278 |
+
return pts
|
279 |
+
|
280 |
+
|
281 |
+
def get_base_dir_file(key: os.PathLike):
|
282 |
+
base = os.path.basename(key)
|
283 |
+
dirname = os.path.basename(os.path.dirname(key))
|
284 |
+
return os.path.join(dirname, base)
|
285 |
+
|
286 |
+
def read_jsonlike(path: os.PathLike):
|
287 |
+
#json or jsonl
|
288 |
+
if str(path).endswith(".json"):
|
289 |
+
with open(path, 'r', encoding='utf8') as f:
|
290 |
+
data = json.load(f)
|
291 |
+
return data
|
292 |
+
elif str(path).endswith(".jsonl"):
|
293 |
+
with open(path, 'r', encoding='utf8') as f:
|
294 |
+
data = [json.loads(line) for line in f.readlines()]
|
295 |
+
return data
|
296 |
+
else:
|
297 |
+
raise ValueError("Unknown file format")
|
298 |
+
|
299 |
+
dist_prob_map = {
|
300 |
+
1: (1.0,),
|
301 |
+
2: (0.5, 0.5),
|
302 |
+
3: (0.3, 0.4, 0.3),
|
303 |
+
4: (0.2, 0.3, 0.3, 0.2),
|
304 |
+
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
305 |
+
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
306 |
+
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
307 |
+
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
308 |
+
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
309 |
+
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
310 |
+
}
|
311 |
+
|
312 |
+
'''
|
313 |
+
#更加偏向短文本的方案
|
314 |
+
dist_prob_map = {
|
315 |
+
1: (1.0,),
|
316 |
+
2: (0.7, 0.3),
|
317 |
+
3: (0.7, 0.2, 0.1),
|
318 |
+
4: (0.6, 0.2, 0.1, 0.1),
|
319 |
+
5: (0.6, 0.2, 0.1, 0.05, 0.05),
|
320 |
+
6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05),
|
321 |
+
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
322 |
+
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
323 |
+
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
324 |
+
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
325 |
+
}
|
326 |
+
'''
|
327 |
+
|
328 |
+
#全部都用的方案
|
329 |
+
# dist_prob_map = {
|
330 |
+
# 1: (1.0,),
|
331 |
+
# 2: (0, 1.0),
|
332 |
+
# 3: (0, 0, 1.0),
|
333 |
+
# 4: (0, 0, 0, 1.0),
|
334 |
+
# 5: (0, 0, 0, 0, 1.0),
|
335 |
+
# 6: (0, 0, 0, 0, 0, 1.0),
|
336 |
+
# 7: (0, 0, 0, 0, 0, 0, 1.0),
|
337 |
+
# 8: (0, 0, 0, 0, 0, 0, 0, 1.0),
|
338 |
+
# 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0),
|
339 |
+
# 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0)
|
340 |
+
# }
|
341 |
+
|
342 |
+
dist_prob_map_low = {
|
343 |
+
1: (1.0,),
|
344 |
+
2: (0.8, 0.2),
|
345 |
+
3: (0.8, 0.1, 0.1),
|
346 |
+
4: (0.7, 0.1, 0.1, 0.1),
|
347 |
+
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
348 |
+
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
349 |
+
}
|
350 |
+
|
351 |
+
_bpm_range_rights = (
|
352 |
+
(40, '20-40'),
|
353 |
+
(60, '40-60'),
|
354 |
+
(66, '60-66'),
|
355 |
+
(76, '66-76'),
|
356 |
+
(108, '76-108'),
|
357 |
+
(120, '108-120'),
|
358 |
+
(168, '120-168'),
|
359 |
+
(176, '168-176'),
|
360 |
+
(200, '176-200')
|
361 |
+
)
|
362 |
+
_bpm_desc_map = {
|
363 |
+
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
364 |
+
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
365 |
+
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
366 |
+
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
367 |
+
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
368 |
+
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
369 |
+
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
370 |
+
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
371 |
+
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
372 |
+
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
373 |
+
}
|
374 |
+
_bpm_desc_map_zh = {
|
375 |
+
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
376 |
+
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
377 |
+
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
378 |
+
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
379 |
+
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
380 |
+
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
381 |
+
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
382 |
+
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
383 |
+
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
384 |
+
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
385 |
+
}
|
386 |
+
def get_bpm_range(bpm):
|
387 |
+
bpm = int(bpm)
|
388 |
+
for right, tag in _bpm_range_rights:
|
389 |
+
if bpm <= right:
|
390 |
+
return tag
|
391 |
+
return '>200'
|
392 |
+
|
393 |
+
def gen_bpm_descript(bpm, lang='en'):
|
394 |
+
bpm_range = get_bpm_range(bpm)
|
395 |
+
if lang == 'en':
|
396 |
+
return random.choice(_bpm_desc_map[bpm_range])
|
397 |
+
elif lang == 'zh':
|
398 |
+
return random.choice(_bpm_desc_map_zh[bpm_range])
|
399 |
+
else:
|
400 |
+
raise ValueError(f"Unknown language {lang}")
|
401 |
+
|
402 |
+
def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]):
|
403 |
+
if translate is None:
|
404 |
+
return None
|
405 |
+
if isinstance(translate, str):
|
406 |
+
return read_jsonlike(translate)
|
407 |
+
return {k: read_jsonlike(path) for k, path in translate.items()}
|
408 |
+
|
409 |
+
|
410 |
+
def gen_plain_prompt(key_list, sep=', '):
|
411 |
+
if len(key_list) == 0:
|
412 |
+
return 'none'
|
413 |
+
|
414 |
+
key_list = [k.strip() for k in key_list]
|
415 |
+
|
416 |
+
if len(key_list) > 10:
|
417 |
+
random.shuffle(key_list)
|
418 |
+
key_list = key_list[:10]
|
419 |
+
|
420 |
+
probs = dist_prob_map[len(key_list)]
|
421 |
+
|
422 |
+
num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
|
423 |
+
|
424 |
+
random.shuffle(key_list)
|
425 |
+
tags = key_list[:num_tags]
|
426 |
+
tags_str = sep.join(tags)
|
427 |
+
return tags_str
|
428 |
+
|
429 |
+
|
430 |
+
class MagnaTagATuneDataset(Dataset):
|
431 |
+
def __init__(self):
|
432 |
+
pass
|
433 |
+
|
434 |
+
|
435 |
+
def tags_to_desc(tag_list, sep=',') -> str:
|
436 |
+
if not isinstance(tag_list, Sequence):
|
437 |
+
return str(tag_list)
|
438 |
+
if isinstance(tag_list, str):
|
439 |
+
return tag_list
|
440 |
+
if len(tag_list) <= 0:
|
441 |
+
return ''
|
442 |
+
elif len(tag_list) <= 5:
|
443 |
+
probs = dist_prob_map[len(tag_list)]
|
444 |
+
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
445 |
+
random.shuffle(tag_list)
|
446 |
+
tag_list = tag_list[:tags_num]
|
447 |
+
return sep.join(tag_list)
|
448 |
+
else:
|
449 |
+
probs = dist_prob_map[5]
|
450 |
+
tags_num = random.choices(range(1, 6), probs)[0]
|
451 |
+
random.shuffle(tag_list)
|
452 |
+
tag_list = tag_list[:tags_num]
|
453 |
+
return sep.join(tag_list)
|
454 |
+
|
455 |
+
def get_sr_and_duration_info(item):
|
456 |
+
return item.get('sample_rate', None), item.get('duration', None)
|
457 |
+
|
458 |
+
class MtgJamendoDatasetFromJson(Dataset):
|
459 |
+
def __init__(self,
|
460 |
+
data_dir:str,
|
461 |
+
json_path:str,
|
462 |
+
duration:float=10,
|
463 |
+
sr:int = 0,
|
464 |
+
lang = 'en',
|
465 |
+
plain_rate = 0,
|
466 |
+
return_audio = True,
|
467 |
+
return_path = False,
|
468 |
+
prompt_template_path: os.PathLike = None,
|
469 |
+
tag_types = [],
|
470 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
471 |
+
use_literal_none = True,
|
472 |
+
):
|
473 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
474 |
+
|
475 |
+
self.data_dir = data_dir
|
476 |
+
self._load_metadata_json(json_path)
|
477 |
+
self.sr = sr
|
478 |
+
self.duration = duration
|
479 |
+
self.plain_rate = plain_rate
|
480 |
+
self.return_audio = return_audio
|
481 |
+
self.return_path = return_path
|
482 |
+
self.use_literal_none = use_literal_none
|
483 |
+
self.lang = lang
|
484 |
+
|
485 |
+
self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
|
486 |
+
if self.use_dynamic_prompt:
|
487 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
488 |
+
self.tag_types = tag_types
|
489 |
+
|
490 |
+
self.translate = read_translate(translate)
|
491 |
+
|
492 |
+
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
493 |
+
WEAK_TAG_LIST = ["title", "artist"]
|
494 |
+
|
495 |
+
def _load_metadata_json(self, json_path):
|
496 |
+
with open(json_path) as fp:
|
497 |
+
self.data = json.load(fp)
|
498 |
+
|
499 |
+
def convert_key_to_path(self, key):
|
500 |
+
return os.path.join(self.data_dir, get_base_dir_file(key))
|
501 |
+
|
502 |
+
def __len__(self):
|
503 |
+
return len(self.data)
|
504 |
+
|
505 |
+
def __getitem__(self, idx):
|
506 |
+
item = self.data[idx]
|
507 |
+
path = self.convert_key_to_path(item['key'])
|
508 |
+
description = self.generate_description(item)
|
509 |
+
|
510 |
+
if self.return_audio:
|
511 |
+
sr, duration = get_sr_and_duration_info(item)
|
512 |
+
audio = self.audio_reader(path, sr, duration)
|
513 |
+
else:
|
514 |
+
audio = None
|
515 |
+
|
516 |
+
if self.return_path:
|
517 |
+
return audio, description, path
|
518 |
+
return audio, description
|
519 |
+
|
520 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
521 |
+
if self.lang == 'en':
|
522 |
+
return tags_to_desc(tag_list)
|
523 |
+
elif self.lang == 'zh':
|
524 |
+
translator = self.translate[tag_type]
|
525 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
526 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
527 |
+
|
528 |
+
def generate_description(self, item):
|
529 |
+
if random.random() > self.plain_rate:
|
530 |
+
# dynamically generate prompt from given prompt template
|
531 |
+
prompt_template = random.choice(self.prompt_templates)
|
532 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
533 |
+
else:
|
534 |
+
# use plain prompt, i.e. tags sequence separated by comma
|
535 |
+
description = self.generate_description_plain(item)
|
536 |
+
return description
|
537 |
+
|
538 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
539 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
540 |
+
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
541 |
+
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
542 |
+
|
543 |
+
if len(exists_strong_tag) > 0:
|
544 |
+
probs = dist_prob_map[len(exists_strong_tag)]
|
545 |
+
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
546 |
+
random.shuffle(exists_strong_tag)
|
547 |
+
tags = exists_strong_tag[:tags_num]
|
548 |
+
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
549 |
+
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
550 |
+
random.shuffle(exists_weak_tag)
|
551 |
+
weak_tags = exists_weak_tag[:weak_tags_num]
|
552 |
+
tags += weak_tags
|
553 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
554 |
+
prompt = prompt_template.apply(**tags_args)
|
555 |
+
else:
|
556 |
+
# no strong tags, use all weak tags instead
|
557 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
558 |
+
prompt = prompt_template.apply(**tags_args)
|
559 |
+
|
560 |
+
if self.use_literal_none and len(tags_args) == 0:
|
561 |
+
return 'none'
|
562 |
+
|
563 |
+
return prompt
|
564 |
+
|
565 |
+
def generate_description_plain(self, item):
|
566 |
+
keywords = []
|
567 |
+
for tag_t in self.tag_types:
|
568 |
+
this_key = item[tag_t]
|
569 |
+
if this_key is None:
|
570 |
+
continue
|
571 |
+
if isinstance(this_key, str):
|
572 |
+
this_key = [this_key]
|
573 |
+
if self.lang != 'en':
|
574 |
+
this_key = [self.get_translation(tag_t, k) for k in this_key]
|
575 |
+
keywords += this_key
|
576 |
+
return gen_plain_prompt(keywords, sep=self.keysep)
|
577 |
+
|
578 |
+
def get_translation(self, tag_t, k):
|
579 |
+
k = k.strip()
|
580 |
+
if k in self.translate[tag_t]:
|
581 |
+
return self.translate[tag_t][k]
|
582 |
+
else:
|
583 |
+
return k
|
584 |
+
|
585 |
+
@property
|
586 |
+
def keysep(self):
|
587 |
+
if self.lang == 'zh':
|
588 |
+
return ',' if random.random() > 0.5 else '、'
|
589 |
+
elif self.lang == 'en':
|
590 |
+
return ', '
|
591 |
+
|
592 |
+
class AudioStockDataset(Dataset):
|
593 |
+
def __init__(self,
|
594 |
+
metadata_path:str,
|
595 |
+
duration:float=10,
|
596 |
+
sr:int = 0,
|
597 |
+
plain_rate = 0,
|
598 |
+
return_path = False,
|
599 |
+
return_audio = True,
|
600 |
+
prompt_template_path: os.PathLike = None,
|
601 |
+
tag_types = [],
|
602 |
+
lang = 'en',
|
603 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
604 |
+
use_literal_none = True,
|
605 |
+
):
|
606 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
607 |
+
|
608 |
+
self._load_metadata(metadata_path)
|
609 |
+
self.sr = sr
|
610 |
+
self.duration = duration
|
611 |
+
self.plain_rate = plain_rate
|
612 |
+
self.return_path = return_path
|
613 |
+
self.return_audio = return_audio
|
614 |
+
self.use_literal_none = use_literal_none
|
615 |
+
|
616 |
+
self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
|
617 |
+
if self.use_dynamic_prompt:
|
618 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
619 |
+
self.tag_types = tag_types
|
620 |
+
|
621 |
+
self.lang = lang
|
622 |
+
self.translate = read_translate(translate)
|
623 |
+
|
624 |
+
def _load_metadata(self, metadata_path):
|
625 |
+
with open(metadata_path) as fp:
|
626 |
+
lines = fp.readlines()
|
627 |
+
self.data = []
|
628 |
+
for line in lines:
|
629 |
+
item = json.loads(line)
|
630 |
+
self.data.append(item)
|
631 |
+
self.is_info_recorded = bool('Tags' in self.data[0])
|
632 |
+
|
633 |
+
def __len__(self):
|
634 |
+
return len(self.data)
|
635 |
+
|
636 |
+
def __getitem__(self, idx):
|
637 |
+
path:str = self.data[idx]["path"]
|
638 |
+
json_path = path[:path.rfind('.')] + ".json"
|
639 |
+
if self.is_info_recorded:
|
640 |
+
item = self.data[idx]
|
641 |
+
else:
|
642 |
+
try:
|
643 |
+
with open(json_path) as fp:
|
644 |
+
item:dict = json.load(fp)
|
645 |
+
except Exception as e:
|
646 |
+
print(f"Error loading json file {json_path} :\n{e}")
|
647 |
+
item = {}
|
648 |
+
description = self.generate_description(item)
|
649 |
+
if self.return_audio:
|
650 |
+
sr, duration = get_sr_and_duration_info(item)
|
651 |
+
audio = self.audio_reader(path, sr, duration)
|
652 |
+
else:
|
653 |
+
audio = None
|
654 |
+
if self.return_path:
|
655 |
+
return audio, description, path
|
656 |
+
return audio, description
|
657 |
+
|
658 |
+
def generate_description(self, item):
|
659 |
+
if random.random() > self.plain_rate:
|
660 |
+
# dynamically generate prompt from given prompt template
|
661 |
+
prompt_template = random.choice(self.prompt_templates)
|
662 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
663 |
+
else:
|
664 |
+
# use plain prompt, i.e. tags sequence separated by comma
|
665 |
+
description = self.generate_description_plain(item)
|
666 |
+
return description
|
667 |
+
|
668 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
669 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
670 |
+
|
671 |
+
if len(exists_tag) > 0:
|
672 |
+
probs = dist_prob_map[len(exists_tag)]
|
673 |
+
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
674 |
+
random.shuffle(exists_tag)
|
675 |
+
tags = exists_tag[:tags_num]
|
676 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
677 |
+
tags_args = self.handle_BPM_tag(tags_args)
|
678 |
+
prompt = prompt_template.apply(**tags_args)
|
679 |
+
else:
|
680 |
+
return 'none'
|
681 |
+
|
682 |
+
if self.use_literal_none and len(tags_args) == 0:
|
683 |
+
return 'none'
|
684 |
+
|
685 |
+
return prompt
|
686 |
+
|
687 |
+
def get_translation(self, tag_t, k):
|
688 |
+
k = k.strip()
|
689 |
+
if k in self.translate[tag_t]:
|
690 |
+
return self.translate[tag_t][k]
|
691 |
+
else:
|
692 |
+
return k
|
693 |
+
|
694 |
+
def generate_description_plain(self, item):
|
695 |
+
keywords = []
|
696 |
+
for tag_t in self.tag_types:
|
697 |
+
if tag_t == 'BPMDescript':
|
698 |
+
bpm = item['BPM']
|
699 |
+
if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
|
700 |
+
continue
|
701 |
+
this_key = gen_bpm_descript(bpm.strip(), lang=self.lang)
|
702 |
+
elif tag_t == 'BPM':
|
703 |
+
bpm = item['BPM']
|
704 |
+
if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
|
705 |
+
continue
|
706 |
+
this_key = f"{bpm.strip()} bpm"
|
707 |
+
else:
|
708 |
+
this_key = item[tag_t]
|
709 |
+
if this_key is None:
|
710 |
+
continue
|
711 |
+
if isinstance(this_key, str):
|
712 |
+
this_key = [this_key]
|
713 |
+
if self.lang != 'en':
|
714 |
+
this_key = [self.get_translation(tag_t, k) for k in this_key]
|
715 |
+
if this_key is None:
|
716 |
+
continue
|
717 |
+
if isinstance(this_key, str):
|
718 |
+
this_key = [this_key]
|
719 |
+
keywords += this_key
|
720 |
+
return gen_plain_prompt(keywords, sep=self.keysep)
|
721 |
+
|
722 |
+
@property
|
723 |
+
def keysep(self):
|
724 |
+
if self.lang == 'zh':
|
725 |
+
return ',' if random.random() > 0.5 else '、'
|
726 |
+
elif self.lang == 'en':
|
727 |
+
return ', '
|
728 |
+
|
729 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
730 |
+
if self.lang == 'en':
|
731 |
+
return tags_to_desc(tag_list)
|
732 |
+
elif self.lang == 'zh':
|
733 |
+
if tag_type == 'BPM':
|
734 |
+
return tags_to_desc(tag_list, sep='、')
|
735 |
+
translator = self.translate[tag_type]
|
736 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
737 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
738 |
+
|
739 |
+
def handle_BPM_tag(self, tags_args):
|
740 |
+
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
741 |
+
bpm = tags_args["BPM"]
|
742 |
+
del tags_args["BPM"]
|
743 |
+
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
744 |
+
for tag_type in tag_types_used:
|
745 |
+
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
746 |
+
return tags_args
|
747 |
+
|
748 |
+
def mp3_path_to_id(mp3_path):
|
749 |
+
return int(
|
750 |
+
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')]
|
751 |
+
)
|
752 |
+
|
753 |
+
class TmeDataset(Dataset):
|
754 |
+
def __init__(self,
|
755 |
+
data_index:str,
|
756 |
+
music_info:str = None,
|
757 |
+
duration:float = 10,
|
758 |
+
sr:int = 0,
|
759 |
+
plain_rate = 0,
|
760 |
+
return_path = False,
|
761 |
+
return_audio = True,
|
762 |
+
return_ID = False,
|
763 |
+
prompt_format_path: os.PathLike = None,
|
764 |
+
tag_types = ['*'],
|
765 |
+
lang = 'zh',
|
766 |
+
translate: Optional[os.PathLike] = None,
|
767 |
+
prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt
|
768 |
+
):
|
769 |
+
if plain_rate > 0:
|
770 |
+
print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.")
|
771 |
+
plain_rate = 0
|
772 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
773 |
+
|
774 |
+
self.sr = sr
|
775 |
+
self.duration = duration
|
776 |
+
self.plain_rate = plain_rate
|
777 |
+
self.return_path = return_path
|
778 |
+
self.return_audio = return_audio
|
779 |
+
self.return_ID = return_ID
|
780 |
+
self.lang = lang
|
781 |
+
|
782 |
+
self.use_ready_prompt = prompt_dir is not None
|
783 |
+
|
784 |
+
data_index = read_jsonlike(data_index)
|
785 |
+
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
786 |
+
self.data_ids = list(self.data_index_dict.keys())
|
787 |
+
|
788 |
+
if not self.use_ready_prompt:
|
789 |
+
#读取音乐的信息文件
|
790 |
+
music_info = read_jsonlike(music_info)
|
791 |
+
if 'music' in music_info:
|
792 |
+
music_info = music_info['music']
|
793 |
+
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
794 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
795 |
+
self.data_ids = list(self.data_index_dict.keys())
|
796 |
+
|
797 |
+
with open(prompt_format_path) as fp:
|
798 |
+
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
799 |
+
|
800 |
+
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
801 |
+
if '*' in tag_types:
|
802 |
+
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
803 |
+
else:
|
804 |
+
self.tag_types = tag_types
|
805 |
+
|
806 |
+
self.key_tag_types = []
|
807 |
+
if 'tag' in self.tag_types:
|
808 |
+
self.tag_types.remove('tag')
|
809 |
+
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
810 |
+
|
811 |
+
#加载translate翻译
|
812 |
+
if translate is not None:
|
813 |
+
self.translator = read_jsonlike(translate)
|
814 |
+
else:
|
815 |
+
data_ids_set = set(self.data_ids)
|
816 |
+
self.prompts_dict = {}
|
817 |
+
for fname in os.listdir(prompt_dir):
|
818 |
+
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
819 |
+
for item in items:
|
820 |
+
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
821 |
+
continue
|
822 |
+
if item['ID'] not in self.prompts_dict:
|
823 |
+
self.prompts_dict[item['ID']] = []
|
824 |
+
self.prompts_dict[item['ID']].append(item['Text'])
|
825 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
826 |
+
self.data_ids = list(self.data_index_dict.keys())
|
827 |
+
|
828 |
+
def tags_to_desc(self, tag_list) -> str:
|
829 |
+
if is_bearable(tag_list, int):
|
830 |
+
return str(tag_list)
|
831 |
+
if self.lang == 'zh':
|
832 |
+
return tags_to_desc(tag_list, sep=self.sep)
|
833 |
+
else:
|
834 |
+
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
835 |
+
return tags_to_desc(translated_tag_list, sep=self.sep)
|
836 |
+
|
837 |
+
def gen_desc_of_tag(self, formats, tags):
|
838 |
+
fmt = random.choice(formats)
|
839 |
+
return fmt.format(self.tags_to_desc(tags))
|
840 |
+
|
841 |
+
@staticmethod
|
842 |
+
def check_valid(value):
|
843 |
+
if isinstance(value, int) or isinstance(value, float):
|
844 |
+
return value > 0
|
845 |
+
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
846 |
+
return True
|
847 |
+
return False
|
848 |
+
|
849 |
+
@staticmethod
|
850 |
+
def remove_repeat(data):
|
851 |
+
#若专辑名和歌曲名相同,则只使用后者
|
852 |
+
album_name = data.get('专辑名', None)
|
853 |
+
if album_name is not None and album_name == data.get('歌曲名', None):
|
854 |
+
del data['专辑名']
|
855 |
+
return data
|
856 |
+
|
857 |
+
@property
|
858 |
+
def comma(self):
|
859 |
+
if self.lang == 'zh':
|
860 |
+
return ','
|
861 |
+
elif self.lang == 'en':
|
862 |
+
return ', '
|
863 |
+
|
864 |
+
@property
|
865 |
+
def sep(self):
|
866 |
+
if self.lang == 'zh':
|
867 |
+
return '、'
|
868 |
+
elif self.lang == 'en':
|
869 |
+
return ', '
|
870 |
+
|
871 |
+
|
872 |
+
def generate_description(self, item):
|
873 |
+
if random.random() > self.plain_rate:
|
874 |
+
# dynamically generate prompt from given prompt template
|
875 |
+
description = self.generate_description_dynamic(item)
|
876 |
+
else:
|
877 |
+
# use plain prompt, i.e. tags sequence separated by comma
|
878 |
+
description = self.generate_description_plain(item)
|
879 |
+
return description
|
880 |
+
|
881 |
+
def generate_description_dynamic(self, data):
|
882 |
+
data = self.remove_repeat(data)
|
883 |
+
|
884 |
+
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
885 |
+
|
886 |
+
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
887 |
+
|
888 |
+
prompts = []
|
889 |
+
if len(weak_tags) > 0:
|
890 |
+
probs = dist_prob_map_low[len(weak_tags)]
|
891 |
+
if len(key_tags) > 0:
|
892 |
+
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
893 |
+
else:
|
894 |
+
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
895 |
+
random.shuffle(weak_tags)
|
896 |
+
tags = weak_tags[:tags_num]
|
897 |
+
for tag_type in tags:
|
898 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
899 |
+
prompts.append(tag_desc)
|
900 |
+
|
901 |
+
if len(key_tags) > 0:
|
902 |
+
probs = dist_prob_map[len(key_tags)]
|
903 |
+
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
904 |
+
random.shuffle(key_tags)
|
905 |
+
tags = key_tags[:tags_num]
|
906 |
+
for tag_type in tags:
|
907 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
908 |
+
prompts.append(tag_desc)
|
909 |
+
|
910 |
+
random.shuffle(prompts)
|
911 |
+
return self.comma.join(prompts)
|
912 |
+
|
913 |
+
def generate_description_plain(self, item):
|
914 |
+
keywords = item['tag']
|
915 |
+
if self.lang != 'en':
|
916 |
+
keywords = [self.translator[k.strip()] for k in keywords]
|
917 |
+
return gen_plain_prompt(keywords, sep=self.keysep)
|
918 |
+
|
919 |
+
@property
|
920 |
+
def keysep(self):
|
921 |
+
if self.lang == 'zh':
|
922 |
+
return ',' if random.random() > 0.5 else '、'
|
923 |
+
elif self.lang == 'en':
|
924 |
+
return ', '
|
925 |
+
|
926 |
+
def is_valid_prompt_text(self, text):
|
927 |
+
for bad in ('抱歉','sorry', 'Sorry'):
|
928 |
+
if bad in text:
|
929 |
+
return False
|
930 |
+
return True
|
931 |
+
|
932 |
+
def get_ready_prompt(self, path):
|
933 |
+
sid = mp3_path_to_id(path)
|
934 |
+
return random.choice(self.prompts_dict[sid])
|
935 |
+
|
936 |
+
def __len__(self):
|
937 |
+
return len(self.data_ids)
|
938 |
+
|
939 |
+
def __getitem__(self, idx):
|
940 |
+
data_id = self.data_ids[idx]
|
941 |
+
item = self.data_index_dict[data_id]
|
942 |
+
path = item['path']
|
943 |
+
if not self.use_ready_prompt:
|
944 |
+
info = self.music_info_dict[data_id]
|
945 |
+
description = self.generate_description(info)
|
946 |
+
else:
|
947 |
+
description = self.get_ready_prompt(path)
|
948 |
+
if self.return_audio:
|
949 |
+
sr, duration = get_sr_and_duration_info(item)
|
950 |
+
audio = self.audio_reader(path, sr, duration)
|
951 |
+
else:
|
952 |
+
audio = None
|
953 |
+
if self.return_path:
|
954 |
+
if self.return_ID:
|
955 |
+
return audio, description, path, info['歌曲ID']
|
956 |
+
return audio, description, path
|
957 |
+
if self.return_ID:
|
958 |
+
return audio, description, info['歌曲ID']
|
959 |
+
return audio, description
|
960 |
+
|
961 |
+
|
962 |
+
class Pond5Dataset(Dataset):
|
963 |
+
MAX_PROMPT_LEN = 200
|
964 |
+
def __init__(self,
|
965 |
+
metadata_path:str,
|
966 |
+
index_path:str,
|
967 |
+
duration:float=10,
|
968 |
+
sr:int = 0,
|
969 |
+
plain_rate = 0,
|
970 |
+
return_path = False,
|
971 |
+
return_audio = True,
|
972 |
+
lang = 'en',
|
973 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
974 |
+
use_literal_none = True,
|
975 |
+
use_avoid_watermark_policy = None,
|
976 |
+
):
|
977 |
+
|
978 |
+
if use_avoid_watermark_policy is None:
|
979 |
+
raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
|
980 |
+
self.use_avoid_watermark_policy = use_avoid_watermark_policy
|
981 |
+
self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy)
|
982 |
+
|
983 |
+
self._load_metadata(metadata_path, index_path)
|
984 |
+
self.sr = sr
|
985 |
+
self.duration = duration
|
986 |
+
self.plain_rate = plain_rate
|
987 |
+
self.return_path = return_path
|
988 |
+
self.return_audio = return_audio
|
989 |
+
self.use_literal_none = use_literal_none
|
990 |
+
|
991 |
+
self.lang = lang
|
992 |
+
self.translate = read_translate(translate)
|
993 |
+
|
994 |
+
def _load_metadata(self, metadata_path, index_path):
|
995 |
+
data_index = read_jsonlike(index_path)
|
996 |
+
data_ids = set([item['id'] for item in data_index])
|
997 |
+
|
998 |
+
with open(metadata_path) as fp:
|
999 |
+
lines = fp.readlines()
|
1000 |
+
|
1001 |
+
append_ids = set()
|
1002 |
+
|
1003 |
+
self.data = []
|
1004 |
+
for line in lines:
|
1005 |
+
item = json.loads(line)
|
1006 |
+
if item['id'] in data_ids and item['id'] not in append_ids:
|
1007 |
+
self.data.append(item)
|
1008 |
+
append_ids.add(item['id'])
|
1009 |
+
|
1010 |
+
def __len__(self):
|
1011 |
+
return len(self.data)
|
1012 |
+
|
1013 |
+
def __getitem__(self, idx):
|
1014 |
+
item = self.data[idx]
|
1015 |
+
path:str = item["path"]
|
1016 |
+
description = self.generate_description(item)
|
1017 |
+
if self.return_audio:
|
1018 |
+
sr, duration = get_sr_and_duration_info(item)
|
1019 |
+
audio = self.audio_reader(path, sr, duration)
|
1020 |
+
else:
|
1021 |
+
audio = None
|
1022 |
+
if self.return_path:
|
1023 |
+
return audio, description, path
|
1024 |
+
return audio, description
|
1025 |
+
|
1026 |
+
@property
|
1027 |
+
def keysep(self):
|
1028 |
+
if self.lang == 'zh':
|
1029 |
+
return ',' if random.random() > 0.5 else '、'
|
1030 |
+
elif self.lang == 'en':
|
1031 |
+
return ', '
|
1032 |
+
|
1033 |
+
def generate_description(self, item):
|
1034 |
+
if random.random() > self.plain_rate:
|
1035 |
+
# dynamically generate prompt from given prompt template
|
1036 |
+
description = self.generate_description_dynamic(item)
|
1037 |
+
else:
|
1038 |
+
# use plain prompt, i.e. tags sequence separated by comma
|
1039 |
+
description = self.generate_description_plain(item)
|
1040 |
+
return description
|
1041 |
+
|
1042 |
+
def get_translation(self, k):
|
1043 |
+
k = k.strip()
|
1044 |
+
if k in self.translate:
|
1045 |
+
return self.translate[k]
|
1046 |
+
else:
|
1047 |
+
return k
|
1048 |
+
|
1049 |
+
def generate_description_plain(self, item):
|
1050 |
+
keywords = item['keywords']
|
1051 |
+
if self.lang != 'en':
|
1052 |
+
keywords = [self.get_translation(k) for k in keywords]
|
1053 |
+
return gen_plain_prompt(keywords, sep=self.keysep)
|
1054 |
+
|
1055 |
+
def generate_description_dynamic(self,item):
|
1056 |
+
desc = item.get('desc', 'none')
|
1057 |
+
if desc is None:
|
1058 |
+
desc = 'none'
|
1059 |
+
desc = desc.strip()
|
1060 |
+
if len(desc) > self.MAX_PROMPT_LEN:
|
1061 |
+
shorter_desc = desc[:self.MAX_PROMPT_LEN]
|
1062 |
+
# find last stop
|
1063 |
+
stop_idx = shorter_desc.rfind('.')
|
1064 |
+
if stop_idx == -1:
|
1065 |
+
stop_idx = shorter_desc.rfind('!')
|
1066 |
+
if stop_idx == -1:
|
1067 |
+
stop_idx = shorter_desc.rfind(',')
|
1068 |
+
if stop_idx == -1:
|
1069 |
+
stop_idx = self.MAX_PROMPT_LEN - 1
|
1070 |
+
desc = desc[:stop_idx+1]
|
1071 |
+
return desc
|
1072 |
+
|
1073 |
+
class SoundDataset(Dataset):
|
1074 |
+
def __init__(self,
|
1075 |
+
metadata_index: str,
|
1076 |
+
duration:float = 10,
|
1077 |
+
min_non_silent_duration:float = 3,
|
1078 |
+
sr:int = 0,
|
1079 |
+
return_path = False,
|
1080 |
+
return_audio = True,
|
1081 |
+
):
|
1082 |
+
self.data = read_jsonlike(metadata_index)
|
1083 |
+
self.sr = sr
|
1084 |
+
self.reader = SafeAudioReader(duration, sr)
|
1085 |
+
self.duration = duration
|
1086 |
+
self.min_non_silent_duration = min_non_silent_duration
|
1087 |
+
self.return_audio = return_audio
|
1088 |
+
self.return_path = return_path
|
1089 |
+
|
1090 |
+
def __getitem__(self, index):
|
1091 |
+
item = self.data[index]
|
1092 |
+
if self.return_audio:
|
1093 |
+
origin_duration = item['duration']
|
1094 |
+
if origin_duration < self.min_non_silent_duration:
|
1095 |
+
audio = self.read_and_repeat_and_pad(item)
|
1096 |
+
else:
|
1097 |
+
audio = self.reader(item['path'], item['sample_rate'], origin_duration)
|
1098 |
+
else:
|
1099 |
+
audio = None
|
1100 |
+
desc = item['caption']
|
1101 |
+
if self.return_path:
|
1102 |
+
return audio, desc, item['path']
|
1103 |
+
else:
|
1104 |
+
return audio, desc
|
1105 |
+
|
1106 |
+
def __len__(self):
|
1107 |
+
return len(self.data)
|
1108 |
+
|
1109 |
+
def read_and_repeat_and_pad(self, item):
|
1110 |
+
path = item['path']
|
1111 |
+
try:
|
1112 |
+
# read
|
1113 |
+
clip, sr = torchaudio.load(path)
|
1114 |
+
if len(clip.shape) > 1:
|
1115 |
+
clip = torch.mean(clip, dim=0, keepdim=True)
|
1116 |
+
clip = resample(clip, sr, self.sr)
|
1117 |
+
#repeat
|
1118 |
+
n_repeats = math.ceil(self.min_non_silent_duration/item['duration'])
|
1119 |
+
clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1)
|
1120 |
+
#pad
|
1121 |
+
n_samples = int(self.duration * self.sr)
|
1122 |
+
if clip.shape[0] >= n_samples:
|
1123 |
+
audio = clip[:n_samples]
|
1124 |
+
else:
|
1125 |
+
audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype)
|
1126 |
+
start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0])))
|
1127 |
+
audio[start_pos:start_pos+clip.shape[0]] = clip
|
1128 |
+
return audio
|
1129 |
+
|
1130 |
+
except Exception as e:
|
1131 |
+
logger.error(f"Error reading {path}: {e}")
|
1132 |
+
wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32)
|
1133 |
+
return wav
|
1134 |
+
|
1135 |
+
class CombinedDataset(Dataset):
|
1136 |
+
@beartype
|
1137 |
+
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
1138 |
+
self.datasets = datasets
|
1139 |
+
self.datasets_index = []
|
1140 |
+
|
1141 |
+
for i,dataset in enumerate(datasets):
|
1142 |
+
if dataset is None:
|
1143 |
+
continue
|
1144 |
+
for dup in range(ratios[i]):
|
1145 |
+
for j in range(len(dataset)):
|
1146 |
+
self.datasets_index.append((i,j))
|
1147 |
+
|
1148 |
+
def __len__(self):
|
1149 |
+
return len(self.datasets_index)
|
1150 |
+
|
1151 |
+
def __getitem__(self, idx):
|
1152 |
+
index = self.datasets_index[idx]
|
1153 |
+
i,j = index
|
1154 |
+
return self.datasets[i][j]
|
1155 |
+
|
1156 |
+
class CombinedDataset_random(Dataset):
|
1157 |
+
@beartype
|
1158 |
+
def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
1159 |
+
self.datasets = datasets
|
1160 |
+
self.datasets_index = []
|
1161 |
+
|
1162 |
+
for i,dataset in enumerate(datasets):
|
1163 |
+
if dataset is None:
|
1164 |
+
continue
|
1165 |
+
for dup in range(ratios[i]):
|
1166 |
+
for j in range(len(dataset)):
|
1167 |
+
self.datasets_index.append((i,j))
|
1168 |
+
|
1169 |
+
if num_examples > 0:
|
1170 |
+
self.random_choose = True
|
1171 |
+
self.dataset_len = num_examples
|
1172 |
+
else:
|
1173 |
+
self.random_choose = False
|
1174 |
+
self.dataset_len = len(self.datasets_index)
|
1175 |
+
|
1176 |
+
def __len__(self):
|
1177 |
+
return self.dataset_len
|
1178 |
+
|
1179 |
+
def __getitem__(self, idx):
|
1180 |
+
first_try = True
|
1181 |
+
try_cnt = 0
|
1182 |
+
while True:
|
1183 |
+
try:
|
1184 |
+
if(self.random_choose or not first_try):
|
1185 |
+
index2 = []
|
1186 |
+
index2.append(np.random.randint(0,len(self.datasets)))
|
1187 |
+
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
1188 |
+
else:
|
1189 |
+
index2 = self.datasets_index[idx]
|
1190 |
+
first_try = False
|
1191 |
+
out = list(self.datasets[index2[0]][index2[1]])
|
1192 |
+
return out
|
1193 |
+
except:
|
1194 |
+
print("Error loadding ", index2)
|
1195 |
+
try_cnt += 1
|
1196 |
+
if(try_cnt>10):
|
1197 |
+
raise ValueError()
|
1198 |
+
|
1199 |
+
class SoundMixedDataset(Dataset):
|
1200 |
+
@staticmethod
|
1201 |
+
def music_desc(desc):
|
1202 |
+
return f'Music:<{desc}>'
|
1203 |
+
@staticmethod
|
1204 |
+
def sound_desc(desc):
|
1205 |
+
return f'Effect:<{desc}>'
|
1206 |
+
|
1207 |
+
def __init__(self,
|
1208 |
+
music_dataset: Dataset,
|
1209 |
+
sound_dataset: Dataset,
|
1210 |
+
mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例
|
1211 |
+
) -> None:
|
1212 |
+
self.music_dataset = music_dataset
|
1213 |
+
self.sound_dataset = sound_dataset
|
1214 |
+
music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例
|
1215 |
+
#三个概率区间的左端点
|
1216 |
+
self.music_anchor = 0
|
1217 |
+
self.sound_anchor = music_r
|
1218 |
+
self.mix_anchor = music_r + sound_r
|
1219 |
+
|
1220 |
+
def __len__(self):
|
1221 |
+
return len(self.music_dataset)
|
1222 |
+
|
1223 |
+
def get_random_sound_data(self):
|
1224 |
+
idx = random.randint(0, len(self.sound_dataset)-1)
|
1225 |
+
return self.sound_dataset[idx]
|
1226 |
+
|
1227 |
+
def __getitem__(self, idx):
|
1228 |
+
p = random.random()
|
1229 |
+
if p >= self.mix_anchor:
|
1230 |
+
music, m_desc = self.music_dataset[idx]
|
1231 |
+
sound, s_desc = self.get_random_sound_data()
|
1232 |
+
audio = music + sound
|
1233 |
+
if(audio.abs().max()>1.0):
|
1234 |
+
music = music / audio.abs().max() * 0.95
|
1235 |
+
audio = audio / audio.abs().max() * 0.95
|
1236 |
+
desc = self.music_desc(m_desc) + self.sound_desc(s_desc)
|
1237 |
+
return audio[None,:], music[None,:], desc
|
1238 |
+
elif p >= self.sound_anchor:
|
1239 |
+
audio, desc = self.get_random_sound_data()
|
1240 |
+
return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc)
|
1241 |
+
else:
|
1242 |
+
audio, desc = self.music_dataset[idx]
|
1243 |
+
return audio[None,:], audio[None,:], self.music_desc(desc)
|
1244 |
+
|
1245 |
+
|
1246 |
+
class DecoTagDataset(Dataset):
|
1247 |
+
'''这个类把普通的datatset包装成适用于标签解耦学习的dataset'''
|
1248 |
+
|
1249 |
+
TAG_TYPES = ('genre', 'mood', 'insrument')
|
1250 |
+
|
1251 |
+
def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs):
|
1252 |
+
self.datasets = []
|
1253 |
+
for i, tag_t in enumerate(self.TAG_TYPES):
|
1254 |
+
kwargs['tag_types'] = [tag_map[tag_t]]
|
1255 |
+
kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本
|
1256 |
+
self.datasets.append(dataset_class(*args, **kwargs))
|
1257 |
+
|
1258 |
+
def __len__(self):
|
1259 |
+
return len(self.datasets[0])
|
1260 |
+
|
1261 |
+
def __getitem__(self, idx):
|
1262 |
+
audio, text = self.datasets[0][idx]
|
1263 |
+
texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1])
|
1264 |
+
return audio, texts
|
1265 |
+
|
1266 |
+
|
1267 |
+
class DecoTagWrapper:
|
1268 |
+
'''这是一个包装器,便于选择是否使用标签解耦学习'''
|
1269 |
+
def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False):
|
1270 |
+
self.dataset_class = dataset_class
|
1271 |
+
self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types))
|
1272 |
+
self.switch_on = switch_on
|
1273 |
+
|
1274 |
+
def __call__(self, *args, **kwargs):
|
1275 |
+
if self.switch_on:
|
1276 |
+
return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs)
|
1277 |
+
else:
|
1278 |
+
return self.dataset_class(*args, **kwargs)
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torchaudio
|
8 |
+
from torchaudio.functional import resample
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
|
14 |
+
PARAGRAPH_GAP = 6
|
15 |
+
MIN_MUSIC_LEN = 3
|
16 |
+
|
17 |
+
def check_lryics(lyric):
|
18 |
+
_FILTER_STRING = [
|
19 |
+
'作词', '作曲', '编曲', '【', '策划',
|
20 |
+
'录音', '混音', '母带', ':', '制作',
|
21 |
+
'版权', '校对', '演奏', '制作', '伴奏'
|
22 |
+
]
|
23 |
+
for item in _FILTER_STRING:
|
24 |
+
if item in lyric:
|
25 |
+
return True
|
26 |
+
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def process_lyrics(lines):
|
32 |
+
lyric_part = []
|
33 |
+
timestamp_part = []
|
34 |
+
|
35 |
+
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
36 |
+
|
37 |
+
for i, line in enumerate(lines):
|
38 |
+
|
39 |
+
# 删除前几行的特定信息
|
40 |
+
if i<10 and check_lryics(line):
|
41 |
+
continue
|
42 |
+
|
43 |
+
# 检查是否包含有效的时间戳和歌词内容
|
44 |
+
if timestamp_pattern.match(line):
|
45 |
+
timestamp_end = line.rfind(']')
|
46 |
+
lyrics = line[timestamp_end + 1:].strip()
|
47 |
+
timestamps = line[:timestamp_end + 1]
|
48 |
+
|
49 |
+
if ':' in lyrics:
|
50 |
+
if len(lyrics.split(":")[0]) <=5:
|
51 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
52 |
+
# if lyrics: # 确保歌词部分不是空的
|
53 |
+
# lyric_part.append(lyrics)
|
54 |
+
# timestamp_part.append(timestamps)
|
55 |
+
# print(processed_lyrics)
|
56 |
+
return timestamp_part, lyric_part
|
57 |
+
|
58 |
+
def get_timestamps(timestamp_part):
|
59 |
+
|
60 |
+
# 转换为秒
|
61 |
+
|
62 |
+
timestamps = []
|
63 |
+
|
64 |
+
for line in timestamp_part:
|
65 |
+
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
66 |
+
if match:
|
67 |
+
minutes = int(match.group(1))
|
68 |
+
seconds = float(match.group(2))
|
69 |
+
millis = float(match.group(3)) if match.group(3) else 0
|
70 |
+
total_seconds = minutes * 60 + seconds + millis
|
71 |
+
timestamps.append(total_seconds)
|
72 |
+
|
73 |
+
|
74 |
+
return timestamps
|
75 |
+
|
76 |
+
def process_lyrics_lrc(lyrics):
|
77 |
+
timestamp_part, lyric_part = process_lyrics(lyrics)
|
78 |
+
# print(timestamp_part)
|
79 |
+
# print(lyric_part)
|
80 |
+
timestamps = get_timestamps(timestamp_part)
|
81 |
+
# print(timestamps)
|
82 |
+
if len(timestamps) == 0:
|
83 |
+
# print(f'{lyric_path}')
|
84 |
+
return []
|
85 |
+
|
86 |
+
slice_start = timestamps[0]
|
87 |
+
slice_start_idx = 0
|
88 |
+
|
89 |
+
output_list = []
|
90 |
+
for i in range(1, len(timestamps)):
|
91 |
+
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
92 |
+
if timestamps[i] - slice_start > 30:
|
93 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
94 |
+
|
95 |
+
slice_start = timestamps[i]
|
96 |
+
slice_start_idx = i
|
97 |
+
|
98 |
+
return output_list
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def process_lyrics_yrc(lyrics):
|
103 |
+
|
104 |
+
timestamps, lyric_part = extract_lrc(lyrics)
|
105 |
+
|
106 |
+
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
107 |
+
# import pdb; pdb.set_trace()
|
108 |
+
# print(timestamp_part)
|
109 |
+
# print(lyric_part)
|
110 |
+
# timestamps = get_timestamps(timestamp_part)
|
111 |
+
# print(timestamps)
|
112 |
+
if len(timestamps) == 0:
|
113 |
+
# print(f'{lyric_path}')
|
114 |
+
return []
|
115 |
+
|
116 |
+
slice_start = timestamps[0]
|
117 |
+
slice_start_idx = 0
|
118 |
+
|
119 |
+
output_list = []
|
120 |
+
for i in range(1, len(timestamps)):
|
121 |
+
# 如果累积时间超过30秒,则进行切分
|
122 |
+
if timestamps[i] - slice_start > 30:
|
123 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
124 |
+
|
125 |
+
slice_start = timestamps[i]
|
126 |
+
slice_start_idx = i
|
127 |
+
# import pdb; pdb.set_trace()
|
128 |
+
return output_list
|
129 |
+
|
130 |
+
def extract_lrc(lyrics):
|
131 |
+
timestamp_part, lyric_part = [], []
|
132 |
+
|
133 |
+
for i, text in enumerate(lyrics):
|
134 |
+
# 提取中括号内的内容
|
135 |
+
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
136 |
+
bracket_content = bracket_content.split(',')
|
137 |
+
# 提取小括号内的内容
|
138 |
+
parentheses_content = re.findall(r'\((.*?)\)', text)
|
139 |
+
# 提取其他内容
|
140 |
+
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
141 |
+
|
142 |
+
# 数据怎么处理?
|
143 |
+
if i<10 and check_lryics(other_content):
|
144 |
+
continue
|
145 |
+
timestamp_part.append(float(bracket_content[0])/1000)
|
146 |
+
lyric_part.append(other_content)
|
147 |
+
return timestamp_part, lyric_part
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class WYYSongDataset(Dataset):
|
152 |
+
def __init__(self,
|
153 |
+
metadata_path: Union[str, List[str]],
|
154 |
+
sr:int = 0,
|
155 |
+
use_lang = ['en', 'zh-cn'],
|
156 |
+
num_examples = -1,
|
157 |
+
max_dur = 20,
|
158 |
+
min_dur=0,
|
159 |
+
add_music=False,
|
160 |
+
pad_to_max= True,
|
161 |
+
):
|
162 |
+
|
163 |
+
self.sr = sr
|
164 |
+
self.use_lang = use_lang
|
165 |
+
self.data = []
|
166 |
+
if type(metadata_path) == str:
|
167 |
+
metadata_path = [metadata_path]
|
168 |
+
for _meta in metadata_path:
|
169 |
+
self._load_metadata(_meta)
|
170 |
+
self.max_dur = max_dur
|
171 |
+
self.min_dur = min_dur
|
172 |
+
self.pad_to_max = pad_to_max
|
173 |
+
self.add_music = add_music
|
174 |
+
|
175 |
+
# buffer
|
176 |
+
self.lyric_buffer = {}
|
177 |
+
|
178 |
+
if(num_examples<=0):
|
179 |
+
self.dataset_len = len(self.data)
|
180 |
+
self.random_slc = False
|
181 |
+
else:
|
182 |
+
self.dataset_len = num_examples
|
183 |
+
self.random_slc = True
|
184 |
+
|
185 |
+
|
186 |
+
# 读取jsonl文件
|
187 |
+
def _load_metadata(self, metadata_path):
|
188 |
+
with open(metadata_path) as fp:
|
189 |
+
lines = fp.readlines()
|
190 |
+
for line in lines:
|
191 |
+
item = json.loads(line)
|
192 |
+
if '伴奏' not in item['path']:
|
193 |
+
# if "lang_type" in item and item['lang_type'] == 'en':
|
194 |
+
if "lang_type" in item:
|
195 |
+
self.data.append(item)
|
196 |
+
|
197 |
+
|
198 |
+
def __len__(self):
|
199 |
+
return self.dataset_len
|
200 |
+
|
201 |
+
|
202 |
+
def __getitem__(self, idx):
|
203 |
+
try_cnt = 0
|
204 |
+
while True:
|
205 |
+
if(self.random_slc):
|
206 |
+
idx = np.random.randint(0, len(self.data))
|
207 |
+
yrc_lyrics = []
|
208 |
+
lrc_lyrics = []
|
209 |
+
try:
|
210 |
+
info = self.data[idx]
|
211 |
+
|
212 |
+
# audio path
|
213 |
+
path = info["path"]
|
214 |
+
lang_type = info["lang_type"]
|
215 |
+
lyrics = info['lyrics'] # chinese
|
216 |
+
# lyrics = info['lyrics_phone']
|
217 |
+
|
218 |
+
# 随机选取一个lyric段落
|
219 |
+
|
220 |
+
parsed_lyrics = []
|
221 |
+
# st_idx = np.random.randint(0, len(lyrics))
|
222 |
+
for ly_id in range(len(lyrics)):
|
223 |
+
lyric = lyrics[ly_id].strip()
|
224 |
+
st, et, lyric = self.parse_lyric(lyric)
|
225 |
+
|
226 |
+
if et - st >= self.max_dur:
|
227 |
+
continue #TODO 前后外沿 [MUSIC]
|
228 |
+
|
229 |
+
if parsed_lyrics != []:
|
230 |
+
if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap
|
231 |
+
parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]'))
|
232 |
+
elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN:
|
233 |
+
parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]'))
|
234 |
+
|
235 |
+
lyric = lyric.replace("\xa0", " ")
|
236 |
+
lyric = " ".join(lyric.split())
|
237 |
+
parsed_lyrics.append((st, et, lyric))
|
238 |
+
|
239 |
+
assert parsed_lyrics != []
|
240 |
+
# if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur:
|
241 |
+
# print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a'))
|
242 |
+
|
243 |
+
parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics
|
244 |
+
|
245 |
+
possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]']
|
246 |
+
st_idx = np.random.choice(possible_starts)
|
247 |
+
|
248 |
+
paraphrase = []
|
249 |
+
for i in parsed_lyrics[st_idx+1:]:
|
250 |
+
if i[2] == '[GAP]':
|
251 |
+
break
|
252 |
+
paraphrase.append(i)
|
253 |
+
# print(paraphrase, lyrics)
|
254 |
+
|
255 |
+
while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur:
|
256 |
+
if np.random.rand() > 0.2:
|
257 |
+
paraphrase.pop(-1) # 大概率从后面截断
|
258 |
+
else:
|
259 |
+
paraphrase.pop(0) # 小概率截前面
|
260 |
+
|
261 |
+
st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP]
|
262 |
+
# print(st, et, lyric)
|
263 |
+
# import pdb; pdb.set_trace()
|
264 |
+
assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}"
|
265 |
+
# print(et-st, lyric)
|
266 |
+
# import pdb; pdb.set_trace()
|
267 |
+
|
268 |
+
if info["lang_type"] == 'en':
|
269 |
+
# print(len(lyric.split())/(et-st))
|
270 |
+
char_num = sum([len(lrc[-1].split()) for lrc in paraphrase])
|
271 |
+
assert 6 > char_num / (et-st) > 1
|
272 |
+
else:
|
273 |
+
# print(len(lyric.split())/(et-st))
|
274 |
+
char_num = sum([len(lrc[-1]) for lrc in paraphrase])
|
275 |
+
assert 6 > char_num / (et-st) > 1
|
276 |
+
|
277 |
+
# 读取音频文件
|
278 |
+
cur_sample_rate = torchaudio.info(path).sample_rate
|
279 |
+
offset = int(cur_sample_rate*st)
|
280 |
+
num_frames = int(cur_sample_rate * (et -st))
|
281 |
+
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
282 |
+
# chunk = torch.zeros(1, 48000*15)
|
283 |
+
if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致
|
284 |
+
print(f"fail to load {path} from {st} to {et} !")
|
285 |
+
raise FileNotFoundError
|
286 |
+
# 随机选取一个channel
|
287 |
+
if(chunk.shape[0]>1):
|
288 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
289 |
+
else:
|
290 |
+
chunk = chunk[[0],:].float()
|
291 |
+
|
292 |
+
if(cur_sample_rate!=self.sr):
|
293 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
294 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
295 |
+
|
296 |
+
if self.pad_to_max:
|
297 |
+
chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
|
298 |
+
|
299 |
+
# print(self.sz_cnt)
|
300 |
+
return chunk, lyric, [st, et], path, lang_type
|
301 |
+
except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok
|
302 |
+
# print("Error loadding ", info["path"])
|
303 |
+
try_cnt += 1
|
304 |
+
idx = np.random.randint(0, len(self.data))
|
305 |
+
if(try_cnt>100):
|
306 |
+
raise e
|
307 |
+
|
308 |
+
def parse_lyric(self, lyric):
|
309 |
+
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
310 |
+
match = re.search(pattern, lyric)
|
311 |
+
|
312 |
+
start_time = float(match.group(1))
|
313 |
+
end_time = float(match.group(2))
|
314 |
+
content = match.group(3)
|
315 |
+
return start_time, end_time, content
|
316 |
+
|
317 |
+
def pad_2d_tensor(self, x, max_len, pad_id):
|
318 |
+
# 获取输入 tensor 的形状
|
319 |
+
batch_size, seq_len = x.size()
|
320 |
+
max_len = max(max_len, seq_len)
|
321 |
+
# 计算需要填充的长度
|
322 |
+
pad_len = max_len - seq_len
|
323 |
+
|
324 |
+
# 如果需要填充
|
325 |
+
if pad_len > 0:
|
326 |
+
# 创建填充 tensor
|
327 |
+
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
328 |
+
|
329 |
+
# 沿第二个维度(列)连接输入 tensor 和填充 tensor
|
330 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
331 |
+
else:
|
332 |
+
# 如果不需要填充,直接返回输入 tensor
|
333 |
+
padded_tensor = x
|
334 |
+
|
335 |
+
return padded_tensor
|
336 |
+
|
337 |
+
def collect_data(data_list):
|
338 |
+
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
339 |
+
lyrics = [data[1] for data in data_list]
|
340 |
+
st_et = [data[2] for data in data_list]
|
341 |
+
paths = [data[3] for data in data_list]
|
342 |
+
lang_types = [data[4] for data in data_list]
|
343 |
+
return audios, lyrics, st_et
|
344 |
+
# return audios, lyrics, st_et
|
345 |
+
|
346 |
+
|
347 |
+
def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False):
|
348 |
+
print(min_dur,max_dur)
|
349 |
+
print(train_jsonl_list)
|
350 |
+
# ["exp/wyy3_20240418_v2f.jsonl",
|
351 |
+
# "exp/tme_lyric_baokuan.jsonl"]
|
352 |
+
train_dataset = WYYSongDataset(
|
353 |
+
metadata_path = train_jsonl_list,
|
354 |
+
sr = 48000,
|
355 |
+
use_lang = ['zh-cn', 'en'],
|
356 |
+
num_examples = 10*10000,
|
357 |
+
min_dur=min_dur,
|
358 |
+
max_dur=max_dur,
|
359 |
+
add_music=add_music
|
360 |
+
)
|
361 |
+
|
362 |
+
valid_dataset = WYYSongDataset(
|
363 |
+
metadata_path = val_jsonl_list,
|
364 |
+
sr = 48000,
|
365 |
+
use_lang = ['zh-cn', 'en'],
|
366 |
+
num_examples = 500,
|
367 |
+
min_dur=min_dur,
|
368 |
+
max_dur=max_dur,
|
369 |
+
add_music=add_music
|
370 |
+
)
|
371 |
+
print(train_jsonl_list, "\t total_song = ", len(train_dataset.data))
|
372 |
+
return train_dataset, valid_dataset
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py
ADDED
@@ -0,0 +1,830 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
|
3 |
+
from beartype import beartype
|
4 |
+
from beartype.door import is_bearable
|
5 |
+
import random
|
6 |
+
import pandas as pd
|
7 |
+
import os
|
8 |
+
from torchaudio.functional import resample
|
9 |
+
import torch
|
10 |
+
import typing as tp
|
11 |
+
from pathlib import Path
|
12 |
+
import torchaudio as ta
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import json
|
16 |
+
import yaml
|
17 |
+
import torchaudio
|
18 |
+
import math
|
19 |
+
import re
|
20 |
+
from loguru import logger
|
21 |
+
|
22 |
+
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
23 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
24 |
+
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.n_samples = n_samples
|
28 |
+
self.sample_rate = sample_rate
|
29 |
+
self.randomize = randomize
|
30 |
+
|
31 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
32 |
+
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
33 |
+
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
34 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
35 |
+
t_start = 0.
|
36 |
+
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
37 |
+
offset = 0
|
38 |
+
# print('c1:',chunk.shape)
|
39 |
+
else:
|
40 |
+
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
41 |
+
t_start = offset / float(cur_sample_rate) / duration
|
42 |
+
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
43 |
+
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
44 |
+
# print('offset:',offset)
|
45 |
+
# print('c0:',chunk.shape)
|
46 |
+
# Pad with silence if necessary.
|
47 |
+
if(chunk.shape[0]>1):
|
48 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
49 |
+
else:
|
50 |
+
chunk = chunk[[0],:].float()
|
51 |
+
if(cur_sample_rate!=self.sample_rate):
|
52 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
53 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
54 |
+
# print('b:',self.sample_rate,chunk.shape)
|
55 |
+
if chunk.shape[-1] < self.n_samples:
|
56 |
+
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
57 |
+
else:
|
58 |
+
chunk = chunk[:,0:self.n_samples]
|
59 |
+
seconds_start = math.floor(offset / cur_sample_rate)
|
60 |
+
seconds_total = math.floor(duration)
|
61 |
+
|
62 |
+
return (
|
63 |
+
chunk,
|
64 |
+
t_start,
|
65 |
+
t_end,
|
66 |
+
seconds_start,
|
67 |
+
seconds_total
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
72 |
+
if USE_DUMMY_AUDIO:
|
73 |
+
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
74 |
+
|
75 |
+
class SafeAudioReader:
|
76 |
+
"""
|
77 |
+
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
78 |
+
"""
|
79 |
+
def __init__(self,
|
80 |
+
duration: float, # 返回音频长度
|
81 |
+
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
82 |
+
randomize: bool = True
|
83 |
+
):
|
84 |
+
self.n_samples = int(sample_rate * max(duration, 0))
|
85 |
+
self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
86 |
+
|
87 |
+
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
88 |
+
def __call__(self,
|
89 |
+
filepath: os.PathLike, # 音频路径
|
90 |
+
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
91 |
+
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
92 |
+
) -> torch.Tensor:
|
93 |
+
if USE_DUMMY_AUDIO:
|
94 |
+
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
95 |
+
return wav
|
96 |
+
try:
|
97 |
+
if origin_sample_rate is None or origin_duration is None:
|
98 |
+
audio_info = torchaudio.info(filepath)
|
99 |
+
origin_sample_rate = audio_info.sample_rate
|
100 |
+
origin_duration = audio_info.num_frames / origin_sample_rate
|
101 |
+
wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f"Error reading {filepath}: {e}")
|
104 |
+
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
105 |
+
return wav
|
106 |
+
|
107 |
+
|
108 |
+
class PromptTemplate:
|
109 |
+
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
110 |
+
self.template_text = template_text
|
111 |
+
self.tag_map = tag_map
|
112 |
+
self.lang = lang
|
113 |
+
|
114 |
+
@property
|
115 |
+
def tags(self):
|
116 |
+
return tuple(self.tag_map.keys())
|
117 |
+
|
118 |
+
def apply(self, **kwargs):
|
119 |
+
for tag in list(kwargs.keys()):
|
120 |
+
if kwargs[tag] == '':
|
121 |
+
kwargs.pop(tag)
|
122 |
+
for tag in self.tags:
|
123 |
+
if tag in kwargs:
|
124 |
+
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
125 |
+
else:
|
126 |
+
kwargs[tag] = ''
|
127 |
+
prompt = self.template_text.format(**kwargs)
|
128 |
+
|
129 |
+
return self.beautify(prompt)
|
130 |
+
|
131 |
+
def beautify(self, text):
|
132 |
+
if self.lang == 'en':
|
133 |
+
return self._beautify_en(text)
|
134 |
+
elif self.lang == 'zh':
|
135 |
+
return self._beautify_zh(text)
|
136 |
+
else:
|
137 |
+
raise ValueError(f'Unknown language {self.lang}')
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def _beautify_en(text):
|
141 |
+
# no continuous commas without content between them
|
142 |
+
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
143 |
+
# no continuous whitespace
|
144 |
+
text = re.sub(r'\s+', ' ', text)
|
145 |
+
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
146 |
+
text = re.sub(r'\s+,', r',', text)
|
147 |
+
text = re.sub(r',\s+', r', ', text)
|
148 |
+
# no whitespace before the full stop
|
149 |
+
text = re.sub(r'\s+\.', r'.', text)
|
150 |
+
# strip whitespace, comma, and replace ',.'
|
151 |
+
text = text.strip(' ,')
|
152 |
+
text = text.replace(',.', '.')
|
153 |
+
return text
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def _beautify_zh(text):
|
157 |
+
# no continuous commas without content between them
|
158 |
+
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
159 |
+
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
160 |
+
# assume there should be NO whitespace in Chinese
|
161 |
+
text = re.sub(r'\s+', r'', text)
|
162 |
+
# strip whitespace, comma, and replace ',。'
|
163 |
+
text = text.strip(', 、')
|
164 |
+
text = text.replace(',。', '。')
|
165 |
+
return text
|
166 |
+
|
167 |
+
def __repr__(self):
|
168 |
+
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
169 |
+
|
170 |
+
__str__ = __repr__
|
171 |
+
|
172 |
+
def parse_prompt_template(prompt_template_text, lang='en'):
|
173 |
+
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
174 |
+
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
175 |
+
|
176 |
+
template_text = prompt_template_text.strip()
|
177 |
+
span_texts = span_pattern.findall(prompt_template_text)
|
178 |
+
tag_map = {}
|
179 |
+
for span_text in span_texts:
|
180 |
+
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
181 |
+
tag_map[tag] = span_text
|
182 |
+
template_text = template_text.replace(span_text, '{'+tag+'}')
|
183 |
+
|
184 |
+
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
185 |
+
|
186 |
+
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
187 |
+
with open(path, 'r') as f:
|
188 |
+
lines = f.readlines()
|
189 |
+
cnt = 0
|
190 |
+
pts = []
|
191 |
+
for line in lines:
|
192 |
+
pt = parse_prompt_template(line, lang=lang)
|
193 |
+
cnt += 1
|
194 |
+
if len(pt.tags) < num:
|
195 |
+
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
196 |
+
pts.append(pt)
|
197 |
+
|
198 |
+
return pts
|
199 |
+
|
200 |
+
|
201 |
+
def get_base_dir_file(key: os.PathLike):
|
202 |
+
base = os.path.basename(key)
|
203 |
+
dirname = os.path.basename(os.path.dirname(key))
|
204 |
+
return os.path.join(dirname, base)
|
205 |
+
|
206 |
+
def read_jsonlike(path: os.PathLike):
|
207 |
+
#json or jsonl
|
208 |
+
if str(path).endswith(".json"):
|
209 |
+
with open(path, 'r', encoding='utf8') as f:
|
210 |
+
data = json.load(f)
|
211 |
+
return data
|
212 |
+
elif str(path).endswith(".jsonl"):
|
213 |
+
with open(path, 'r', encoding='utf8') as f:
|
214 |
+
data = [json.loads(line) for line in f.readlines()]
|
215 |
+
return data
|
216 |
+
else:
|
217 |
+
raise ValueError("Unknown file format")
|
218 |
+
|
219 |
+
dist_prob_map = {
|
220 |
+
1: (1.0,),
|
221 |
+
2: (0.5, 0.5),
|
222 |
+
3: (0.3, 0.4, 0.3),
|
223 |
+
4: (0.2, 0.3, 0.3, 0.2),
|
224 |
+
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
225 |
+
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
226 |
+
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
227 |
+
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
228 |
+
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
229 |
+
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
230 |
+
}
|
231 |
+
|
232 |
+
dist_prob_map_low = {
|
233 |
+
1: (1.0,),
|
234 |
+
2: (0.8, 0.2),
|
235 |
+
3: (0.8, 0.1, 0.1),
|
236 |
+
4: (0.7, 0.1, 0.1, 0.1),
|
237 |
+
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
238 |
+
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
239 |
+
}
|
240 |
+
|
241 |
+
_bpm_range_rights = (
|
242 |
+
(40, '20-40'),
|
243 |
+
(60, '40-60'),
|
244 |
+
(66, '60-66'),
|
245 |
+
(76, '66-76'),
|
246 |
+
(108, '76-108'),
|
247 |
+
(120, '108-120'),
|
248 |
+
(168, '120-168'),
|
249 |
+
(176, '168-176'),
|
250 |
+
(200, '176-200')
|
251 |
+
)
|
252 |
+
_bpm_desc_map = {
|
253 |
+
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
254 |
+
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
255 |
+
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
256 |
+
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
257 |
+
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
258 |
+
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
259 |
+
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
260 |
+
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
261 |
+
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
262 |
+
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
263 |
+
}
|
264 |
+
_bpm_desc_map_zh = {
|
265 |
+
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
266 |
+
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
267 |
+
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
268 |
+
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
269 |
+
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
270 |
+
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
271 |
+
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
272 |
+
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
273 |
+
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
274 |
+
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
275 |
+
}
|
276 |
+
def get_bpm_range(bpm):
|
277 |
+
bpm = int(bpm)
|
278 |
+
for right, tag in _bpm_range_rights:
|
279 |
+
if bpm <= right:
|
280 |
+
return tag
|
281 |
+
return '>200'
|
282 |
+
|
283 |
+
def gen_bpm_descript(bpm, lang='en'):
|
284 |
+
bpm_range = get_bpm_range(bpm)
|
285 |
+
if lang == 'en':
|
286 |
+
return random.choice(_bpm_desc_map[bpm_range])
|
287 |
+
elif lang == 'zh':
|
288 |
+
return random.choice(_bpm_desc_map_zh[bpm_range])
|
289 |
+
else:
|
290 |
+
raise ValueError(f"Unknown language {lang}")
|
291 |
+
|
292 |
+
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
293 |
+
if translate is None:
|
294 |
+
return None
|
295 |
+
return {k: read_jsonlike(path) for k, path in translate.items()}
|
296 |
+
|
297 |
+
|
298 |
+
class MagnaTagATuneDataset(Dataset):
|
299 |
+
def __init__(self):
|
300 |
+
pass
|
301 |
+
|
302 |
+
|
303 |
+
def tags_to_desc(tag_list, sep=',') -> str:
|
304 |
+
if not isinstance(tag_list, Sequence):
|
305 |
+
return str(tag_list)
|
306 |
+
if isinstance(tag_list, str):
|
307 |
+
return tag_list
|
308 |
+
if len(tag_list) <= 0:
|
309 |
+
return ''
|
310 |
+
elif len(tag_list) <= 5:
|
311 |
+
probs = dist_prob_map[len(tag_list)]
|
312 |
+
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
313 |
+
random.shuffle(tag_list)
|
314 |
+
tag_list = tag_list[:tags_num]
|
315 |
+
return sep.join(tag_list)
|
316 |
+
else:
|
317 |
+
probs = dist_prob_map[5]
|
318 |
+
tags_num = random.choices(range(1, 6), probs)[0]
|
319 |
+
random.shuffle(tag_list)
|
320 |
+
tag_list = tag_list[:tags_num]
|
321 |
+
return sep.join(tag_list)
|
322 |
+
|
323 |
+
def get_sr_and_duration_info(item):
|
324 |
+
return item.get('sample_rate', None), item.get('duration', None)
|
325 |
+
|
326 |
+
class MtgJamendoDatasetFromJson(Dataset):
|
327 |
+
def __init__(self,
|
328 |
+
data_dir:str,
|
329 |
+
json_path:str,
|
330 |
+
duration:float=10,
|
331 |
+
sr:int = 0,
|
332 |
+
*,
|
333 |
+
lang = 'en',
|
334 |
+
return_path = False,
|
335 |
+
prompt_template_path: os.PathLike = None,
|
336 |
+
tag_types = [],
|
337 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
338 |
+
):
|
339 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
340 |
+
|
341 |
+
self.data_dir = data_dir
|
342 |
+
self._load_metadata_json(json_path)
|
343 |
+
self.sr = sr
|
344 |
+
self.duration = duration
|
345 |
+
self.return_path = return_path
|
346 |
+
self.lang = lang
|
347 |
+
|
348 |
+
self.use_dynamic_prompt = prompt_template_path is not None
|
349 |
+
if self.use_dynamic_prompt:
|
350 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
351 |
+
self.tag_types = tag_types
|
352 |
+
|
353 |
+
self.translate = read_translate(translate)
|
354 |
+
if not self.use_dynamic_prompt and self.lang != 'en':
|
355 |
+
raise NotImplementedError
|
356 |
+
|
357 |
+
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
358 |
+
WEAK_TAG_LIST = ["title", "artist"]
|
359 |
+
|
360 |
+
def _load_metadata_json(self, json_path):
|
361 |
+
with open(json_path) as fp:
|
362 |
+
self.data = json.load(fp)
|
363 |
+
|
364 |
+
def convert_key_to_path(self, key):
|
365 |
+
return os.path.join(self.data_dir, get_base_dir_file(key))
|
366 |
+
|
367 |
+
def __len__(self):
|
368 |
+
return len(self.data)
|
369 |
+
|
370 |
+
def __getitem__(self, idx):
|
371 |
+
item = self.data[idx]
|
372 |
+
path = self.convert_key_to_path(item['key'])
|
373 |
+
description = self.generate_description(item)
|
374 |
+
|
375 |
+
sr, duration = get_sr_and_duration_info(item)
|
376 |
+
audio = self.audio_reader(path, sr, duration)
|
377 |
+
|
378 |
+
if self.return_path:
|
379 |
+
return audio, description, path
|
380 |
+
return audio, description
|
381 |
+
|
382 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
383 |
+
if self.lang == 'en':
|
384 |
+
return tags_to_desc(tag_list)
|
385 |
+
elif self.lang == 'zh':
|
386 |
+
translator = self.translate[tag_type]
|
387 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
388 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
389 |
+
|
390 |
+
def generate_description(self, item):
|
391 |
+
if self.use_dynamic_prompt:
|
392 |
+
# dynamically generate prompt from given prompt template
|
393 |
+
prompt_template = random.choice(self.prompt_templates)
|
394 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
395 |
+
|
396 |
+
else:
|
397 |
+
# use ordinary static prompt instead
|
398 |
+
description = self.generate_description_ordinary(item)
|
399 |
+
return description
|
400 |
+
|
401 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
402 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
403 |
+
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
404 |
+
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
405 |
+
|
406 |
+
if len(exists_strong_tag) > 0:
|
407 |
+
probs = dist_prob_map[len(exists_strong_tag)]
|
408 |
+
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
409 |
+
random.shuffle(exists_strong_tag)
|
410 |
+
tags = exists_strong_tag[:tags_num]
|
411 |
+
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
412 |
+
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
413 |
+
random.shuffle(exists_weak_tag)
|
414 |
+
weak_tags = exists_weak_tag[:weak_tags_num]
|
415 |
+
tags += weak_tags
|
416 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
417 |
+
prompt = prompt_template.apply(**tags_args)
|
418 |
+
else:
|
419 |
+
# no strong tags, use all weak tags instead
|
420 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
421 |
+
prompt = prompt_template.apply(**tags_args)
|
422 |
+
|
423 |
+
return prompt
|
424 |
+
|
425 |
+
def generate_description_ordinary(self, data, thresh = 0.3):
|
426 |
+
# Initialize the description with title and artist
|
427 |
+
description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
|
428 |
+
|
429 |
+
# Add genre if available
|
430 |
+
if data["genre"] and random.random() > thresh:
|
431 |
+
genres = ', '.join(data["genre"])
|
432 |
+
description += f', belonging to the {genres} genres'
|
433 |
+
|
434 |
+
# Add moods if available
|
435 |
+
if data["moods"] and random.random() > thresh:
|
436 |
+
moods = ', '.join(data["moods"])
|
437 |
+
description += f'. This track conveys a {moods} mood'
|
438 |
+
|
439 |
+
# Add instruments if available
|
440 |
+
if data["instrument"] and random.random() > thresh:
|
441 |
+
instruments = ', '.join(data["instrument"])
|
442 |
+
description += f', and primarily features the following instruments: {instruments}'
|
443 |
+
|
444 |
+
# Add a period to end the description
|
445 |
+
description += '.'
|
446 |
+
|
447 |
+
return description
|
448 |
+
|
449 |
+
class AudioStockDataset(Dataset):
|
450 |
+
def __init__(self,
|
451 |
+
metadata_path:str,
|
452 |
+
duration:float=10,
|
453 |
+
sr:int = 0,
|
454 |
+
return_path = False,
|
455 |
+
return_audio = True,
|
456 |
+
prompt_template_path: os.PathLike = None,
|
457 |
+
tag_types = [],
|
458 |
+
lang = 'en',
|
459 |
+
translate:Optional[Dict[str, os.PathLike]] = None
|
460 |
+
):
|
461 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
462 |
+
|
463 |
+
self._load_metadata(metadata_path)
|
464 |
+
self.sr = sr
|
465 |
+
self.duration = duration
|
466 |
+
self.return_path = return_path
|
467 |
+
self.return_audio = return_audio
|
468 |
+
|
469 |
+
self.use_dynamic_prompt = prompt_template_path is not None
|
470 |
+
if self.use_dynamic_prompt:
|
471 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
472 |
+
self.tag_types = tag_types
|
473 |
+
|
474 |
+
self.lang = lang
|
475 |
+
self.translate = read_translate(translate)
|
476 |
+
|
477 |
+
def _load_metadata(self, metadata_path):
|
478 |
+
with open(metadata_path) as fp:
|
479 |
+
lines = fp.readlines()
|
480 |
+
self.data = []
|
481 |
+
for line in lines:
|
482 |
+
item = json.loads(line)
|
483 |
+
self.data.append(item)
|
484 |
+
self.is_info_recorded = bool('Tags' in self.data[0])
|
485 |
+
|
486 |
+
def __len__(self):
|
487 |
+
return len(self.data)
|
488 |
+
|
489 |
+
def __getitem__(self, idx):
|
490 |
+
path:str = self.data[idx]["path"]
|
491 |
+
json_path = path[:path.rfind('.')] + ".json"
|
492 |
+
if self.is_info_recorded:
|
493 |
+
item = self.data[idx]
|
494 |
+
else:
|
495 |
+
try:
|
496 |
+
with open(json_path) as fp:
|
497 |
+
item:dict = json.load(fp)
|
498 |
+
except Exception as e:
|
499 |
+
print(f"Error loading json file {json_path} :\n{e}")
|
500 |
+
item = {}
|
501 |
+
description = self.generate_description(item)
|
502 |
+
if self.return_audio:
|
503 |
+
sr, duration = get_sr_and_duration_info(item)
|
504 |
+
audio = self.audio_reader(path, sr, duration)
|
505 |
+
else:
|
506 |
+
audio = None
|
507 |
+
if self.return_path:
|
508 |
+
return audio, description, path
|
509 |
+
return audio, description
|
510 |
+
|
511 |
+
def generate_description(self, item):
|
512 |
+
if self.use_dynamic_prompt:
|
513 |
+
# dynamically generate prompt from given prompt template
|
514 |
+
prompt_template = random.choice(self.prompt_templates)
|
515 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
516 |
+
else:
|
517 |
+
# use ordinary static prompt instead
|
518 |
+
description = self.generate_description_ordinary(item)
|
519 |
+
return description
|
520 |
+
|
521 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
522 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
523 |
+
|
524 |
+
if len(exists_tag) > 0:
|
525 |
+
probs = dist_prob_map[len(exists_tag)]
|
526 |
+
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
527 |
+
random.shuffle(exists_tag)
|
528 |
+
tags = exists_tag[:tags_num]
|
529 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
530 |
+
tags_args = self.handle_BPM_tag(tags_args)
|
531 |
+
prompt = prompt_template.apply(**tags_args)
|
532 |
+
else:
|
533 |
+
# no strong tags, use all weak tags instead
|
534 |
+
prompt = prompt_template.apply()
|
535 |
+
|
536 |
+
return prompt
|
537 |
+
|
538 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
539 |
+
if self.lang == 'en':
|
540 |
+
return tags_to_desc(tag_list)
|
541 |
+
elif self.lang == 'zh':
|
542 |
+
if tag_type == 'BPM':
|
543 |
+
return tags_to_desc(tag_list, sep='、')
|
544 |
+
translator = self.translate[tag_type]
|
545 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
546 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
547 |
+
|
548 |
+
def handle_BPM_tag(self, tags_args):
|
549 |
+
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
550 |
+
bpm = tags_args["BPM"]
|
551 |
+
del tags_args["BPM"]
|
552 |
+
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
553 |
+
for tag_type in tag_types_used:
|
554 |
+
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
555 |
+
return tags_args
|
556 |
+
|
557 |
+
def generate_description_ordinary(self, data, thresh = 0.3):
|
558 |
+
if self.lang != 'en':
|
559 |
+
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
560 |
+
description = f'a piece of music by {data["Artist"]}'
|
561 |
+
|
562 |
+
# Add genre if available
|
563 |
+
if data["Genre"] and random.random() > thresh:
|
564 |
+
genres = ', '.join(data["Genre"])
|
565 |
+
description += f', belonging to the {genres} genres'
|
566 |
+
|
567 |
+
# Add moods if available
|
568 |
+
if data["Tags"] and random.random() > thresh:
|
569 |
+
tags = ', '.join(data["Tags"])
|
570 |
+
description += f'. This track contains the tags:{tags}'
|
571 |
+
|
572 |
+
# Add moods if available
|
573 |
+
if data["Mood"] and random.random() > thresh:
|
574 |
+
moods = ', '.join(data["Mood"])
|
575 |
+
description += f'. This track conveys a {moods} mood.'
|
576 |
+
|
577 |
+
# Add instruments if available
|
578 |
+
if data["Instrument"] and random.random() > thresh:
|
579 |
+
instruments = ', '.join(data["Instrument"])
|
580 |
+
description += f'. and primarily features the following instruments: {instruments}'
|
581 |
+
|
582 |
+
# Add a period to end the description
|
583 |
+
description += '.'
|
584 |
+
|
585 |
+
return description
|
586 |
+
|
587 |
+
def mp3_path_to_id(mp3_path):
|
588 |
+
return int(
|
589 |
+
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
|
590 |
+
)
|
591 |
+
|
592 |
+
class TmeDataset(Dataset):
|
593 |
+
def __init__(self,
|
594 |
+
data_index:str,
|
595 |
+
music_info:str = None,
|
596 |
+
duration:float = 10,
|
597 |
+
sr:int = 0,
|
598 |
+
return_path = False,
|
599 |
+
return_audio = True,
|
600 |
+
prompt_format_path: os.PathLike = None,
|
601 |
+
tag_types = ['*'],
|
602 |
+
lang = 'zh',
|
603 |
+
translate: Optional[os.PathLike] = None,
|
604 |
+
prompt_dir: os.PathLike = None,
|
605 |
+
):
|
606 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
607 |
+
|
608 |
+
self.sr = sr
|
609 |
+
self.duration = duration
|
610 |
+
self.return_path = return_path
|
611 |
+
self.return_audio = return_audio
|
612 |
+
self.lang = lang
|
613 |
+
|
614 |
+
self.use_ready_prompt = prompt_dir is not None
|
615 |
+
|
616 |
+
data_index = read_jsonlike(data_index)
|
617 |
+
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
618 |
+
self.data_ids = list(self.data_index_dict.keys())
|
619 |
+
|
620 |
+
if not self.use_ready_prompt:
|
621 |
+
#读取音乐的信息文件
|
622 |
+
music_info = read_jsonlike(music_info)
|
623 |
+
if 'music' in music_info:
|
624 |
+
music_info = music_info['music']
|
625 |
+
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
626 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
627 |
+
self.data_ids = list(self.data_index_dict.keys())
|
628 |
+
|
629 |
+
with open(prompt_format_path) as fp:
|
630 |
+
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
631 |
+
|
632 |
+
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
633 |
+
if '*' in tag_types:
|
634 |
+
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
635 |
+
else:
|
636 |
+
self.tag_types = tag_types
|
637 |
+
|
638 |
+
self.key_tag_types = []
|
639 |
+
if 'tag' in self.tag_types:
|
640 |
+
self.tag_types.remove('tag')
|
641 |
+
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
642 |
+
|
643 |
+
#加载translate翻译
|
644 |
+
if translate is not None:
|
645 |
+
self.translator = read_jsonlike(translate)
|
646 |
+
else:
|
647 |
+
data_ids_set = set(self.data_ids)
|
648 |
+
self.prompts_dict = {}
|
649 |
+
for fname in os.listdir(prompt_dir):
|
650 |
+
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
651 |
+
for item in items:
|
652 |
+
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
653 |
+
continue
|
654 |
+
if item['ID'] not in self.prompts_dict:
|
655 |
+
self.prompts_dict[item['ID']] = []
|
656 |
+
self.prompts_dict[item['ID']].append(item['Text'])
|
657 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
658 |
+
self.data_ids = list(self.data_index_dict.keys())
|
659 |
+
|
660 |
+
def tags_to_desc(self, tag_list) -> str:
|
661 |
+
if is_bearable(tag_list, int):
|
662 |
+
return str(tag_list)
|
663 |
+
if self.lang == 'zh':
|
664 |
+
return tags_to_desc(tag_list, sep=self.sep)
|
665 |
+
else:
|
666 |
+
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
667 |
+
return tags_to_desc(translated_tag_list, sep=self.sep)
|
668 |
+
|
669 |
+
def gen_desc_of_tag(self, formats, tags):
|
670 |
+
fmt = random.choice(formats)
|
671 |
+
return fmt.format(self.tags_to_desc(tags))
|
672 |
+
|
673 |
+
@staticmethod
|
674 |
+
def check_valid(value):
|
675 |
+
if isinstance(value, int) or isinstance(value, float):
|
676 |
+
return value > 0
|
677 |
+
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
678 |
+
return True
|
679 |
+
return False
|
680 |
+
|
681 |
+
@staticmethod
|
682 |
+
def remove_repeat(data):
|
683 |
+
#若专辑名和歌曲名相同,则只使用后者
|
684 |
+
album_name = data.get('专辑名', None)
|
685 |
+
if album_name is not None and album_name == data.get('歌曲名', None):
|
686 |
+
del data['专辑名']
|
687 |
+
return data
|
688 |
+
|
689 |
+
@property
|
690 |
+
def comma(self):
|
691 |
+
if self.lang == 'zh':
|
692 |
+
return ','
|
693 |
+
elif self.lang == 'en':
|
694 |
+
return ', '
|
695 |
+
|
696 |
+
@property
|
697 |
+
def sep(self):
|
698 |
+
if self.lang == 'zh':
|
699 |
+
return '、'
|
700 |
+
elif self.lang == 'en':
|
701 |
+
return ', '
|
702 |
+
|
703 |
+
def generate_description(self, data):
|
704 |
+
data = self.remove_repeat(data)
|
705 |
+
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
706 |
+
|
707 |
+
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
708 |
+
|
709 |
+
prompts = []
|
710 |
+
if len(weak_tags) > 0:
|
711 |
+
probs = dist_prob_map_low[len(weak_tags)]
|
712 |
+
if len(key_tags) > 0:
|
713 |
+
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
714 |
+
else:
|
715 |
+
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
716 |
+
random.shuffle(weak_tags)
|
717 |
+
tags = weak_tags[:tags_num]
|
718 |
+
for tag_type in tags:
|
719 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
720 |
+
prompts.append(tag_desc)
|
721 |
+
|
722 |
+
if len(key_tags) > 0:
|
723 |
+
probs = dist_prob_map[len(key_tags)]
|
724 |
+
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
725 |
+
random.shuffle(key_tags)
|
726 |
+
tags = key_tags[:tags_num]
|
727 |
+
for tag_type in tags:
|
728 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
729 |
+
prompts.append(tag_desc)
|
730 |
+
|
731 |
+
random.shuffle(prompts)
|
732 |
+
return self.comma.join(prompts)
|
733 |
+
|
734 |
+
def is_valid_prompt_text(self, text):
|
735 |
+
for bad in ('抱歉','sorry', 'Sorry'):
|
736 |
+
if bad in text:
|
737 |
+
return False
|
738 |
+
return True
|
739 |
+
|
740 |
+
def get_ready_prompt(self, path):
|
741 |
+
sid = mp3_path_to_id(path)
|
742 |
+
return random.choice(self.prompts_dict[sid])
|
743 |
+
|
744 |
+
def __len__(self):
|
745 |
+
return len(self.data_ids)
|
746 |
+
|
747 |
+
def __getitem__(self, idx):
|
748 |
+
data_id = self.data_ids[idx]
|
749 |
+
item = self.data_index_dict[data_id]
|
750 |
+
path = item['path']
|
751 |
+
if not self.use_ready_prompt:
|
752 |
+
info = self.music_info_dict[data_id]
|
753 |
+
description = self.generate_description(info)
|
754 |
+
else:
|
755 |
+
description = self.get_ready_prompt(path)
|
756 |
+
if self.return_audio:
|
757 |
+
sr, duration = get_sr_and_duration_info(item)
|
758 |
+
audio = self.audio_reader(path, sr, duration)
|
759 |
+
else:
|
760 |
+
audio = None
|
761 |
+
if self.return_path:
|
762 |
+
return audio, description, path
|
763 |
+
return audio, description
|
764 |
+
|
765 |
+
class CombinedDataset(Dataset):
|
766 |
+
@beartype
|
767 |
+
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
768 |
+
self.datasets = datasets
|
769 |
+
self.datasets_index = []
|
770 |
+
|
771 |
+
for i,dataset in enumerate(datasets):
|
772 |
+
if dataset is None:
|
773 |
+
continue
|
774 |
+
for dup in range(ratios[i]):
|
775 |
+
for j in range(len(dataset)):
|
776 |
+
self.datasets_index.append((i,j))
|
777 |
+
|
778 |
+
def __len__(self):
|
779 |
+
return len(self.datasets_index)
|
780 |
+
|
781 |
+
def __getitem__(self, idx):
|
782 |
+
index = self.datasets_index[idx]
|
783 |
+
i,j = index
|
784 |
+
return self.datasets[i][j]
|
785 |
+
|
786 |
+
class CombinedDataset_random(Dataset):
|
787 |
+
@beartype
|
788 |
+
def __init__(self,
|
789 |
+
num_examples:int,
|
790 |
+
datasets: Sequence[Dataset], ratios: Sequence[int]
|
791 |
+
):
|
792 |
+
self.datasets = datasets
|
793 |
+
self.datasets_index = []
|
794 |
+
|
795 |
+
for i,dataset in enumerate(datasets):
|
796 |
+
if dataset is None:
|
797 |
+
continue
|
798 |
+
for dup in range(ratios[i]):
|
799 |
+
for j in range(len(dataset)):
|
800 |
+
self.datasets_index.append((i,j))
|
801 |
+
if num_examples > 0:
|
802 |
+
self.random_choose = True
|
803 |
+
self.dataset_len = num_examples
|
804 |
+
else:
|
805 |
+
self.random_choose = False
|
806 |
+
self.dataset_len = len(self.datasets_index)
|
807 |
+
|
808 |
+
def __len__(self):
|
809 |
+
return self.dataset_len
|
810 |
+
|
811 |
+
def __getitem__(self, idx):
|
812 |
+
first_try = True
|
813 |
+
try_cnt = 0
|
814 |
+
while True:
|
815 |
+
try:
|
816 |
+
if(self.random_choose or not first_try):
|
817 |
+
index2 = []
|
818 |
+
index2.append(np.random.randint(0,len(self.datasets)))
|
819 |
+
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
820 |
+
else:
|
821 |
+
index2 = self.datasets_index[idx]
|
822 |
+
first_try = False
|
823 |
+
out = self.datasets[index2[0]][index2[1]]
|
824 |
+
if(len(out[0].shape)==1):out[0]=out[0][None,:]
|
825 |
+
return out
|
826 |
+
except:
|
827 |
+
print("Error loadding ", index2)
|
828 |
+
try_cnt += 1
|
829 |
+
if(try_cnt>10):
|
830 |
+
raise ValueError()
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py
ADDED
@@ -0,0 +1,994 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
|
3 |
+
from beartype import beartype
|
4 |
+
from beartype.door import is_bearable
|
5 |
+
import random
|
6 |
+
import pandas as pd
|
7 |
+
import os
|
8 |
+
from torchaudio.functional import resample
|
9 |
+
import torch
|
10 |
+
import typing as tp
|
11 |
+
from pathlib import Path
|
12 |
+
import torchaudio as ta
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import json
|
16 |
+
import yaml
|
17 |
+
import torchaudio
|
18 |
+
import math
|
19 |
+
import re
|
20 |
+
from loguru import logger
|
21 |
+
|
22 |
+
def gen_plain_prompt(key_list, sep=', '):
|
23 |
+
if len(key_list) == 0:
|
24 |
+
return 'none'
|
25 |
+
|
26 |
+
key_list = [k.strip() for k in key_list]
|
27 |
+
|
28 |
+
if len(key_list) > 10:
|
29 |
+
random.shuffle(key_list)
|
30 |
+
key_list = key_list[:10]
|
31 |
+
|
32 |
+
probs = dist_prob_map[len(key_list)]
|
33 |
+
|
34 |
+
num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
|
35 |
+
|
36 |
+
random.shuffle(key_list)
|
37 |
+
tags = key_list[:num_tags]
|
38 |
+
tags_str = sep.join(tags)
|
39 |
+
return tags_str
|
40 |
+
|
41 |
+
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
44 |
+
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.n_samples = n_samples
|
48 |
+
self.sample_rate = sample_rate
|
49 |
+
self.randomize = randomize
|
50 |
+
self.prob = {"is_start":0.2, "is_end":0.9}
|
51 |
+
self.shift_secs = 5
|
52 |
+
|
53 |
+
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
54 |
+
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
55 |
+
raise ValueError(duration,float(self.n_samples),self.sample_rate)
|
56 |
+
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
57 |
+
t_start = 0.
|
58 |
+
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
59 |
+
offset = 0
|
60 |
+
is_start = True
|
61 |
+
is_end = True
|
62 |
+
else:
|
63 |
+
prob = random.uniform(0,1)
|
64 |
+
if(prob<self.prob['is_start']):
|
65 |
+
is_start = True
|
66 |
+
is_end = False
|
67 |
+
offset = 0
|
68 |
+
elif(prob>self.prob['is_end']):
|
69 |
+
is_start = False
|
70 |
+
is_end = True
|
71 |
+
offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)
|
72 |
+
else:
|
73 |
+
is_start = False
|
74 |
+
is_end = False
|
75 |
+
offset = np.random.randint(self.shift_secs*cur_sample_rate, \
|
76 |
+
int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate)
|
77 |
+
t_start = offset / float(cur_sample_rate) / duration
|
78 |
+
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
79 |
+
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
80 |
+
if(chunk.shape[0]>1):
|
81 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
82 |
+
else:
|
83 |
+
chunk = chunk[[0],:].float()
|
84 |
+
if(cur_sample_rate!=self.sample_rate):
|
85 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
86 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
87 |
+
# print('b:',self.sample_rate,chunk.shape)
|
88 |
+
if chunk.shape[-1] != self.n_samples:
|
89 |
+
raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
90 |
+
# if chunk.shape[-1] < self.n_samples:
|
91 |
+
# chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
92 |
+
# else:
|
93 |
+
# chunk = chunk[:,0:self.n_samples]
|
94 |
+
seconds_start = math.floor(offset / cur_sample_rate)
|
95 |
+
seconds_total = math.floor(duration)
|
96 |
+
|
97 |
+
# # In this dataset, we do not introduce zeros
|
98 |
+
# if(is_start):
|
99 |
+
# chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples]
|
100 |
+
# elif(is_end):
|
101 |
+
# chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:]
|
102 |
+
|
103 |
+
return (
|
104 |
+
chunk,
|
105 |
+
t_start,
|
106 |
+
t_end,
|
107 |
+
seconds_start,
|
108 |
+
seconds_total,
|
109 |
+
is_start,
|
110 |
+
is_end,
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
115 |
+
if USE_DUMMY_AUDIO:
|
116 |
+
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
117 |
+
|
118 |
+
class SafeAudioReader:
|
119 |
+
"""
|
120 |
+
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
121 |
+
"""
|
122 |
+
def __init__(self,
|
123 |
+
duration: float, # 返回音频长度
|
124 |
+
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
125 |
+
randomize: bool = True
|
126 |
+
):
|
127 |
+
self.n_samples = int(sample_rate * max(duration, 0))
|
128 |
+
self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
129 |
+
|
130 |
+
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
131 |
+
def __call__(self,
|
132 |
+
filepath: os.PathLike, # 音频路径
|
133 |
+
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
134 |
+
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
135 |
+
) -> torch.Tensor:
|
136 |
+
if USE_DUMMY_AUDIO:
|
137 |
+
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
138 |
+
return wav
|
139 |
+
try:
|
140 |
+
# if origin_sample_rate is None or origin_duration is None:
|
141 |
+
# audio_info = torchaudio.info(filepath)
|
142 |
+
# origin_sample_rate = audio_info.sample_rate
|
143 |
+
# origin_duration = audio_info.num_frames / origin_sample_rate
|
144 |
+
audio_info = torchaudio.info(filepath)
|
145 |
+
origin_sample_rate = audio_info.sample_rate
|
146 |
+
origin_duration = audio_info.num_frames / origin_sample_rate
|
147 |
+
wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate)
|
148 |
+
except Exception as e:
|
149 |
+
logger.error(f"Error reading {filepath}: {e}")
|
150 |
+
raise FileNotFoundError(filepath)
|
151 |
+
return wav, is_start, is_end
|
152 |
+
|
153 |
+
|
154 |
+
class PromptTemplate:
|
155 |
+
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
156 |
+
self.template_text = template_text
|
157 |
+
self.tag_map = tag_map
|
158 |
+
self.lang = lang
|
159 |
+
|
160 |
+
@property
|
161 |
+
def tags(self):
|
162 |
+
return tuple(self.tag_map.keys())
|
163 |
+
|
164 |
+
def apply(self, **kwargs):
|
165 |
+
for tag in list(kwargs.keys()):
|
166 |
+
if kwargs[tag] == '':
|
167 |
+
kwargs.pop(tag)
|
168 |
+
for tag in self.tags:
|
169 |
+
if tag in kwargs:
|
170 |
+
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
171 |
+
else:
|
172 |
+
kwargs[tag] = ''
|
173 |
+
prompt = self.template_text.format(**kwargs)
|
174 |
+
|
175 |
+
return self.beautify(prompt)
|
176 |
+
|
177 |
+
def beautify(self, text):
|
178 |
+
if self.lang == 'en':
|
179 |
+
return self._beautify_en(text)
|
180 |
+
elif self.lang == 'zh':
|
181 |
+
return self._beautify_zh(text)
|
182 |
+
else:
|
183 |
+
raise ValueError(f'Unknown language {self.lang}')
|
184 |
+
|
185 |
+
@staticmethod
|
186 |
+
def _beautify_en(text):
|
187 |
+
# no continuous commas without content between them
|
188 |
+
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
189 |
+
# no continuous whitespace
|
190 |
+
text = re.sub(r'\s+', ' ', text)
|
191 |
+
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
192 |
+
text = re.sub(r'\s+,', r',', text)
|
193 |
+
text = re.sub(r',\s+', r', ', text)
|
194 |
+
# no whitespace before the full stop
|
195 |
+
text = re.sub(r'\s+\.', r'.', text)
|
196 |
+
# strip whitespace, comma, and replace ',.'
|
197 |
+
text = text.strip(' ,')
|
198 |
+
text = text.replace(',.', '.')
|
199 |
+
return text
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
def _beautify_zh(text):
|
203 |
+
# no continuous commas without content between them
|
204 |
+
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
205 |
+
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
206 |
+
# assume there should be NO whitespace in Chinese
|
207 |
+
text = re.sub(r'\s+', r'', text)
|
208 |
+
# strip whitespace, comma, and replace ',。'
|
209 |
+
text = text.strip(', 、')
|
210 |
+
text = text.replace(',。', '。')
|
211 |
+
return text
|
212 |
+
|
213 |
+
def __repr__(self):
|
214 |
+
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
215 |
+
|
216 |
+
__str__ = __repr__
|
217 |
+
|
218 |
+
def parse_prompt_template(prompt_template_text, lang='en'):
|
219 |
+
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
220 |
+
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
221 |
+
|
222 |
+
template_text = prompt_template_text.strip()
|
223 |
+
span_texts = span_pattern.findall(prompt_template_text)
|
224 |
+
tag_map = {}
|
225 |
+
for span_text in span_texts:
|
226 |
+
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
227 |
+
tag_map[tag] = span_text
|
228 |
+
template_text = template_text.replace(span_text, '{'+tag+'}')
|
229 |
+
|
230 |
+
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
231 |
+
|
232 |
+
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
233 |
+
with open(path, 'r') as f:
|
234 |
+
lines = f.readlines()
|
235 |
+
cnt = 0
|
236 |
+
pts = []
|
237 |
+
for line in lines:
|
238 |
+
pt = parse_prompt_template(line, lang=lang)
|
239 |
+
cnt += 1
|
240 |
+
if len(pt.tags) < num:
|
241 |
+
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
242 |
+
pts.append(pt)
|
243 |
+
|
244 |
+
return pts
|
245 |
+
|
246 |
+
|
247 |
+
def get_base_dir_file(key: os.PathLike):
|
248 |
+
base = os.path.basename(key)
|
249 |
+
dirname = os.path.basename(os.path.dirname(key))
|
250 |
+
return os.path.join(dirname, base)
|
251 |
+
|
252 |
+
def read_jsonlike(path: os.PathLike):
|
253 |
+
#json or jsonl
|
254 |
+
if str(path).endswith(".json"):
|
255 |
+
with open(path, 'r', encoding='utf8') as f:
|
256 |
+
data = json.load(f)
|
257 |
+
return data
|
258 |
+
elif str(path).endswith(".jsonl"):
|
259 |
+
with open(path, 'r', encoding='utf8') as f:
|
260 |
+
data = [json.loads(line) for line in f.readlines()]
|
261 |
+
return data
|
262 |
+
else:
|
263 |
+
raise ValueError("Unknown file format")
|
264 |
+
|
265 |
+
dist_prob_map = {
|
266 |
+
1: (1.0,),
|
267 |
+
2: (0.5, 0.5),
|
268 |
+
3: (0.3, 0.4, 0.3),
|
269 |
+
4: (0.2, 0.3, 0.3, 0.2),
|
270 |
+
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
271 |
+
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
272 |
+
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
273 |
+
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
274 |
+
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
275 |
+
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
276 |
+
}
|
277 |
+
|
278 |
+
dist_prob_map_low = {
|
279 |
+
1: (1.0,),
|
280 |
+
2: (0.8, 0.2),
|
281 |
+
3: (0.8, 0.1, 0.1),
|
282 |
+
4: (0.7, 0.1, 0.1, 0.1),
|
283 |
+
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
284 |
+
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
285 |
+
}
|
286 |
+
|
287 |
+
_bpm_range_rights = (
|
288 |
+
(40, '20-40'),
|
289 |
+
(60, '40-60'),
|
290 |
+
(66, '60-66'),
|
291 |
+
(76, '66-76'),
|
292 |
+
(108, '76-108'),
|
293 |
+
(120, '108-120'),
|
294 |
+
(168, '120-168'),
|
295 |
+
(176, '168-176'),
|
296 |
+
(200, '176-200')
|
297 |
+
)
|
298 |
+
_bpm_desc_map = {
|
299 |
+
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
300 |
+
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
301 |
+
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
302 |
+
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
303 |
+
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
304 |
+
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
305 |
+
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
306 |
+
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
307 |
+
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
308 |
+
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
309 |
+
}
|
310 |
+
_bpm_desc_map_zh = {
|
311 |
+
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
312 |
+
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
313 |
+
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
314 |
+
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
315 |
+
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
316 |
+
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
317 |
+
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
318 |
+
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
319 |
+
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
320 |
+
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
321 |
+
}
|
322 |
+
def get_bpm_range(bpm):
|
323 |
+
bpm = int(bpm)
|
324 |
+
for right, tag in _bpm_range_rights:
|
325 |
+
if bpm <= right:
|
326 |
+
return tag
|
327 |
+
return '>200'
|
328 |
+
|
329 |
+
def gen_bpm_descript(bpm, lang='en'):
|
330 |
+
bpm_range = get_bpm_range(bpm)
|
331 |
+
if lang == 'en':
|
332 |
+
return random.choice(_bpm_desc_map[bpm_range])
|
333 |
+
elif lang == 'zh':
|
334 |
+
return random.choice(_bpm_desc_map_zh[bpm_range])
|
335 |
+
else:
|
336 |
+
raise ValueError(f"Unknown language {lang}")
|
337 |
+
|
338 |
+
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
339 |
+
if translate is None:
|
340 |
+
return None
|
341 |
+
if isinstance(translate, str):
|
342 |
+
return read_jsonlike(translate)
|
343 |
+
return {k: read_jsonlike(path) for k, path in translate.items()}
|
344 |
+
|
345 |
+
|
346 |
+
class MagnaTagATuneDataset(Dataset):
|
347 |
+
def __init__(self):
|
348 |
+
pass
|
349 |
+
|
350 |
+
|
351 |
+
def tags_to_desc(tag_list, sep=',') -> str:
|
352 |
+
if not isinstance(tag_list, Sequence):
|
353 |
+
return str(tag_list)
|
354 |
+
if isinstance(tag_list, str):
|
355 |
+
return tag_list
|
356 |
+
if len(tag_list) <= 0:
|
357 |
+
return ''
|
358 |
+
elif len(tag_list) <= 5:
|
359 |
+
probs = dist_prob_map[len(tag_list)]
|
360 |
+
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
361 |
+
random.shuffle(tag_list)
|
362 |
+
tag_list = tag_list[:tags_num]
|
363 |
+
return sep.join(tag_list)
|
364 |
+
else:
|
365 |
+
probs = dist_prob_map[5]
|
366 |
+
tags_num = random.choices(range(1, 6), probs)[0]
|
367 |
+
random.shuffle(tag_list)
|
368 |
+
tag_list = tag_list[:tags_num]
|
369 |
+
return sep.join(tag_list)
|
370 |
+
|
371 |
+
def get_sr_and_duration_info(item):
|
372 |
+
return item.get('sample_rate', None), item.get('duration', None)
|
373 |
+
|
374 |
+
class MtgJamendoDatasetFromJson(Dataset):
|
375 |
+
def __init__(self,
|
376 |
+
data_dir:str,
|
377 |
+
json_path:str,
|
378 |
+
duration:float=10,
|
379 |
+
sr:int = 0,
|
380 |
+
*,
|
381 |
+
lang = 'en',
|
382 |
+
return_path = False,
|
383 |
+
prompt_template_path: os.PathLike = None,
|
384 |
+
tag_types = [],
|
385 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
386 |
+
):
|
387 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
388 |
+
|
389 |
+
self.data_dir = data_dir
|
390 |
+
self._load_metadata_json(json_path)
|
391 |
+
self.sr = sr
|
392 |
+
self.duration = duration
|
393 |
+
self.return_path = return_path
|
394 |
+
self.lang = lang
|
395 |
+
|
396 |
+
self.use_dynamic_prompt = prompt_template_path is not None
|
397 |
+
if self.use_dynamic_prompt:
|
398 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
399 |
+
self.tag_types = tag_types
|
400 |
+
|
401 |
+
self.translate = read_translate(translate)
|
402 |
+
if not self.use_dynamic_prompt and self.lang != 'en':
|
403 |
+
raise NotImplementedError
|
404 |
+
|
405 |
+
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
406 |
+
WEAK_TAG_LIST = ["title", "artist"]
|
407 |
+
|
408 |
+
def _load_metadata_json(self, json_path):
|
409 |
+
with open(json_path) as fp:
|
410 |
+
self.data = json.load(fp)
|
411 |
+
|
412 |
+
def convert_key_to_path(self, key):
|
413 |
+
return os.path.join(self.data_dir, get_base_dir_file(key))
|
414 |
+
|
415 |
+
def __len__(self):
|
416 |
+
return len(self.data)
|
417 |
+
|
418 |
+
def __getitem__(self, idx):
|
419 |
+
item = self.data[idx]
|
420 |
+
path = self.convert_key_to_path(item['key'])
|
421 |
+
description = self.generate_description(item)
|
422 |
+
|
423 |
+
sr, duration = get_sr_and_duration_info(item)
|
424 |
+
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
425 |
+
|
426 |
+
if self.return_path:
|
427 |
+
return audio, description, path
|
428 |
+
return audio, description, is_start, is_end
|
429 |
+
|
430 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
431 |
+
if self.lang == 'en':
|
432 |
+
return tags_to_desc(tag_list)
|
433 |
+
elif self.lang == 'zh':
|
434 |
+
translator = self.translate[tag_type]
|
435 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
436 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
437 |
+
|
438 |
+
def generate_description(self, item):
|
439 |
+
if self.use_dynamic_prompt:
|
440 |
+
# dynamically generate prompt from given prompt template
|
441 |
+
prompt_template = random.choice(self.prompt_templates)
|
442 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
443 |
+
|
444 |
+
else:
|
445 |
+
# use ordinary static prompt instead
|
446 |
+
description = self.generate_description_ordinary(item)
|
447 |
+
return description
|
448 |
+
|
449 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
450 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
451 |
+
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
452 |
+
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
453 |
+
|
454 |
+
if len(exists_strong_tag) > 0:
|
455 |
+
probs = dist_prob_map[len(exists_strong_tag)]
|
456 |
+
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
457 |
+
random.shuffle(exists_strong_tag)
|
458 |
+
tags = exists_strong_tag[:tags_num]
|
459 |
+
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
460 |
+
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
461 |
+
random.shuffle(exists_weak_tag)
|
462 |
+
weak_tags = exists_weak_tag[:weak_tags_num]
|
463 |
+
tags += weak_tags
|
464 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
465 |
+
prompt = prompt_template.apply(**tags_args)
|
466 |
+
else:
|
467 |
+
# no strong tags, use all weak tags instead
|
468 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
469 |
+
prompt = prompt_template.apply(**tags_args)
|
470 |
+
|
471 |
+
return prompt
|
472 |
+
|
473 |
+
def generate_description_ordinary(self, data, thresh = 0.3):
|
474 |
+
# Initialize the description with title and artist
|
475 |
+
description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
|
476 |
+
|
477 |
+
# Add genre if available
|
478 |
+
if data["genre"] and random.random() > thresh:
|
479 |
+
genres = ', '.join(data["genre"])
|
480 |
+
description += f', belonging to the {genres} genres'
|
481 |
+
|
482 |
+
# Add moods if available
|
483 |
+
if data["moods"] and random.random() > thresh:
|
484 |
+
moods = ', '.join(data["moods"])
|
485 |
+
description += f'. This track conveys a {moods} mood'
|
486 |
+
|
487 |
+
# Add instruments if available
|
488 |
+
if data["instrument"] and random.random() > thresh:
|
489 |
+
instruments = ', '.join(data["instrument"])
|
490 |
+
description += f', and primarily features the following instruments: {instruments}'
|
491 |
+
|
492 |
+
# Add a period to end the description
|
493 |
+
description += '.'
|
494 |
+
|
495 |
+
return description
|
496 |
+
|
497 |
+
class AudioStockDataset(Dataset):
|
498 |
+
def __init__(self,
|
499 |
+
metadata_path:str,
|
500 |
+
duration:float=10,
|
501 |
+
sr:int = 0,
|
502 |
+
return_path = False,
|
503 |
+
return_audio = True,
|
504 |
+
prompt_template_path: os.PathLike = None,
|
505 |
+
tag_types = [],
|
506 |
+
lang = 'en',
|
507 |
+
translate:Optional[Dict[str, os.PathLike]] = None
|
508 |
+
):
|
509 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
510 |
+
|
511 |
+
self.duration = duration
|
512 |
+
self._load_metadata(metadata_path)
|
513 |
+
self.sr = sr
|
514 |
+
self.return_path = return_path
|
515 |
+
self.return_audio = return_audio
|
516 |
+
|
517 |
+
self.use_dynamic_prompt = prompt_template_path is not None
|
518 |
+
if self.use_dynamic_prompt:
|
519 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
520 |
+
self.tag_types = tag_types
|
521 |
+
|
522 |
+
self.lang = lang
|
523 |
+
self.translate = read_translate(translate)
|
524 |
+
|
525 |
+
def _load_metadata(self, metadata_path):
|
526 |
+
with open(metadata_path) as fp:
|
527 |
+
lines = fp.readlines()
|
528 |
+
self.data = []
|
529 |
+
for line in lines:
|
530 |
+
item = json.loads(line)
|
531 |
+
if(item['duration']>self.duration+10):
|
532 |
+
self.data.append(item)
|
533 |
+
self.is_info_recorded = bool('Tags' in self.data[0])
|
534 |
+
|
535 |
+
def __len__(self):
|
536 |
+
return len(self.data)
|
537 |
+
|
538 |
+
def __getitem__(self, idx):
|
539 |
+
path:str = self.data[idx]["path"]
|
540 |
+
json_path = path[:path.rfind('.')] + ".json"
|
541 |
+
if self.is_info_recorded:
|
542 |
+
item = self.data[idx]
|
543 |
+
else:
|
544 |
+
try:
|
545 |
+
with open(json_path) as fp:
|
546 |
+
item:dict = json.load(fp)
|
547 |
+
except Exception as e:
|
548 |
+
print(f"Error loading json file {json_path} :\n{e}")
|
549 |
+
item = {}
|
550 |
+
description = self.generate_description(item)
|
551 |
+
if self.return_audio:
|
552 |
+
sr, duration = get_sr_and_duration_info(item)
|
553 |
+
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
554 |
+
else:
|
555 |
+
audio = None
|
556 |
+
if self.return_path:
|
557 |
+
return audio, description, path, is_start, is_end
|
558 |
+
else:
|
559 |
+
return audio, description, is_start, is_end
|
560 |
+
|
561 |
+
def generate_description(self, item):
|
562 |
+
if self.use_dynamic_prompt:
|
563 |
+
# dynamically generate prompt from given prompt template
|
564 |
+
prompt_template = random.choice(self.prompt_templates)
|
565 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
566 |
+
else:
|
567 |
+
# use ordinary static prompt instead
|
568 |
+
description = self.generate_description_ordinary(item)
|
569 |
+
return description
|
570 |
+
|
571 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
572 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
573 |
+
|
574 |
+
if len(exists_tag) > 0:
|
575 |
+
probs = dist_prob_map[len(exists_tag)]
|
576 |
+
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
577 |
+
random.shuffle(exists_tag)
|
578 |
+
tags = exists_tag[:tags_num]
|
579 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
580 |
+
tags_args = self.handle_BPM_tag(tags_args)
|
581 |
+
prompt = prompt_template.apply(**tags_args)
|
582 |
+
else:
|
583 |
+
# no strong tags, use all weak tags instead
|
584 |
+
prompt = prompt_template.apply()
|
585 |
+
|
586 |
+
return prompt
|
587 |
+
|
588 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
589 |
+
if self.lang == 'en':
|
590 |
+
return tags_to_desc(tag_list)
|
591 |
+
elif self.lang == 'zh':
|
592 |
+
if tag_type == 'BPM':
|
593 |
+
return tags_to_desc(tag_list, sep='、')
|
594 |
+
translator = self.translate[tag_type]
|
595 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
596 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
597 |
+
|
598 |
+
def handle_BPM_tag(self, tags_args):
|
599 |
+
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
600 |
+
bpm = tags_args["BPM"]
|
601 |
+
del tags_args["BPM"]
|
602 |
+
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
603 |
+
for tag_type in tag_types_used:
|
604 |
+
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
605 |
+
return tags_args
|
606 |
+
|
607 |
+
def generate_description_ordinary(self, data, thresh = 0.3):
|
608 |
+
if self.lang != 'en':
|
609 |
+
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
610 |
+
description = f'a piece of music by {data["Artist"]}'
|
611 |
+
|
612 |
+
# Add genre if available
|
613 |
+
if data["Genre"] and random.random() > thresh:
|
614 |
+
genres = ', '.join(data["Genre"])
|
615 |
+
description += f', belonging to the {genres} genres'
|
616 |
+
|
617 |
+
# Add moods if available
|
618 |
+
if data["Tags"] and random.random() > thresh:
|
619 |
+
tags = ', '.join(data["Tags"])
|
620 |
+
description += f'. This track contains the tags:{tags}'
|
621 |
+
|
622 |
+
# Add moods if available
|
623 |
+
if data["Mood"] and random.random() > thresh:
|
624 |
+
moods = ', '.join(data["Mood"])
|
625 |
+
description += f'. This track conveys a {moods} mood.'
|
626 |
+
|
627 |
+
# Add instruments if available
|
628 |
+
if data["Instrument"] and random.random() > thresh:
|
629 |
+
instruments = ', '.join(data["Instrument"])
|
630 |
+
description += f'. and primarily features the following instruments: {instruments}'
|
631 |
+
|
632 |
+
# Add a period to end the description
|
633 |
+
description += '.'
|
634 |
+
|
635 |
+
return description
|
636 |
+
|
637 |
+
def mp3_path_to_id(mp3_path):
|
638 |
+
return int(
|
639 |
+
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
|
640 |
+
)
|
641 |
+
|
642 |
+
class TmeDataset(Dataset):
|
643 |
+
def __init__(self,
|
644 |
+
data_index:str,
|
645 |
+
music_info:str = None,
|
646 |
+
duration:float = 10,
|
647 |
+
sr:int = 0,
|
648 |
+
return_path = False,
|
649 |
+
return_audio = True,
|
650 |
+
prompt_format_path: os.PathLike = None,
|
651 |
+
tag_types = ['*'],
|
652 |
+
lang = 'zh',
|
653 |
+
translate: Optional[os.PathLike] = None,
|
654 |
+
prompt_dir: os.PathLike = None,
|
655 |
+
):
|
656 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
657 |
+
|
658 |
+
self.sr = sr
|
659 |
+
self.duration = duration
|
660 |
+
self.return_path = return_path
|
661 |
+
self.return_audio = return_audio
|
662 |
+
self.lang = lang
|
663 |
+
|
664 |
+
self.use_ready_prompt = prompt_dir is not None
|
665 |
+
|
666 |
+
data_index = read_jsonlike(data_index)
|
667 |
+
data_index = [d for d in data_index if d['duration']>self.duration+10]
|
668 |
+
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
669 |
+
self.data_ids = list(self.data_index_dict.keys())
|
670 |
+
|
671 |
+
if not self.use_ready_prompt:
|
672 |
+
#读取音乐的信息文件
|
673 |
+
music_info = read_jsonlike(music_info)
|
674 |
+
if 'music' in music_info:
|
675 |
+
music_info = music_info['music']
|
676 |
+
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
677 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
678 |
+
self.data_ids = list(self.data_index_dict.keys())
|
679 |
+
|
680 |
+
with open(prompt_format_path) as fp:
|
681 |
+
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
682 |
+
|
683 |
+
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
684 |
+
if '*' in tag_types:
|
685 |
+
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
686 |
+
else:
|
687 |
+
self.tag_types = tag_types
|
688 |
+
|
689 |
+
self.key_tag_types = []
|
690 |
+
if 'tag' in self.tag_types:
|
691 |
+
self.tag_types.remove('tag')
|
692 |
+
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
693 |
+
|
694 |
+
#加载translate翻译
|
695 |
+
if translate is not None:
|
696 |
+
self.translator = read_jsonlike(translate)
|
697 |
+
else:
|
698 |
+
data_ids_set = set(self.data_ids)
|
699 |
+
self.prompts_dict = {}
|
700 |
+
for fname in os.listdir(prompt_dir):
|
701 |
+
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
702 |
+
for item in items:
|
703 |
+
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
704 |
+
continue
|
705 |
+
if item['ID'] not in self.prompts_dict:
|
706 |
+
self.prompts_dict[item['ID']] = []
|
707 |
+
self.prompts_dict[item['ID']].append(item['Text'])
|
708 |
+
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
709 |
+
self.data_ids = list(self.data_index_dict.keys())
|
710 |
+
|
711 |
+
def tags_to_desc(self, tag_list) -> str:
|
712 |
+
if is_bearable(tag_list, int):
|
713 |
+
return str(tag_list)
|
714 |
+
if self.lang == 'zh':
|
715 |
+
return tags_to_desc(tag_list, sep=self.sep)
|
716 |
+
else:
|
717 |
+
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
718 |
+
return tags_to_desc(translated_tag_list, sep=self.sep)
|
719 |
+
|
720 |
+
def gen_desc_of_tag(self, formats, tags):
|
721 |
+
fmt = random.choice(formats)
|
722 |
+
return fmt.format(self.tags_to_desc(tags))
|
723 |
+
|
724 |
+
@staticmethod
|
725 |
+
def check_valid(value):
|
726 |
+
if isinstance(value, int) or isinstance(value, float):
|
727 |
+
return value > 0
|
728 |
+
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
729 |
+
return True
|
730 |
+
return False
|
731 |
+
|
732 |
+
@staticmethod
|
733 |
+
def remove_repeat(data):
|
734 |
+
#若专辑名和歌曲名相同,则只使用后者
|
735 |
+
album_name = data.get('专辑名', None)
|
736 |
+
if album_name is not None and album_name == data.get('歌曲名', None):
|
737 |
+
del data['专辑名']
|
738 |
+
return data
|
739 |
+
|
740 |
+
@property
|
741 |
+
def comma(self):
|
742 |
+
if self.lang == 'zh':
|
743 |
+
return ','
|
744 |
+
elif self.lang == 'en':
|
745 |
+
return ', '
|
746 |
+
|
747 |
+
@property
|
748 |
+
def sep(self):
|
749 |
+
if self.lang == 'zh':
|
750 |
+
return '、'
|
751 |
+
elif self.lang == 'en':
|
752 |
+
return ', '
|
753 |
+
|
754 |
+
def generate_description(self, data):
|
755 |
+
data = self.remove_repeat(data)
|
756 |
+
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
757 |
+
|
758 |
+
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
759 |
+
|
760 |
+
prompts = []
|
761 |
+
if len(weak_tags) > 0:
|
762 |
+
probs = dist_prob_map_low[len(weak_tags)]
|
763 |
+
if len(key_tags) > 0:
|
764 |
+
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
765 |
+
else:
|
766 |
+
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
767 |
+
random.shuffle(weak_tags)
|
768 |
+
tags = weak_tags[:tags_num]
|
769 |
+
for tag_type in tags:
|
770 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
771 |
+
prompts.append(tag_desc)
|
772 |
+
|
773 |
+
if len(key_tags) > 0:
|
774 |
+
probs = dist_prob_map[len(key_tags)]
|
775 |
+
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
776 |
+
random.shuffle(key_tags)
|
777 |
+
tags = key_tags[:tags_num]
|
778 |
+
for tag_type in tags:
|
779 |
+
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
780 |
+
prompts.append(tag_desc)
|
781 |
+
|
782 |
+
random.shuffle(prompts)
|
783 |
+
return self.comma.join(prompts)
|
784 |
+
|
785 |
+
def is_valid_prompt_text(self, text):
|
786 |
+
for bad in ('抱歉','sorry', 'Sorry'):
|
787 |
+
if bad in text:
|
788 |
+
return False
|
789 |
+
return True
|
790 |
+
|
791 |
+
def get_ready_prompt(self, path):
|
792 |
+
sid = mp3_path_to_id(path)
|
793 |
+
return random.choice(self.prompts_dict[sid])
|
794 |
+
|
795 |
+
def __len__(self):
|
796 |
+
return len(self.data_ids)
|
797 |
+
|
798 |
+
def __getitem__(self, idx):
|
799 |
+
data_id = self.data_ids[idx]
|
800 |
+
item = self.data_index_dict[data_id]
|
801 |
+
path = item['path']
|
802 |
+
if not self.use_ready_prompt:
|
803 |
+
info = self.music_info_dict[data_id]
|
804 |
+
description = self.generate_description(info)
|
805 |
+
else:
|
806 |
+
description = self.get_ready_prompt(path)
|
807 |
+
if self.return_audio:
|
808 |
+
sr, duration = get_sr_and_duration_info(item)
|
809 |
+
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
810 |
+
else:
|
811 |
+
audio = None
|
812 |
+
if self.return_path:
|
813 |
+
return audio, description, path, is_start, is_end
|
814 |
+
else:
|
815 |
+
return audio, description, is_start, is_end
|
816 |
+
|
817 |
+
class Pond5Dataset(Dataset):
|
818 |
+
MAX_PROMPT_LEN = 200
|
819 |
+
def __init__(self,
|
820 |
+
metadata_path:str,
|
821 |
+
index_path:str,
|
822 |
+
duration:float=10,
|
823 |
+
sr:int = 0,
|
824 |
+
plain_rate = 0,
|
825 |
+
return_path = False,
|
826 |
+
return_audio = True,
|
827 |
+
lang = 'en',
|
828 |
+
translate:Optional[Dict[str, os.PathLike]] = None,
|
829 |
+
use_literal_none = True,
|
830 |
+
use_avoid_watermark_policy = None,
|
831 |
+
):
|
832 |
+
|
833 |
+
if use_avoid_watermark_policy is None:
|
834 |
+
raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
|
835 |
+
self.use_avoid_watermark_policy = use_avoid_watermark_policy
|
836 |
+
assert self.use_avoid_watermark_policy is False
|
837 |
+
self.audio_reader = SafeAudioReader(duration, sr)
|
838 |
+
|
839 |
+
self.duration = duration
|
840 |
+
self._load_metadata(metadata_path, index_path)
|
841 |
+
self.sr = sr
|
842 |
+
self.plain_rate = plain_rate
|
843 |
+
self.return_path = return_path
|
844 |
+
self.return_audio = return_audio
|
845 |
+
self.use_literal_none = use_literal_none
|
846 |
+
|
847 |
+
self.lang = lang
|
848 |
+
self.translate = read_translate(translate)
|
849 |
+
|
850 |
+
def _load_metadata(self, metadata_path, index_path):
|
851 |
+
data_index = read_jsonlike(index_path)
|
852 |
+
data_ids = set([item['id'] for item in data_index])
|
853 |
+
|
854 |
+
with open(metadata_path) as fp:
|
855 |
+
lines = fp.readlines()
|
856 |
+
|
857 |
+
append_ids = set()
|
858 |
+
|
859 |
+
self.data = []
|
860 |
+
for line in lines:
|
861 |
+
item = json.loads(line)
|
862 |
+
if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10:
|
863 |
+
self.data.append(item)
|
864 |
+
append_ids.add(item['id'])
|
865 |
+
|
866 |
+
def __len__(self):
|
867 |
+
return len(self.data)
|
868 |
+
|
869 |
+
def __getitem__(self, idx):
|
870 |
+
item = self.data[idx]
|
871 |
+
path:str = item["path"]
|
872 |
+
description = self.generate_description(item)
|
873 |
+
if self.return_audio:
|
874 |
+
sr, duration = get_sr_and_duration_info(item)
|
875 |
+
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
876 |
+
else:
|
877 |
+
audio = None
|
878 |
+
if self.return_path:
|
879 |
+
return audio, description, path
|
880 |
+
return audio, description, is_start, is_end
|
881 |
+
|
882 |
+
@property
|
883 |
+
def keysep(self):
|
884 |
+
if self.lang == 'zh':
|
885 |
+
return ',' if random.random() > 0.5 else '、'
|
886 |
+
elif self.lang == 'en':
|
887 |
+
return ', '
|
888 |
+
|
889 |
+
def generate_description(self, item):
|
890 |
+
if random.random() > self.plain_rate:
|
891 |
+
# dynamically generate prompt from given prompt template
|
892 |
+
description = self.generate_description_dynamic(item)
|
893 |
+
else:
|
894 |
+
# use plain prompt, i.e. tags sequence separated by comma
|
895 |
+
description = self.generate_description_plain(item)
|
896 |
+
return description
|
897 |
+
|
898 |
+
def get_translation(self, k):
|
899 |
+
k = k.strip()
|
900 |
+
if k in self.translate:
|
901 |
+
return self.translate[k]
|
902 |
+
else:
|
903 |
+
return k
|
904 |
+
|
905 |
+
def generate_description_plain(self, item):
|
906 |
+
keywords = item['keywords']
|
907 |
+
if self.lang != 'en':
|
908 |
+
keywords = [self.get_translation(k) for k in keywords]
|
909 |
+
return gen_plain_prompt(keywords, sep=self.keysep)
|
910 |
+
|
911 |
+
def generate_description_dynamic(self,item):
|
912 |
+
desc = item.get('desc', 'none')
|
913 |
+
if desc is None:
|
914 |
+
desc = 'none'
|
915 |
+
desc = desc.strip()
|
916 |
+
if len(desc) > self.MAX_PROMPT_LEN:
|
917 |
+
shorter_desc = desc[:self.MAX_PROMPT_LEN]
|
918 |
+
# find last stop
|
919 |
+
stop_idx = shorter_desc.rfind('.')
|
920 |
+
if stop_idx == -1:
|
921 |
+
stop_idx = shorter_desc.rfind('!')
|
922 |
+
if stop_idx == -1:
|
923 |
+
stop_idx = shorter_desc.rfind(',')
|
924 |
+
if stop_idx == -1:
|
925 |
+
stop_idx = self.MAX_PROMPT_LEN - 1
|
926 |
+
desc = desc[:stop_idx+1]
|
927 |
+
return desc
|
928 |
+
|
929 |
+
class CombinedDataset(Dataset):
|
930 |
+
@beartype
|
931 |
+
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
932 |
+
self.datasets = datasets
|
933 |
+
self.datasets_index = []
|
934 |
+
|
935 |
+
for i,dataset in enumerate(datasets):
|
936 |
+
if dataset is None:
|
937 |
+
continue
|
938 |
+
for dup in range(ratios[i]):
|
939 |
+
for j in range(len(dataset)):
|
940 |
+
self.datasets_index.append((i,j))
|
941 |
+
|
942 |
+
def __len__(self):
|
943 |
+
return len(self.datasets_index)
|
944 |
+
|
945 |
+
def __getitem__(self, idx):
|
946 |
+
index = self.datasets_index[idx]
|
947 |
+
i,j = index
|
948 |
+
return self.datasets[i][j]
|
949 |
+
|
950 |
+
class CombinedDataset_random(Dataset):
|
951 |
+
@beartype
|
952 |
+
def __init__(self,
|
953 |
+
num_examples:int,
|
954 |
+
datasets: Sequence[Dataset], ratios: Sequence[int]
|
955 |
+
):
|
956 |
+
self.datasets = datasets
|
957 |
+
self.datasets_index = []
|
958 |
+
|
959 |
+
for i,dataset in enumerate(datasets):
|
960 |
+
if dataset is None:
|
961 |
+
continue
|
962 |
+
for dup in range(ratios[i]):
|
963 |
+
for j in range(len(dataset)):
|
964 |
+
self.datasets_index.append((i,j))
|
965 |
+
if num_examples > 0:
|
966 |
+
self.random_choose = True
|
967 |
+
self.dataset_len = num_examples
|
968 |
+
else:
|
969 |
+
self.random_choose = False
|
970 |
+
self.dataset_len = len(self.datasets_index)
|
971 |
+
|
972 |
+
def __len__(self):
|
973 |
+
return self.dataset_len
|
974 |
+
|
975 |
+
def __getitem__(self, idx):
|
976 |
+
first_try = True
|
977 |
+
try_cnt = 0
|
978 |
+
while True:
|
979 |
+
try:
|
980 |
+
if(self.random_choose or not first_try):
|
981 |
+
index2 = []
|
982 |
+
index2.append(np.random.randint(0,len(self.datasets)))
|
983 |
+
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
984 |
+
else:
|
985 |
+
index2 = self.datasets_index[idx]
|
986 |
+
first_try = False
|
987 |
+
out = self.datasets[index2[0]][index2[1]]
|
988 |
+
if(len(out[0].shape)==1):out[0]=out[0][None,:]
|
989 |
+
return out
|
990 |
+
except:
|
991 |
+
print("Error loadding ", index2)
|
992 |
+
try_cnt += 1
|
993 |
+
if(try_cnt>10):
|
994 |
+
raise FileNotFoundError()
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torchaudio
|
7 |
+
from torchaudio.functional import resample
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def check_lryics(lyric):
|
16 |
+
_FILTER_STRING = [
|
17 |
+
'作词', '作曲', '编曲', '【', '策划',
|
18 |
+
'录音', '混音', '母带', ':', '制作',
|
19 |
+
'版权', '校对', '演奏', '制作', '伴奏'
|
20 |
+
]
|
21 |
+
for item in _FILTER_STRING:
|
22 |
+
if item in lyric:
|
23 |
+
return True
|
24 |
+
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def process_lyrics(lines):
|
30 |
+
lyric_part = []
|
31 |
+
timestamp_part = []
|
32 |
+
|
33 |
+
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
34 |
+
|
35 |
+
for i, line in enumerate(lines):
|
36 |
+
|
37 |
+
# 删除前几行的特定信息
|
38 |
+
if i<10 and check_lryics(line):
|
39 |
+
continue
|
40 |
+
|
41 |
+
# 检查是否包含有效的时间戳和歌词内容
|
42 |
+
if timestamp_pattern.match(line):
|
43 |
+
timestamp_end = line.rfind(']')
|
44 |
+
lyrics = line[timestamp_end + 1:].strip()
|
45 |
+
timestamps = line[:timestamp_end + 1]
|
46 |
+
|
47 |
+
if ':' in lyrics:
|
48 |
+
if len(lyrics.split(":")[0]) <=5:
|
49 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
50 |
+
# if lyrics: # 确保歌词部分不是空的
|
51 |
+
# lyric_part.append(lyrics)
|
52 |
+
# timestamp_part.append(timestamps)
|
53 |
+
# print(processed_lyrics)
|
54 |
+
return timestamp_part, lyric_part
|
55 |
+
|
56 |
+
def get_timestamps(timestamp_part):
|
57 |
+
|
58 |
+
# 转换为秒
|
59 |
+
|
60 |
+
timestamps = []
|
61 |
+
|
62 |
+
for line in timestamp_part:
|
63 |
+
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
64 |
+
if match:
|
65 |
+
minutes = int(match.group(1))
|
66 |
+
seconds = float(match.group(2))
|
67 |
+
millis = float(match.group(3)) if match.group(3) else 0
|
68 |
+
total_seconds = minutes * 60 + seconds + millis
|
69 |
+
timestamps.append(total_seconds)
|
70 |
+
|
71 |
+
|
72 |
+
return timestamps
|
73 |
+
|
74 |
+
def process_lyrics_lrc(lyrics):
|
75 |
+
timestamp_part, lyric_part = process_lyrics(lyrics)
|
76 |
+
# print(timestamp_part)
|
77 |
+
# print(lyric_part)
|
78 |
+
timestamps = get_timestamps(timestamp_part)
|
79 |
+
# print(timestamps)
|
80 |
+
if len(timestamps) == 0:
|
81 |
+
# print(f'{lyric_path}')
|
82 |
+
return []
|
83 |
+
|
84 |
+
slice_start = timestamps[0]
|
85 |
+
slice_start_idx = 0
|
86 |
+
|
87 |
+
output_list = []
|
88 |
+
for i in range(1, len(timestamps)):
|
89 |
+
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
90 |
+
if timestamps[i] - slice_start > 30:
|
91 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
92 |
+
|
93 |
+
slice_start = timestamps[i]
|
94 |
+
slice_start_idx = i
|
95 |
+
|
96 |
+
return output_list
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def process_lyrics_yrc(lyrics):
|
101 |
+
|
102 |
+
timestamps, lyric_part = extract_lrc(lyrics)
|
103 |
+
|
104 |
+
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
105 |
+
# import pdb; pdb.set_trace()
|
106 |
+
# print(timestamp_part)
|
107 |
+
# print(lyric_part)
|
108 |
+
# timestamps = get_timestamps(timestamp_part)
|
109 |
+
# print(timestamps)
|
110 |
+
if len(timestamps) == 0:
|
111 |
+
# print(f'{lyric_path}')
|
112 |
+
return []
|
113 |
+
|
114 |
+
slice_start = timestamps[0]
|
115 |
+
slice_start_idx = 0
|
116 |
+
|
117 |
+
output_list = []
|
118 |
+
for i in range(1, len(timestamps)):
|
119 |
+
# 如果累积时间超过30秒,则进行切分
|
120 |
+
if timestamps[i] - slice_start > 30:
|
121 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
122 |
+
|
123 |
+
slice_start = timestamps[i]
|
124 |
+
slice_start_idx = i
|
125 |
+
# import pdb; pdb.set_trace()
|
126 |
+
return output_list
|
127 |
+
|
128 |
+
def extract_lrc(lyrics):
|
129 |
+
timestamp_part, lyric_part = [], []
|
130 |
+
|
131 |
+
for i, text in enumerate(lyrics):
|
132 |
+
# 提取中括号内的内容
|
133 |
+
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
134 |
+
bracket_content = bracket_content.split(',')
|
135 |
+
# 提取小括号内的内容
|
136 |
+
parentheses_content = re.findall(r'\((.*?)\)', text)
|
137 |
+
# 提取其他内容
|
138 |
+
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
139 |
+
|
140 |
+
# 数据怎么处理?
|
141 |
+
# import pdb; pdb.set_trace()
|
142 |
+
if i<10 and check_lryics(other_content):
|
143 |
+
continue
|
144 |
+
|
145 |
+
# import pdb; pdb.set_trace()
|
146 |
+
timestamp_part.append(float(bracket_content[0])/1000)
|
147 |
+
lyric_part.append(other_content)
|
148 |
+
# import pdb; pdb.set_trace()
|
149 |
+
return timestamp_part, lyric_part
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
class WYYSongDataset(Dataset):
|
154 |
+
def __init__(self,
|
155 |
+
metadata_path:str,
|
156 |
+
sr:int = 0,
|
157 |
+
use_lang = ['en', 'zh-cn'],
|
158 |
+
num_examples = -1,
|
159 |
+
):
|
160 |
+
|
161 |
+
self.sr = sr
|
162 |
+
self.use_lang = use_lang
|
163 |
+
self._load_metadata(metadata_path)
|
164 |
+
|
165 |
+
# buffer
|
166 |
+
self.lyric_buffer = {}
|
167 |
+
|
168 |
+
if(num_examples<=0):
|
169 |
+
self.dataset_len = len(self.data)
|
170 |
+
self.random_slc = False
|
171 |
+
else:
|
172 |
+
self.dataset_len = num_examples
|
173 |
+
self.random_slc = True
|
174 |
+
|
175 |
+
# 读取jsonl文件
|
176 |
+
def _load_metadata(self, metadata_path):
|
177 |
+
with open(metadata_path) as fp:
|
178 |
+
lines = fp.readlines()
|
179 |
+
self.data = []
|
180 |
+
for line in lines:
|
181 |
+
item = json.loads(line)
|
182 |
+
# if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
|
183 |
+
if 'lyrics' in item and 'lang_info' in item:
|
184 |
+
if len(item['lyrics']) > 0:
|
185 |
+
for lang in self.use_lang:
|
186 |
+
if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
|
187 |
+
# if '伴奏' not in item['path'] and "cloud" in item['path']:
|
188 |
+
if '伴奏' not in item['path']:
|
189 |
+
self.data.append(item)
|
190 |
+
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return self.dataset_len
|
194 |
+
|
195 |
+
|
196 |
+
def __getitem__(self, idx):
|
197 |
+
try_cnt = 0
|
198 |
+
while True:
|
199 |
+
if(self.random_slc):
|
200 |
+
idx = np.random.randint(0, len(self.data))
|
201 |
+
yrc_lyrics = []
|
202 |
+
lrc_lyrics = []
|
203 |
+
try:
|
204 |
+
info = self.data[idx]
|
205 |
+
|
206 |
+
# audio path
|
207 |
+
path:str = info["path"]
|
208 |
+
|
209 |
+
# 读取歌词段落
|
210 |
+
if 'lyrics' not in info:
|
211 |
+
if idx not in self.lyric_buffer:
|
212 |
+
# 字级别align的歌词
|
213 |
+
if info['yrc-lyric'] is not None:
|
214 |
+
with open(info['yrc-lyric']) as f_in:
|
215 |
+
yrc_lyric = json.load(f_in)
|
216 |
+
yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
|
217 |
+
|
218 |
+
# 句子级align的歌词
|
219 |
+
if info['lrc-lyric'] is not None:
|
220 |
+
with open(info['lrc-lyric']) as f_in:
|
221 |
+
lrc_lyric = json.load(f_in)
|
222 |
+
lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
|
223 |
+
|
224 |
+
# 优先使用字级别align的歌词
|
225 |
+
if len(yrc_lyrics) > 0:
|
226 |
+
lyrics = yrc_lyrics
|
227 |
+
else:
|
228 |
+
lyrics = lrc_lyrics
|
229 |
+
self.lyric_buffer[idx] = lyrics
|
230 |
+
|
231 |
+
# TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
|
232 |
+
else:
|
233 |
+
lyrics = self.lyric_buffer[idx]
|
234 |
+
else:
|
235 |
+
lyrics = info['lyrics']
|
236 |
+
|
237 |
+
# 随机选取一个lyric段落
|
238 |
+
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
239 |
+
# ly_id = 0
|
240 |
+
|
241 |
+
lyric = lyrics[ly_id]
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
st, et, lyric = self.parse_lyric(lyric)
|
246 |
+
|
247 |
+
assert et - st < 40
|
248 |
+
|
249 |
+
# 文本过滤
|
250 |
+
|
251 |
+
lyric = re.sub(r'【.*?】', '', lyric)
|
252 |
+
if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
|
253 |
+
assert 200 > len(lyric.replace(" ", "")) > 30
|
254 |
+
if ':' in lyrics:
|
255 |
+
if len(lyrics.split(":")[0]) <=5:
|
256 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
257 |
+
|
258 |
+
if ':' in lyrics:
|
259 |
+
if len(lyrics.split(":")[0]) <=5:
|
260 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
261 |
+
|
262 |
+
if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
|
263 |
+
assert 200 > len(lyric.split()) > 20
|
264 |
+
|
265 |
+
if ':' in lyrics:
|
266 |
+
if len(lyrics.split(":")[0].split()) <=3:
|
267 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
268 |
+
|
269 |
+
if ':' in lyrics:
|
270 |
+
if len(lyrics.split(":")[0].split()) <=3:
|
271 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
# 读取音频文件
|
276 |
+
cur_sample_rate = torchaudio.info(path).sample_rate
|
277 |
+
offset = int(cur_sample_rate*st)
|
278 |
+
num_frames = int(cur_sample_rate * (et -st))
|
279 |
+
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
280 |
+
|
281 |
+
# 随机选取一个channel
|
282 |
+
if(chunk.shape[0]>1):
|
283 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
284 |
+
else:
|
285 |
+
chunk = chunk[[0],:].float()
|
286 |
+
|
287 |
+
if(cur_sample_rate!=self.sr):
|
288 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
289 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
290 |
+
|
291 |
+
return chunk, lyric, [st, et], path
|
292 |
+
except:
|
293 |
+
print("Error loadding ", info["path"])
|
294 |
+
try_cnt += 1
|
295 |
+
idx = np.random.randint(0, len(self.data))
|
296 |
+
if(try_cnt>10):
|
297 |
+
raise FileNotFoundError()
|
298 |
+
|
299 |
+
def parse_lyric(self, lyric):
|
300 |
+
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
301 |
+
match = re.search(pattern, lyric)
|
302 |
+
|
303 |
+
start_time = float(match.group(1))
|
304 |
+
end_time = float(match.group(2))
|
305 |
+
content = match.group(3)
|
306 |
+
return start_time, end_time, content
|
307 |
+
|
308 |
+
def collect_song(data_list):
|
309 |
+
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
310 |
+
lyrics = [data[1] for data in data_list]
|
311 |
+
st_et = [data[2] for data in data_list]
|
312 |
+
paths = [data[3] for data in data_list]
|
313 |
+
return audios, lyrics, st_et
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torchaudio
|
7 |
+
from torchaudio.functional import resample
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def check_lryics(lyric):
|
16 |
+
_FILTER_STRING = [
|
17 |
+
'作词', '作曲', '编曲', '【', '策划',
|
18 |
+
'录音', '混音', '母带', ':', '制作',
|
19 |
+
'版权', '校对', '演奏', '制作', '伴奏'
|
20 |
+
]
|
21 |
+
for item in _FILTER_STRING:
|
22 |
+
if item in lyric:
|
23 |
+
return True
|
24 |
+
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def process_lyrics(lines):
|
30 |
+
lyric_part = []
|
31 |
+
timestamp_part = []
|
32 |
+
|
33 |
+
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
34 |
+
|
35 |
+
for i, line in enumerate(lines):
|
36 |
+
|
37 |
+
# 删除前几行的特定信息
|
38 |
+
if i<10 and check_lryics(line):
|
39 |
+
continue
|
40 |
+
|
41 |
+
# 检查是否包含有效的时间戳和歌词内容
|
42 |
+
if timestamp_pattern.match(line):
|
43 |
+
timestamp_end = line.rfind(']')
|
44 |
+
lyrics = line[timestamp_end + 1:].strip()
|
45 |
+
timestamps = line[:timestamp_end + 1]
|
46 |
+
|
47 |
+
if ':' in lyrics:
|
48 |
+
if len(lyrics.split(":")[0]) <=5:
|
49 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
50 |
+
# if lyrics: # 确保歌词部分不是空的
|
51 |
+
# lyric_part.append(lyrics)
|
52 |
+
# timestamp_part.append(timestamps)
|
53 |
+
# print(processed_lyrics)
|
54 |
+
return timestamp_part, lyric_part
|
55 |
+
|
56 |
+
def get_timestamps(timestamp_part):
|
57 |
+
|
58 |
+
# 转换为秒
|
59 |
+
|
60 |
+
timestamps = []
|
61 |
+
|
62 |
+
for line in timestamp_part:
|
63 |
+
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
64 |
+
if match:
|
65 |
+
minutes = int(match.group(1))
|
66 |
+
seconds = float(match.group(2))
|
67 |
+
millis = float(match.group(3)) if match.group(3) else 0
|
68 |
+
total_seconds = minutes * 60 + seconds + millis
|
69 |
+
timestamps.append(total_seconds)
|
70 |
+
|
71 |
+
|
72 |
+
return timestamps
|
73 |
+
|
74 |
+
def process_lyrics_lrc(lyrics):
|
75 |
+
timestamp_part, lyric_part = process_lyrics(lyrics)
|
76 |
+
# print(timestamp_part)
|
77 |
+
# print(lyric_part)
|
78 |
+
timestamps = get_timestamps(timestamp_part)
|
79 |
+
# print(timestamps)
|
80 |
+
if len(timestamps) == 0:
|
81 |
+
# print(f'{lyric_path}')
|
82 |
+
return []
|
83 |
+
|
84 |
+
slice_start = timestamps[0]
|
85 |
+
slice_start_idx = 0
|
86 |
+
|
87 |
+
output_list = []
|
88 |
+
for i in range(1, len(timestamps)):
|
89 |
+
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
90 |
+
if timestamps[i] - slice_start > 30:
|
91 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
92 |
+
|
93 |
+
slice_start = timestamps[i]
|
94 |
+
slice_start_idx = i
|
95 |
+
|
96 |
+
return output_list
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def process_lyrics_yrc(lyrics):
|
101 |
+
|
102 |
+
timestamps, lyric_part = extract_lrc(lyrics)
|
103 |
+
|
104 |
+
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
105 |
+
# import pdb; pdb.set_trace()
|
106 |
+
# print(timestamp_part)
|
107 |
+
# print(lyric_part)
|
108 |
+
# timestamps = get_timestamps(timestamp_part)
|
109 |
+
# print(timestamps)
|
110 |
+
if len(timestamps) == 0:
|
111 |
+
# print(f'{lyric_path}')
|
112 |
+
return []
|
113 |
+
|
114 |
+
slice_start = timestamps[0]
|
115 |
+
slice_start_idx = 0
|
116 |
+
|
117 |
+
output_list = []
|
118 |
+
for i in range(1, len(timestamps)):
|
119 |
+
# 如果累积时间超过30秒,则进行切分
|
120 |
+
if timestamps[i] - slice_start > 30:
|
121 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
122 |
+
|
123 |
+
slice_start = timestamps[i]
|
124 |
+
slice_start_idx = i
|
125 |
+
# import pdb; pdb.set_trace()
|
126 |
+
return output_list
|
127 |
+
|
128 |
+
def extract_lrc(lyrics):
|
129 |
+
timestamp_part, lyric_part = [], []
|
130 |
+
|
131 |
+
for i, text in enumerate(lyrics):
|
132 |
+
# 提取中括号内的内容
|
133 |
+
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
134 |
+
bracket_content = bracket_content.split(',')
|
135 |
+
# 提取小括号内的内容
|
136 |
+
parentheses_content = re.findall(r'\((.*?)\)', text)
|
137 |
+
# 提取其他内容
|
138 |
+
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
139 |
+
|
140 |
+
# 数据怎么处理?
|
141 |
+
# import pdb; pdb.set_trace()
|
142 |
+
if i<10 and check_lryics(other_content):
|
143 |
+
continue
|
144 |
+
|
145 |
+
# import pdb; pdb.set_trace()
|
146 |
+
timestamp_part.append(float(bracket_content[0])/1000)
|
147 |
+
lyric_part.append(other_content)
|
148 |
+
# import pdb; pdb.set_trace()
|
149 |
+
return timestamp_part, lyric_part
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
class WYYSongDataset(Dataset):
|
154 |
+
def __init__(self,
|
155 |
+
metadata_path:str,
|
156 |
+
sr:int = 0,
|
157 |
+
use_lang = ['en', 'zh-cn'],
|
158 |
+
num_examples = -1,
|
159 |
+
):
|
160 |
+
|
161 |
+
self.sr = sr
|
162 |
+
self.use_lang = use_lang
|
163 |
+
self._load_metadata(metadata_path)
|
164 |
+
|
165 |
+
# buffer
|
166 |
+
self.lyric_buffer = {}
|
167 |
+
|
168 |
+
if(num_examples<=0):
|
169 |
+
self.dataset_len = len(self.data)
|
170 |
+
self.random_slc = False
|
171 |
+
else:
|
172 |
+
self.dataset_len = num_examples
|
173 |
+
self.random_slc = True
|
174 |
+
|
175 |
+
# 读取jsonl文件
|
176 |
+
def _load_metadata(self, metadata_path):
|
177 |
+
with open(metadata_path) as fp:
|
178 |
+
lines = fp.readlines()
|
179 |
+
self.data = []
|
180 |
+
for line in lines:
|
181 |
+
item = json.loads(line)
|
182 |
+
# if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
|
183 |
+
if 'lyrics' in item and 'lang_info' in item:
|
184 |
+
if len(item['lyrics']) > 0:
|
185 |
+
for lang in self.use_lang:
|
186 |
+
if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
|
187 |
+
# if '伴奏' not in item['path'] and "cloud" in item['path']:
|
188 |
+
if '伴奏' not in item['path']:
|
189 |
+
self.data.append(item)
|
190 |
+
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return self.dataset_len
|
194 |
+
|
195 |
+
|
196 |
+
def __getitem__(self, idx):
|
197 |
+
try_cnt = 0
|
198 |
+
while True:
|
199 |
+
if(self.random_slc):
|
200 |
+
idx = np.random.randint(0, len(self.data))
|
201 |
+
yrc_lyrics = []
|
202 |
+
lrc_lyrics = []
|
203 |
+
try:
|
204 |
+
info = self.data[idx]
|
205 |
+
|
206 |
+
# audio path
|
207 |
+
path:str = info["path"]
|
208 |
+
|
209 |
+
# 读取歌词段落
|
210 |
+
if 'lyrics' not in info:
|
211 |
+
if idx not in self.lyric_buffer:
|
212 |
+
# 字级别align的歌词
|
213 |
+
if info['yrc-lyric'] is not None:
|
214 |
+
with open(info['yrc-lyric']) as f_in:
|
215 |
+
yrc_lyric = json.load(f_in)
|
216 |
+
yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
|
217 |
+
|
218 |
+
# 句子级align的歌词
|
219 |
+
if info['lrc-lyric'] is not None:
|
220 |
+
with open(info['lrc-lyric']) as f_in:
|
221 |
+
lrc_lyric = json.load(f_in)
|
222 |
+
lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
|
223 |
+
|
224 |
+
# 优先使用字级别align的歌词
|
225 |
+
if len(yrc_lyrics) > 0:
|
226 |
+
lyrics = yrc_lyrics
|
227 |
+
else:
|
228 |
+
lyrics = lrc_lyrics
|
229 |
+
self.lyric_buffer[idx] = lyrics
|
230 |
+
|
231 |
+
# TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
|
232 |
+
else:
|
233 |
+
lyrics = self.lyric_buffer[idx]
|
234 |
+
else:
|
235 |
+
lyrics = info['lyrics']
|
236 |
+
|
237 |
+
# 随机选取一个lyric段落
|
238 |
+
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
239 |
+
# ly_id = 0
|
240 |
+
|
241 |
+
lyric = lyrics[ly_id]
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
st, et, lyric = self.parse_lyric(lyric)
|
246 |
+
|
247 |
+
assert et - st < 20
|
248 |
+
|
249 |
+
# 文本过滤
|
250 |
+
|
251 |
+
lyric = re.sub(r'【.*?】', '', lyric)
|
252 |
+
if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
|
253 |
+
assert 100 > len(lyric.replace(" ", "")) > 5
|
254 |
+
if ':' in lyrics:
|
255 |
+
if len(lyrics.split(":")[0]) <=5:
|
256 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
257 |
+
|
258 |
+
if ':' in lyrics:
|
259 |
+
if len(lyrics.split(":")[0]) <=5:
|
260 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
261 |
+
|
262 |
+
if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
|
263 |
+
assert 100 > len(lyric.split()) > 5
|
264 |
+
|
265 |
+
if ':' in lyrics:
|
266 |
+
if len(lyrics.split(":")[0].split()) <=3:
|
267 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
268 |
+
|
269 |
+
if ':' in lyrics:
|
270 |
+
if len(lyrics.split(":")[0].split()) <=3:
|
271 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
# 读取音频文件
|
276 |
+
cur_sample_rate = torchaudio.info(path).sample_rate
|
277 |
+
offset = int(cur_sample_rate*st)
|
278 |
+
num_frames = int(cur_sample_rate * (et -st))
|
279 |
+
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
280 |
+
|
281 |
+
# 随机选取一个channel
|
282 |
+
if(chunk.shape[0]>1):
|
283 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
284 |
+
else:
|
285 |
+
chunk = chunk[[0],:].float()
|
286 |
+
|
287 |
+
if(cur_sample_rate!=self.sr):
|
288 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
289 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
290 |
+
|
291 |
+
return chunk, lyric, [st, et], path
|
292 |
+
except:
|
293 |
+
print("Error loadding ", info["path"])
|
294 |
+
try_cnt += 1
|
295 |
+
idx = np.random.randint(0, len(self.data))
|
296 |
+
if(try_cnt>10):
|
297 |
+
raise FileNotFoundError()
|
298 |
+
|
299 |
+
def parse_lyric(self, lyric):
|
300 |
+
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
301 |
+
match = re.search(pattern, lyric)
|
302 |
+
|
303 |
+
start_time = float(match.group(1))
|
304 |
+
end_time = float(match.group(2))
|
305 |
+
content = match.group(3)
|
306 |
+
return start_time, end_time, content
|
307 |
+
|
308 |
+
def collect_song(data_list):
|
309 |
+
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
310 |
+
lyrics = [data[1] for data in data_list]
|
311 |
+
st_et = [data[2] for data in data_list]
|
312 |
+
paths = [data[3] for data in data_list]
|
313 |
+
return audios, lyrics, st_et
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torchaudio
|
7 |
+
from torchaudio.functional import resample
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.nn.utils.rnn import pad_sequence
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def check_lryics(lyric):
|
16 |
+
_FILTER_STRING = [
|
17 |
+
'作词', '作曲', '编曲', '【', '策划',
|
18 |
+
'录音', '混音', '母带', ':', '制作',
|
19 |
+
'版权', '校对', '演奏', '制作', '伴奏'
|
20 |
+
]
|
21 |
+
for item in _FILTER_STRING:
|
22 |
+
if item in lyric:
|
23 |
+
return True
|
24 |
+
|
25 |
+
return False
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
def process_lyrics(lines):
|
30 |
+
lyric_part = []
|
31 |
+
timestamp_part = []
|
32 |
+
|
33 |
+
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
34 |
+
|
35 |
+
for i, line in enumerate(lines):
|
36 |
+
|
37 |
+
# 删除前几行的特定信息
|
38 |
+
if i<10 and check_lryics(line):
|
39 |
+
continue
|
40 |
+
|
41 |
+
# 检查是否包含有效的时间戳和歌词内容
|
42 |
+
if timestamp_pattern.match(line):
|
43 |
+
timestamp_end = line.rfind(']')
|
44 |
+
lyrics = line[timestamp_end + 1:].strip()
|
45 |
+
timestamps = line[:timestamp_end + 1]
|
46 |
+
|
47 |
+
if ':' in lyrics:
|
48 |
+
if len(lyrics.split(":")[0]) <=5:
|
49 |
+
lyrics = "".join(lyrics.split(":")[1:])
|
50 |
+
# if lyrics: # 确保歌词部分不是空的
|
51 |
+
# lyric_part.append(lyrics)
|
52 |
+
# timestamp_part.append(timestamps)
|
53 |
+
# print(processed_lyrics)
|
54 |
+
return timestamp_part, lyric_part
|
55 |
+
|
56 |
+
def get_timestamps(timestamp_part):
|
57 |
+
|
58 |
+
# 转换为秒
|
59 |
+
|
60 |
+
timestamps = []
|
61 |
+
|
62 |
+
for line in timestamp_part:
|
63 |
+
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
64 |
+
if match:
|
65 |
+
minutes = int(match.group(1))
|
66 |
+
seconds = float(match.group(2))
|
67 |
+
millis = float(match.group(3)) if match.group(3) else 0
|
68 |
+
total_seconds = minutes * 60 + seconds + millis
|
69 |
+
timestamps.append(total_seconds)
|
70 |
+
|
71 |
+
|
72 |
+
return timestamps
|
73 |
+
|
74 |
+
def process_lyrics_lrc(lyrics):
|
75 |
+
timestamp_part, lyric_part = process_lyrics(lyrics)
|
76 |
+
# print(timestamp_part)
|
77 |
+
# print(lyric_part)
|
78 |
+
timestamps = get_timestamps(timestamp_part)
|
79 |
+
# print(timestamps)
|
80 |
+
if len(timestamps) == 0:
|
81 |
+
# print(f'{lyric_path}')
|
82 |
+
return []
|
83 |
+
|
84 |
+
slice_start = timestamps[0]
|
85 |
+
slice_start_idx = 0
|
86 |
+
|
87 |
+
output_list = []
|
88 |
+
for i in range(1, len(timestamps)):
|
89 |
+
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
90 |
+
if timestamps[i] - slice_start > 30:
|
91 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
92 |
+
|
93 |
+
slice_start = timestamps[i]
|
94 |
+
slice_start_idx = i
|
95 |
+
|
96 |
+
return output_list
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def process_lyrics_yrc(lyrics):
|
101 |
+
|
102 |
+
timestamps, lyric_part = extract_lrc(lyrics)
|
103 |
+
|
104 |
+
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
105 |
+
# import pdb; pdb.set_trace()
|
106 |
+
# print(timestamp_part)
|
107 |
+
# print(lyric_part)
|
108 |
+
# timestamps = get_timestamps(timestamp_part)
|
109 |
+
# print(timestamps)
|
110 |
+
if len(timestamps) == 0:
|
111 |
+
# print(f'{lyric_path}')
|
112 |
+
return []
|
113 |
+
|
114 |
+
slice_start = timestamps[0]
|
115 |
+
slice_start_idx = 0
|
116 |
+
|
117 |
+
output_list = []
|
118 |
+
for i in range(1, len(timestamps)):
|
119 |
+
# 如果累积时间超过30秒,则进行切分
|
120 |
+
if timestamps[i] - slice_start > 30:
|
121 |
+
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
122 |
+
|
123 |
+
slice_start = timestamps[i]
|
124 |
+
slice_start_idx = i
|
125 |
+
# import pdb; pdb.set_trace()
|
126 |
+
return output_list
|
127 |
+
|
128 |
+
def extract_lrc(lyrics):
|
129 |
+
timestamp_part, lyric_part = [], []
|
130 |
+
|
131 |
+
for i, text in enumerate(lyrics):
|
132 |
+
# 提取中括号内的内容
|
133 |
+
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
134 |
+
bracket_content = bracket_content.split(',')
|
135 |
+
# 提取小括号内的内容
|
136 |
+
parentheses_content = re.findall(r'\((.*?)\)', text)
|
137 |
+
# 提取其他内容
|
138 |
+
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
139 |
+
|
140 |
+
# 数据怎么处理?
|
141 |
+
if i<10 and check_lryics(other_content):
|
142 |
+
continue
|
143 |
+
timestamp_part.append(float(bracket_content[0])/1000)
|
144 |
+
lyric_part.append(other_content)
|
145 |
+
return timestamp_part, lyric_part
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
class WYYSongDataset(Dataset):
|
150 |
+
def __init__(self,
|
151 |
+
metadata_path:str,
|
152 |
+
sr:int = 0,
|
153 |
+
use_lang = ['en', 'zh-cn'],
|
154 |
+
num_examples = -1,
|
155 |
+
max_dur = 20,
|
156 |
+
pad_to_max= True,
|
157 |
+
):
|
158 |
+
|
159 |
+
self.sr = sr
|
160 |
+
self.use_lang = use_lang
|
161 |
+
self._load_metadata(metadata_path)
|
162 |
+
self.max_dur = max_dur
|
163 |
+
self.pad_to_max = pad_to_max
|
164 |
+
|
165 |
+
# buffer
|
166 |
+
self.lyric_buffer = {}
|
167 |
+
|
168 |
+
if(num_examples<=0):
|
169 |
+
self.dataset_len = len(self.data)
|
170 |
+
self.random_slc = False
|
171 |
+
else:
|
172 |
+
self.dataset_len = num_examples
|
173 |
+
self.random_slc = True
|
174 |
+
|
175 |
+
# 读取jsonl文件
|
176 |
+
def _load_metadata(self, metadata_path):
|
177 |
+
with open(metadata_path) as fp:
|
178 |
+
lines = fp.readlines()
|
179 |
+
self.data = []
|
180 |
+
for line in lines:
|
181 |
+
item = json.loads(line)
|
182 |
+
if '伴奏' not in item['path']:
|
183 |
+
# if "lang_type" in item and item['lang_type'] == 'en':
|
184 |
+
if "lang_type" in item:
|
185 |
+
self.data.append(item)
|
186 |
+
|
187 |
+
|
188 |
+
def __len__(self):
|
189 |
+
return self.dataset_len
|
190 |
+
|
191 |
+
|
192 |
+
def __getitem__(self, idx):
|
193 |
+
try_cnt = 0
|
194 |
+
while True:
|
195 |
+
if(self.random_slc):
|
196 |
+
idx = np.random.randint(0, len(self.data))
|
197 |
+
yrc_lyrics = []
|
198 |
+
lrc_lyrics = []
|
199 |
+
try:
|
200 |
+
info = self.data[idx]
|
201 |
+
|
202 |
+
# audio path
|
203 |
+
path = info["path"]
|
204 |
+
lang_type = info["lang_type"]
|
205 |
+
if info["lang_type"] == 'en':
|
206 |
+
lyrics = info['lyrics']
|
207 |
+
else:
|
208 |
+
lyrics = info['lyrics_phone']
|
209 |
+
|
210 |
+
# 随机选取一个lyric段落
|
211 |
+
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
212 |
+
lyric = lyrics[ly_id].strip()
|
213 |
+
|
214 |
+
st, et, lyric = self.parse_lyric(lyric)
|
215 |
+
lyric = lyric.replace("\xa0", " ")
|
216 |
+
|
217 |
+
lyric = " ".join(lyric.split())
|
218 |
+
|
219 |
+
assert et - st < self.max_dur
|
220 |
+
|
221 |
+
|
222 |
+
if info["lang_type"] == 'en':
|
223 |
+
# print(len(lyric.split())/(et-st))
|
224 |
+
assert 6 > len(lyric.split())/(et-st) > 1
|
225 |
+
else:
|
226 |
+
# print(len(lyric.split())/(et-st))
|
227 |
+
lyric = lyric.replace("-", "")
|
228 |
+
assert 6 > len(lyric.split())/(et-st) > 1
|
229 |
+
|
230 |
+
|
231 |
+
# 读取音频文件
|
232 |
+
cur_sample_rate = torchaudio.info(path).sample_rate
|
233 |
+
offset = int(cur_sample_rate*st)
|
234 |
+
num_frames = int(cur_sample_rate * (et -st))
|
235 |
+
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
236 |
+
# chunk = torch.zeros(1, 48000*15)
|
237 |
+
|
238 |
+
# 随机选取一个channel
|
239 |
+
if(chunk.shape[0]>1):
|
240 |
+
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
241 |
+
else:
|
242 |
+
chunk = chunk[[0],:].float()
|
243 |
+
|
244 |
+
if(cur_sample_rate!=self.sr):
|
245 |
+
# print('a:',cur_sample_rate,chunk.shape)
|
246 |
+
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
247 |
+
|
248 |
+
if self.pad_to_max:
|
249 |
+
chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
|
250 |
+
|
251 |
+
return chunk, lyric, et-st, path, lang_type
|
252 |
+
except:
|
253 |
+
# print("Error loadding ", info["path"])
|
254 |
+
try_cnt += 1
|
255 |
+
idx = np.random.randint(0, len(self.data))
|
256 |
+
if(try_cnt>20):
|
257 |
+
raise FileNotFoundError()
|
258 |
+
|
259 |
+
def parse_lyric(self, lyric):
|
260 |
+
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
261 |
+
match = re.search(pattern, lyric)
|
262 |
+
|
263 |
+
start_time = float(match.group(1))
|
264 |
+
end_time = float(match.group(2))
|
265 |
+
content = match.group(3)
|
266 |
+
return start_time, end_time, content
|
267 |
+
|
268 |
+
def pad_2d_tensor(self, x, max_len, pad_id):
|
269 |
+
# 获取输入 tensor 的形状
|
270 |
+
batch_size, seq_len = x.size()
|
271 |
+
max_len = max(max_len, seq_len)
|
272 |
+
# 计算需要填充的长度
|
273 |
+
pad_len = max_len - seq_len
|
274 |
+
|
275 |
+
# 如果需要填充
|
276 |
+
if pad_len > 0:
|
277 |
+
# 创建填充 tensor
|
278 |
+
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
279 |
+
|
280 |
+
# 沿第二个维度(列)连接输入 tensor 和填充 tensor
|
281 |
+
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
282 |
+
else:
|
283 |
+
# 如果不需要填充,直接返回输入 tensor
|
284 |
+
padded_tensor = x
|
285 |
+
|
286 |
+
return padded_tensor
|
287 |
+
|
288 |
+
def collect_data(data_list):
|
289 |
+
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
290 |
+
lyrics = [data[1] for data in data_list]
|
291 |
+
st_et = [data[2] for data in data_list]
|
292 |
+
paths = [data[3] for data in data_list]
|
293 |
+
lang_types = [data[4] for data in data_list]
|
294 |
+
return audios, lyrics, st_et, lang_types
|
295 |
+
# return audios, lyrics, st_et
|
296 |
+
|
297 |
+
|
298 |
+
def build_dataset():
|
299 |
+
train_dataset = WYYSongDataset(
|
300 |
+
metadata_path = "train.jsonl",
|
301 |
+
sr = 48000,
|
302 |
+
use_lang = ['zh-cn', 'en'],
|
303 |
+
num_examples = 10*10000
|
304 |
+
)
|
305 |
+
|
306 |
+
valid_dataset = WYYSongDataset(
|
307 |
+
metadata_path = "valid.jsonl",
|
308 |
+
sr = 48000,
|
309 |
+
use_lang = ['zh-cn', 'en'],
|
310 |
+
num_examples = 500
|
311 |
+
)
|
312 |
+
|
313 |
+
return train_dataset, valid_dataset
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from beartype.typing import Sequence, Callable, Optional, Dict, List
|
3 |
+
from beartype.door import is_bearable
|
4 |
+
import random
|
5 |
+
import os
|
6 |
+
from torchaudio.functional import resample
|
7 |
+
import torch
|
8 |
+
import typing as tp
|
9 |
+
from pathlib import Path
|
10 |
+
import torchaudio as ta
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import soundfile
|
13 |
+
import numpy as np
|
14 |
+
import json
|
15 |
+
import yaml
|
16 |
+
import random
|
17 |
+
import librosa
|
18 |
+
from loguru import logger
|
19 |
+
import re
|
20 |
+
|
21 |
+
|
22 |
+
def _av_read(filepath, seek_time=0, duration=None):
|
23 |
+
if duration is not None:
|
24 |
+
sr = librosa.get_samplerate(filepath)
|
25 |
+
offset = seek_time
|
26 |
+
num_samples = int(duration * sr)
|
27 |
+
wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration)
|
28 |
+
else:
|
29 |
+
wav, sr = librosa.load(filepath, sr=None, offset=seek_time)
|
30 |
+
|
31 |
+
return wav, sr
|
32 |
+
|
33 |
+
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
34 |
+
duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]:
|
35 |
+
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
filepath (str or Path): Path to audio file to read.
|
39 |
+
seek_time (float): Time at which to start reading in the file.
|
40 |
+
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
41 |
+
pad (bool): Pad output audio if not reaching expected duration.
|
42 |
+
Returns:
|
43 |
+
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
44 |
+
"""
|
45 |
+
fp = Path(filepath)
|
46 |
+
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
47 |
+
# There is some bug with ffmpeg and reading flac
|
48 |
+
info = soundfile.info(filepath)
|
49 |
+
frames = -1 if duration <= 0 else int(duration * info.samplerate)
|
50 |
+
frame_offset = int(seek_time * info.samplerate)
|
51 |
+
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
52 |
+
assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}"
|
53 |
+
wav = torch.from_numpy(wav).t().contiguous()
|
54 |
+
if len(wav.shape) == 1:
|
55 |
+
wav = torch.unsqueeze(wav, 0)
|
56 |
+
elif (
|
57 |
+
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
58 |
+
and duration <= 0 and seek_time == 0
|
59 |
+
):
|
60 |
+
# Torchaudio is faster if we load an entire file at once.
|
61 |
+
wav, sr = librosa.load(fp, sr=None, mono=True)
|
62 |
+
else:
|
63 |
+
wav, sr = _av_read(filepath, seek_time, duration)
|
64 |
+
if pad and duration > 0:
|
65 |
+
expected_frames = int(duration * sr)
|
66 |
+
wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1]))
|
67 |
+
if not isinstance(wav, torch.Tensor):
|
68 |
+
wav = torch.tensor(wav)
|
69 |
+
return wav, sr
|
70 |
+
|
71 |
+
def random_seek_read(filepath, duration):
|
72 |
+
if duration > 0:
|
73 |
+
total_duration = librosa.get_duration(path=filepath)
|
74 |
+
acceptable_start = max(0, total_duration - duration)
|
75 |
+
wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True)
|
76 |
+
else:
|
77 |
+
wav, sr = audio_read(filepath, 0, -1, pad=False)
|
78 |
+
return wav, sr
|
79 |
+
|
80 |
+
def safe_random_seek_read(filepath, duration, sample_rate):
|
81 |
+
try:
|
82 |
+
wav, sr = random_seek_read(filepath, duration)
|
83 |
+
if sr != sample_rate:
|
84 |
+
wav = resample(wav, sr, sample_rate)
|
85 |
+
sr = sample_rate
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Error reading {filepath}: {e}")
|
88 |
+
sr = sample_rate
|
89 |
+
wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32)
|
90 |
+
return wav, sr
|
91 |
+
|
92 |
+
def read_jsonlike(path: os.PathLike):
|
93 |
+
#json or jsonl
|
94 |
+
if str(path).endswith(".json"):
|
95 |
+
with open(path, 'r', encoding='utf8') as f:
|
96 |
+
data = json.load(f)
|
97 |
+
return data
|
98 |
+
elif str(path).endswith(".jsonl"):
|
99 |
+
with open(path, 'r', encoding='utf8') as f:
|
100 |
+
data = [json.loads(line) for line in f.readlines()]
|
101 |
+
return data
|
102 |
+
else:
|
103 |
+
raise ValueError("Unknown file format")
|
104 |
+
|
105 |
+
dist_prob_map = {
|
106 |
+
1: (1.0,),
|
107 |
+
2: (0.5, 0.5),
|
108 |
+
3: (0.3, 0.4, 0.3),
|
109 |
+
4: (0.2, 0.3, 0.3, 0.2),
|
110 |
+
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
111 |
+
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
112 |
+
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
113 |
+
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
114 |
+
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
115 |
+
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
116 |
+
}
|
117 |
+
|
118 |
+
dist_prob_map_low = {
|
119 |
+
1: (1.0,),
|
120 |
+
2: (0.8, 0.2),
|
121 |
+
3: (0.8, 0.1, 0.1),
|
122 |
+
4: (0.7, 0.1, 0.1, 0.1),
|
123 |
+
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
124 |
+
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
125 |
+
}
|
126 |
+
|
127 |
+
|
128 |
+
_bpm_range_rights = (
|
129 |
+
(40, '20-40'),
|
130 |
+
(60, '40-60'),
|
131 |
+
(66, '60-66'),
|
132 |
+
(76, '66-76'),
|
133 |
+
(108, '76-108'),
|
134 |
+
(120, '108-120'),
|
135 |
+
(168, '120-168'),
|
136 |
+
(176, '168-176'),
|
137 |
+
(200, '176-200')
|
138 |
+
)
|
139 |
+
_bpm_desc_map = {
|
140 |
+
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
141 |
+
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
142 |
+
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
143 |
+
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
144 |
+
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
145 |
+
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
146 |
+
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
147 |
+
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
148 |
+
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
149 |
+
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
150 |
+
}
|
151 |
+
_bpm_desc_map_zh = {
|
152 |
+
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
153 |
+
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
154 |
+
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
155 |
+
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
156 |
+
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
157 |
+
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
158 |
+
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
159 |
+
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
160 |
+
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
161 |
+
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
162 |
+
}
|
163 |
+
def get_bpm_range(bpm):
|
164 |
+
bpm = int(bpm)
|
165 |
+
for right, tag in _bpm_range_rights:
|
166 |
+
if bpm <= right:
|
167 |
+
return tag
|
168 |
+
return '>200'
|
169 |
+
|
170 |
+
def gen_bpm_descript(bpm, lang='en'):
|
171 |
+
bpm_range = get_bpm_range(bpm)
|
172 |
+
if lang == 'en':
|
173 |
+
return random.choice(_bpm_desc_map[bpm_range])
|
174 |
+
elif lang == 'zh':
|
175 |
+
return random.choice(_bpm_desc_map_zh[bpm_range])
|
176 |
+
else:
|
177 |
+
raise ValueError(f"Unknown language {lang}")
|
178 |
+
|
179 |
+
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
180 |
+
if translate is None:
|
181 |
+
return None
|
182 |
+
return {k: read_jsonlike(path) for k, path in translate.items()}
|
183 |
+
|
184 |
+
|
185 |
+
def tags_to_desc(tag_list, sep=',') -> str:
|
186 |
+
if not isinstance(tag_list, Sequence):
|
187 |
+
return str(tag_list)
|
188 |
+
if isinstance(tag_list, str):
|
189 |
+
return tag_list
|
190 |
+
if len(tag_list) <= 0:
|
191 |
+
return ''
|
192 |
+
elif len(tag_list) <= 5:
|
193 |
+
probs = dist_prob_map[len(tag_list)]
|
194 |
+
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
195 |
+
random.shuffle(tag_list)
|
196 |
+
tag_list = tag_list[:tags_num]
|
197 |
+
return sep.join(tag_list)
|
198 |
+
else:
|
199 |
+
probs = dist_prob_map[5]
|
200 |
+
tags_num = random.choices(range(1, 6), probs)[0]
|
201 |
+
random.shuffle(tag_list)
|
202 |
+
tag_list = tag_list[:tags_num]
|
203 |
+
return sep.join(tag_list)
|
204 |
+
|
205 |
+
|
206 |
+
class PromptTemplate:
|
207 |
+
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
208 |
+
self.template_text = template_text
|
209 |
+
self.tag_map = tag_map
|
210 |
+
self.lang = lang
|
211 |
+
|
212 |
+
@property
|
213 |
+
def tags(self):
|
214 |
+
return tuple(self.tag_map.keys())
|
215 |
+
|
216 |
+
def apply(self, **kwargs):
|
217 |
+
for tag in list(kwargs.keys()):
|
218 |
+
if kwargs[tag] == '':
|
219 |
+
kwargs.pop(tag)
|
220 |
+
for tag in self.tags:
|
221 |
+
if tag in kwargs:
|
222 |
+
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
223 |
+
else:
|
224 |
+
kwargs[tag] = ''
|
225 |
+
prompt = self.template_text.format(**kwargs)
|
226 |
+
|
227 |
+
return self.beautify(prompt)
|
228 |
+
|
229 |
+
def beautify(self, text):
|
230 |
+
if self.lang == 'en':
|
231 |
+
return self._beautify_en(text)
|
232 |
+
elif self.lang == 'zh':
|
233 |
+
return self._beautify_zh(text)
|
234 |
+
else:
|
235 |
+
raise ValueError(f'Unknown language {self.lang}')
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def _beautify_en(text):
|
239 |
+
# no continuous commas without content between them
|
240 |
+
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
241 |
+
# no continuous whitespace
|
242 |
+
text = re.sub(r'\s+', ' ', text)
|
243 |
+
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
244 |
+
text = re.sub(r'\s+,', r',', text)
|
245 |
+
text = re.sub(r',\s+', r', ', text)
|
246 |
+
# no whitespace before the full stop
|
247 |
+
text = re.sub(r'\s+\.', r'.', text)
|
248 |
+
# strip whitespace, comma, and replace ',.'
|
249 |
+
text = text.strip(' ,')
|
250 |
+
text = text.replace(',.', '.')
|
251 |
+
return text
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def _beautify_zh(text):
|
255 |
+
# no continuous commas without content between them
|
256 |
+
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
257 |
+
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
258 |
+
# assume there should be NO whitespace in Chinese
|
259 |
+
text = re.sub(r'\s+', r'', text)
|
260 |
+
# strip whitespace, comma, and replace ',。'
|
261 |
+
text = text.strip(', 、')
|
262 |
+
text = text.replace(',。', '。')
|
263 |
+
return text
|
264 |
+
|
265 |
+
def __repr__(self):
|
266 |
+
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
267 |
+
|
268 |
+
__str__ = __repr__
|
269 |
+
|
270 |
+
def parse_prompt_template(prompt_template_text, lang='en'):
|
271 |
+
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
272 |
+
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
273 |
+
|
274 |
+
template_text = prompt_template_text.strip()
|
275 |
+
span_texts = span_pattern.findall(prompt_template_text)
|
276 |
+
tag_map = {}
|
277 |
+
for span_text in span_texts:
|
278 |
+
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
279 |
+
tag_map[tag] = span_text
|
280 |
+
template_text = template_text.replace(span_text, '{'+tag+'}')
|
281 |
+
|
282 |
+
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
283 |
+
|
284 |
+
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
285 |
+
with open(path, 'r') as f:
|
286 |
+
lines = f.readlines()
|
287 |
+
cnt = 0
|
288 |
+
pts = []
|
289 |
+
for line in lines:
|
290 |
+
pt = parse_prompt_template(line, lang=lang)
|
291 |
+
cnt += 1
|
292 |
+
if len(pt.tags) < num:
|
293 |
+
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
294 |
+
pts.append(pt)
|
295 |
+
|
296 |
+
return pts
|
297 |
+
|
298 |
+
|
299 |
+
class AudioStockDataset(Dataset):
|
300 |
+
def __init__(self,
|
301 |
+
num_examples:int,
|
302 |
+
metadata_path:str,
|
303 |
+
duration:float=60,
|
304 |
+
sr:int = 0,
|
305 |
+
return_path = False,
|
306 |
+
return_audio = True,
|
307 |
+
prompt_template_path: os.PathLike = None,
|
308 |
+
tag_types = [],
|
309 |
+
lang = 'en',
|
310 |
+
translate:Optional[Dict[str, os.PathLike]] = None
|
311 |
+
):
|
312 |
+
self.duration = duration
|
313 |
+
self.MAX_DURATION = 360
|
314 |
+
self._load_metadata(metadata_path)
|
315 |
+
if num_examples > 0:
|
316 |
+
self.random_choose = True
|
317 |
+
self.dataset_len = num_examples
|
318 |
+
else:
|
319 |
+
self.random_choose = False
|
320 |
+
self.dataset_len = len(self.data)
|
321 |
+
self.sr = sr
|
322 |
+
self.return_path = return_path
|
323 |
+
self.return_audio = return_audio
|
324 |
+
|
325 |
+
self.use_dynamic_prompt = prompt_template_path is not None
|
326 |
+
if self.use_dynamic_prompt:
|
327 |
+
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
328 |
+
self.tag_types = tag_types
|
329 |
+
|
330 |
+
self.lang = lang
|
331 |
+
self.translate = read_translate(translate)
|
332 |
+
|
333 |
+
def _load_metadata(self, metadata_path):
|
334 |
+
total_len = 0; valid_len = 0
|
335 |
+
with open(metadata_path) as fp:
|
336 |
+
lines = fp.readlines()
|
337 |
+
self.data = []
|
338 |
+
for line in lines:
|
339 |
+
item = json.loads(line)
|
340 |
+
total_len += 1
|
341 |
+
if(item['duration']>self.duration and item['duration']<self.MAX_DURATION):
|
342 |
+
valid_len += 1
|
343 |
+
self.data.append(item)
|
344 |
+
print("Filter data from {} to {}".format(total_len, valid_len))
|
345 |
+
self.is_info_recorded = bool('Tags' in self.data[0])
|
346 |
+
|
347 |
+
def __len__(self):
|
348 |
+
return self.dataset_len
|
349 |
+
|
350 |
+
def __getitem__(self, idx):
|
351 |
+
first_try = True
|
352 |
+
try_cnt = 0
|
353 |
+
while True:
|
354 |
+
try:
|
355 |
+
if(self.random_choose or not first_try):
|
356 |
+
index2 = np.random.randint(0,len(self.data))
|
357 |
+
else:
|
358 |
+
index2 = idx
|
359 |
+
first_try = False
|
360 |
+
return self.getitem_main(index2)
|
361 |
+
except:
|
362 |
+
print("Error loadding ", self.data[idx]["path"])
|
363 |
+
try_cnt += 1
|
364 |
+
if(try_cnt>10):
|
365 |
+
raise ValueError()
|
366 |
+
|
367 |
+
def getitem_main(self, idx):
|
368 |
+
path:str = self.data[idx]["path"]
|
369 |
+
json_path = path[:path.rfind('.')] + ".json"
|
370 |
+
if self.is_info_recorded:
|
371 |
+
item = self.data[idx]
|
372 |
+
else:
|
373 |
+
with open(json_path) as fp:
|
374 |
+
item:dict = json.load(fp)
|
375 |
+
description = self.generate_description(item)
|
376 |
+
if self.return_audio:
|
377 |
+
audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr)
|
378 |
+
else:
|
379 |
+
audio = None
|
380 |
+
if self.return_path:
|
381 |
+
return audio, description, path
|
382 |
+
return audio, description
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
def generate_description(self, item):
|
387 |
+
if self.use_dynamic_prompt:
|
388 |
+
# dynamically generate prompt from given prompt template
|
389 |
+
prompt_template = random.choice(self.prompt_templates)
|
390 |
+
description = self.generate_description_dynamic(item, prompt_template)
|
391 |
+
else:
|
392 |
+
# use ordinary static prompt instead
|
393 |
+
description = self.generate_description_ordinary(item)
|
394 |
+
return description
|
395 |
+
|
396 |
+
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
397 |
+
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
398 |
+
|
399 |
+
if len(exists_tag) > 0:
|
400 |
+
probs = dist_prob_map[len(exists_tag)]
|
401 |
+
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
402 |
+
random.shuffle(exists_tag)
|
403 |
+
tags = exists_tag[:tags_num]
|
404 |
+
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
405 |
+
tags_args = self.handle_BPM_tag(tags_args)
|
406 |
+
prompt = prompt_template.apply(**tags_args)
|
407 |
+
else:
|
408 |
+
# no strong tags, use all weak tags instead
|
409 |
+
prompt = prompt_template.apply()
|
410 |
+
|
411 |
+
return prompt
|
412 |
+
|
413 |
+
def tags_to_desc(self, tag_list, tag_type) -> str:
|
414 |
+
if self.lang == 'en':
|
415 |
+
return tags_to_desc(tag_list)
|
416 |
+
elif self.lang == 'zh':
|
417 |
+
if tag_type == 'BPM':
|
418 |
+
return tags_to_desc(tag_list, sep='、')
|
419 |
+
translator = self.translate[tag_type]
|
420 |
+
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
421 |
+
return tags_to_desc(translated_tag_list, sep='、')
|
422 |
+
|
423 |
+
def handle_BPM_tag(self, tags_args):
|
424 |
+
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
425 |
+
bpm = tags_args["BPM"]
|
426 |
+
del tags_args["BPM"]
|
427 |
+
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
428 |
+
for tag_type in tag_types_used:
|
429 |
+
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
430 |
+
return tags_args
|
431 |
+
|
432 |
+
def generate_description_ordinary(self, data, thresh = 0.3):
|
433 |
+
if self.lang != 'en':
|
434 |
+
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
435 |
+
description = f'a piece of music by {data["Artist"]}'
|
436 |
+
|
437 |
+
# Add genre if available
|
438 |
+
if data["Genre"] and random.random() > thresh:
|
439 |
+
genres = ', '.join(data["Genre"])
|
440 |
+
description += f', belonging to the {genres} genres'
|
441 |
+
|
442 |
+
# Add moods if available
|
443 |
+
if data["Tags"] and random.random() > thresh:
|
444 |
+
tags = ', '.join(data["Tags"])
|
445 |
+
description += f'. This track contains the tags:{tags}'
|
446 |
+
|
447 |
+
# Add moods if available
|
448 |
+
if data["Mood"] and random.random() > thresh:
|
449 |
+
moods = ', '.join(data["Mood"])
|
450 |
+
description += f'. This track conveys a {moods} mood.'
|
451 |
+
|
452 |
+
# Add instruments if available
|
453 |
+
if data["Instrument"] and random.random() > thresh:
|
454 |
+
instruments = ', '.join(data["Instrument"])
|
455 |
+
description += f'. and primarily features the following instruments: {instruments}'
|
456 |
+
|
457 |
+
# Add a period to end the description
|
458 |
+
description += '.'
|
459 |
+
|
460 |
+
return description
|
461 |
+
|
codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
|
3 |
+
Code adapted from Jax version in Appendix A.1
|
4 |
+
"""
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
from functools import wraps, partial
|
8 |
+
from contextlib import nullcontext
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import Module
|
14 |
+
from torch import Tensor, int32
|
15 |
+
from torch.amp import autocast
|
16 |
+
|
17 |
+
from einops import rearrange, pack, unpack
|
18 |
+
|
19 |
+
# helper functions
|
20 |
+
|
21 |
+
def exists(v):
|
22 |
+
return v is not None
|
23 |
+
|
24 |
+
def default(*args):
|
25 |
+
for arg in args:
|
26 |
+
if exists(arg):
|
27 |
+
return arg
|
28 |
+
return None
|
29 |
+
|
30 |
+
def maybe(fn):
|
31 |
+
@wraps(fn)
|
32 |
+
def inner(x, *args, **kwargs):
|
33 |
+
if not exists(x):
|
34 |
+
return x
|
35 |
+
return fn(x, *args, **kwargs)
|
36 |
+
return inner
|
37 |
+
|
38 |
+
def pack_one(t, pattern):
|
39 |
+
return pack([t], pattern)
|
40 |
+
|
41 |
+
def unpack_one(t, ps, pattern):
|
42 |
+
return unpack(t, ps, pattern)[0]
|
43 |
+
|
44 |
+
# tensor helpers
|
45 |
+
|
46 |
+
def round_ste(z: Tensor) -> Tensor:
|
47 |
+
"""Round with straight through gradients."""
|
48 |
+
zhat = z.round()
|
49 |
+
return z + (zhat - z).detach()
|
50 |
+
|
51 |
+
# main class
|
52 |
+
|
53 |
+
class FSQ(Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
levels: List[int],
|
57 |
+
dim: int | None = None,
|
58 |
+
num_codebooks = 1,
|
59 |
+
keep_num_codebooks_dim: bool | None = None,
|
60 |
+
scale: float | None = None,
|
61 |
+
allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64),
|
62 |
+
channel_first: bool = False,
|
63 |
+
projection_has_bias: bool = True,
|
64 |
+
return_indices = True,
|
65 |
+
force_quantization_f32 = True
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
_levels = torch.tensor(levels, dtype=int32)
|
69 |
+
self.register_buffer("_levels", _levels, persistent = False)
|
70 |
+
|
71 |
+
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
|
72 |
+
self.register_buffer("_basis", _basis, persistent = False)
|
73 |
+
|
74 |
+
self.scale = scale
|
75 |
+
|
76 |
+
codebook_dim = len(levels)
|
77 |
+
self.codebook_dim = codebook_dim
|
78 |
+
|
79 |
+
effective_codebook_dim = codebook_dim * num_codebooks
|
80 |
+
self.num_codebooks = num_codebooks
|
81 |
+
self.effective_codebook_dim = effective_codebook_dim
|
82 |
+
|
83 |
+
keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
|
84 |
+
assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
|
85 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
86 |
+
|
87 |
+
self.dim = default(dim, len(_levels) * num_codebooks)
|
88 |
+
|
89 |
+
self.channel_first = channel_first
|
90 |
+
|
91 |
+
has_projections = self.dim != effective_codebook_dim
|
92 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
93 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity()
|
94 |
+
|
95 |
+
self.has_projections = has_projections
|
96 |
+
|
97 |
+
self.return_indices = return_indices
|
98 |
+
if return_indices:
|
99 |
+
self.codebook_size = self._levels.prod().item()
|
100 |
+
implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size))
|
101 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
102 |
+
|
103 |
+
self.allowed_dtypes = allowed_dtypes
|
104 |
+
self.force_quantization_f32 = force_quantization_f32
|
105 |
+
|
106 |
+
def bound(self, z, eps: float = 1e-3):
|
107 |
+
""" Bound `z`, an array of shape (..., d). """
|
108 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
109 |
+
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
|
110 |
+
shift = (offset / half_l).atanh()
|
111 |
+
return (z + shift).tanh() * half_l - offset
|
112 |
+
|
113 |
+
def quantize(self, z):
|
114 |
+
""" Quantizes z, returns quantized zhat, same shape as z. """
|
115 |
+
quantized = round_ste(self.bound(z))
|
116 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
117 |
+
return quantized / half_width
|
118 |
+
|
119 |
+
def _scale_and_shift(self, zhat_normalized):
|
120 |
+
half_width = self._levels // 2
|
121 |
+
return (zhat_normalized * half_width) + half_width
|
122 |
+
|
123 |
+
def _scale_and_shift_inverse(self, zhat):
|
124 |
+
half_width = self._levels // 2
|
125 |
+
return (zhat - half_width) / half_width
|
126 |
+
|
127 |
+
def _indices_to_codes(self, indices):
|
128 |
+
level_indices = self.indices_to_level_indices(indices)
|
129 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
130 |
+
return codes
|
131 |
+
|
132 |
+
def codes_to_indices(self, zhat):
|
133 |
+
""" Converts a `code` to an index in the codebook. """
|
134 |
+
assert zhat.shape[-1] == self.codebook_dim
|
135 |
+
zhat = self._scale_and_shift(zhat)
|
136 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
137 |
+
|
138 |
+
def indices_to_level_indices(self, indices):
|
139 |
+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
|
140 |
+
indices = rearrange(indices, '... -> ... 1')
|
141 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
142 |
+
return codes_non_centered
|
143 |
+
|
144 |
+
def indices_to_codes(self, indices):
|
145 |
+
""" Inverse of `codes_to_indices`. """
|
146 |
+
assert exists(indices)
|
147 |
+
|
148 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
149 |
+
|
150 |
+
codes = self._indices_to_codes(indices)
|
151 |
+
|
152 |
+
if self.keep_num_codebooks_dim:
|
153 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
154 |
+
|
155 |
+
codes = self.project_out(codes)
|
156 |
+
|
157 |
+
if is_img_or_video or self.channel_first:
|
158 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
159 |
+
|
160 |
+
return codes
|
161 |
+
|
162 |
+
def forward(self, z):
|
163 |
+
"""
|
164 |
+
einstein notation
|
165 |
+
b - batch
|
166 |
+
n - sequence (or flattened spatial dimensions)
|
167 |
+
d - feature dimension
|
168 |
+
c - number of codebook dim
|
169 |
+
"""
|
170 |
+
|
171 |
+
is_img_or_video = z.ndim >= 4
|
172 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
173 |
+
|
174 |
+
# standardize image or video into (batch, seq, dimension)
|
175 |
+
|
176 |
+
if need_move_channel_last:
|
177 |
+
z = rearrange(z, 'b d ... -> b ... d')
|
178 |
+
z, ps = pack_one(z, 'b * d')
|
179 |
+
|
180 |
+
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
181 |
+
|
182 |
+
z = self.project_in(z)
|
183 |
+
|
184 |
+
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
185 |
+
|
186 |
+
# whether to force quantization step to be full precision or not
|
187 |
+
|
188 |
+
force_f32 = self.force_quantization_f32
|
189 |
+
quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
|
190 |
+
|
191 |
+
with quantization_context():
|
192 |
+
orig_dtype = z.dtype
|
193 |
+
|
194 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
195 |
+
z = z.float()
|
196 |
+
|
197 |
+
codes = self.quantize(z)
|
198 |
+
|
199 |
+
# returning indices could be optional
|
200 |
+
|
201 |
+
indices = None
|
202 |
+
|
203 |
+
if self.return_indices:
|
204 |
+
indices = self.codes_to_indices(codes)
|
205 |
+
|
206 |
+
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
207 |
+
|
208 |
+
codes = codes.type(orig_dtype)
|
209 |
+
|
210 |
+
# project out
|
211 |
+
|
212 |
+
out = self.project_out(codes)
|
213 |
+
|
214 |
+
# reconstitute image or video dimensions
|
215 |
+
|
216 |
+
if need_move_channel_last:
|
217 |
+
out = unpack_one(out, ps, 'b * d')
|
218 |
+
out = rearrange(out, 'b ... d -> b d ...')
|
219 |
+
|
220 |
+
indices = maybe(unpack_one)(indices, ps, 'b * c')
|
221 |
+
|
222 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
223 |
+
indices = maybe(rearrange)(indices, '... 1 -> ...')
|
224 |
+
|
225 |
+
# return quantized output and indices
|
226 |
+
|
227 |
+
return out, indices
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == '__main__':
|
231 |
+
# test
|
232 |
+
fsq = FSQ([4, 4, 4],dim=1024)
|
233 |
+
z = torch.randn(2, 3, 1024)
|
234 |
+
out, indices = fsq(z)
|
235 |
+
print(out.shape, indices.shape)
|
236 |
+
# print(out, indices)
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# This implementation is inspired from
|
8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
9 |
+
# which is released under MIT License. Hereafter, the original license:
|
10 |
+
# MIT License
|
11 |
+
#
|
12 |
+
# Copyright (c) 2020 Phil Wang
|
13 |
+
#
|
14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
16 |
+
# in the Software without restriction, including without limitation the rights
|
17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
19 |
+
# furnished to do so, subject to the following conditions:
|
20 |
+
#
|
21 |
+
# The above copyright notice and this permission notice shall be included in all
|
22 |
+
# copies or substantial portions of the Software.
|
23 |
+
#
|
24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
30 |
+
# SOFTWARE.
|
31 |
+
|
32 |
+
"""Core vector quantization implementation."""
|
33 |
+
|
34 |
+
import typing as tp
|
35 |
+
|
36 |
+
from einops import rearrange, repeat
|
37 |
+
import torch
|
38 |
+
from torch import nn
|
39 |
+
import torch.nn.functional as F
|
40 |
+
|
41 |
+
# from .. import distrib
|
42 |
+
|
43 |
+
|
44 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
45 |
+
return val if val is not None else d
|
46 |
+
|
47 |
+
|
48 |
+
def ema_inplace(moving_avg, new, decay: float):
|
49 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
50 |
+
|
51 |
+
|
52 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
53 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
54 |
+
|
55 |
+
|
56 |
+
def uniform_init(*shape: int):
|
57 |
+
t = torch.empty(shape)
|
58 |
+
nn.init.kaiming_uniform_(t)
|
59 |
+
return t
|
60 |
+
|
61 |
+
|
62 |
+
def sample_vectors(samples, num: int):
|
63 |
+
num_samples, device = samples.shape[0], samples.device
|
64 |
+
|
65 |
+
if num_samples >= num:
|
66 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
67 |
+
else:
|
68 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
69 |
+
|
70 |
+
return samples[indices]
|
71 |
+
|
72 |
+
|
73 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
74 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
75 |
+
|
76 |
+
means = sample_vectors(samples, num_clusters)
|
77 |
+
|
78 |
+
for _ in range(num_iters):
|
79 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
80 |
+
means, "c d -> () c d"
|
81 |
+
)
|
82 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
83 |
+
|
84 |
+
buckets = dists.max(dim=-1).indices
|
85 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
86 |
+
zero_mask = bins == 0
|
87 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
88 |
+
|
89 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
90 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
91 |
+
new_means = new_means / bins_min_clamped[..., None]
|
92 |
+
|
93 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
94 |
+
|
95 |
+
return means, bins
|
96 |
+
|
97 |
+
|
98 |
+
class EuclideanCodebook(nn.Module):
|
99 |
+
"""Codebook with Euclidean distance.
|
100 |
+
Args:
|
101 |
+
dim (int): Dimension.
|
102 |
+
codebook_size (int): Codebook size.
|
103 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
104 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
105 |
+
the learned centroids as initialization.
|
106 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
107 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
108 |
+
epsilon (float): Epsilon value for numerical stability.
|
109 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
110 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
111 |
+
randomly selected vector from the current batch.
|
112 |
+
"""
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
dim: int,
|
116 |
+
codebook_size: int,
|
117 |
+
kmeans_init: int = False,
|
118 |
+
kmeans_iters: int = 10,
|
119 |
+
decay: float = 0.99,
|
120 |
+
epsilon: float = 1e-5,
|
121 |
+
threshold_ema_dead_code: int = 2,
|
122 |
+
):
|
123 |
+
super().__init__()
|
124 |
+
self.decay = decay
|
125 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
126 |
+
embed = init_fn(codebook_size, dim)
|
127 |
+
|
128 |
+
self.codebook_size = codebook_size
|
129 |
+
|
130 |
+
self.kmeans_iters = kmeans_iters
|
131 |
+
self.epsilon = epsilon
|
132 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
133 |
+
|
134 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
135 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
136 |
+
self.register_buffer("embed", embed)
|
137 |
+
self.register_buffer("embed_avg", embed.clone())
|
138 |
+
|
139 |
+
self.runed_steps = 0
|
140 |
+
self.stop_steps = 50_000
|
141 |
+
|
142 |
+
@torch.jit.ignore
|
143 |
+
def init_embed_(self, data):
|
144 |
+
if self.inited:
|
145 |
+
return
|
146 |
+
|
147 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
148 |
+
self.embed.data.copy_(embed)
|
149 |
+
self.embed_avg.data.copy_(embed.clone())
|
150 |
+
self.cluster_size.data.copy_(cluster_size)
|
151 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
152 |
+
# Make sure all buffers across workers are in sync after initialization
|
153 |
+
distrib.broadcast_tensors(self.buffers())
|
154 |
+
|
155 |
+
def replace_(self, samples, mask):
|
156 |
+
modified_codebook = torch.where(
|
157 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
158 |
+
)
|
159 |
+
self.embed.data.copy_(modified_codebook)
|
160 |
+
|
161 |
+
def expire_codes_(self, batch_samples):
|
162 |
+
if self.threshold_ema_dead_code == 0:
|
163 |
+
return
|
164 |
+
|
165 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
166 |
+
if not torch.any(expired_codes):
|
167 |
+
return
|
168 |
+
|
169 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
170 |
+
self.replace_(batch_samples, mask=expired_codes)
|
171 |
+
# distrib.broadcast_tensors(self.buffers())
|
172 |
+
|
173 |
+
def preprocess(self, x):
|
174 |
+
x = rearrange(x, "... d -> (...) d")
|
175 |
+
return x
|
176 |
+
|
177 |
+
def quantize(self, x):
|
178 |
+
embed = self.embed.t()
|
179 |
+
dist = -(
|
180 |
+
x.pow(2).sum(1, keepdim=True)
|
181 |
+
- 2 * x @ embed
|
182 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
183 |
+
)
|
184 |
+
embed_ind = dist.max(dim=-1).indices
|
185 |
+
return embed_ind
|
186 |
+
|
187 |
+
def postprocess_emb(self, embed_ind, shape):
|
188 |
+
return embed_ind.view(*shape[:-1])
|
189 |
+
|
190 |
+
def dequantize(self, embed_ind):
|
191 |
+
quantize = F.embedding(embed_ind, self.embed)
|
192 |
+
return quantize
|
193 |
+
|
194 |
+
def encode(self, x):
|
195 |
+
shape = x.shape
|
196 |
+
# pre-process
|
197 |
+
x = self.preprocess(x)
|
198 |
+
# quantize
|
199 |
+
embed_ind = self.quantize(x)
|
200 |
+
# post-process
|
201 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
202 |
+
return embed_ind
|
203 |
+
|
204 |
+
def decode(self, embed_ind):
|
205 |
+
quantize = self.dequantize(embed_ind)
|
206 |
+
return quantize
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
shape, dtype = x.shape, x.dtype
|
210 |
+
x = self.preprocess(x)
|
211 |
+
# self.init_embed_(x)
|
212 |
+
|
213 |
+
embed_ind = self.quantize(x)
|
214 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
215 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
216 |
+
quantize = self.dequantize(embed_ind)
|
217 |
+
self.runed_steps += 1
|
218 |
+
|
219 |
+
if self.training and self.runed_steps < self.stop_steps:
|
220 |
+
# We do the expiry of code at that point as buffers are in sync
|
221 |
+
# and all the workers will take the same decision.
|
222 |
+
self.expire_codes_(x)
|
223 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
224 |
+
embed_sum = x.t() @ embed_onehot
|
225 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
226 |
+
cluster_size = (
|
227 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
228 |
+
* self.cluster_size.sum()
|
229 |
+
)
|
230 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
231 |
+
self.embed.data.copy_(embed_normalized)
|
232 |
+
|
233 |
+
return quantize, embed_ind
|
234 |
+
|
235 |
+
|
236 |
+
class VectorQuantization(nn.Module):
|
237 |
+
"""Vector quantization implementation.
|
238 |
+
Currently supports only euclidean distance.
|
239 |
+
Args:
|
240 |
+
dim (int): Dimension
|
241 |
+
codebook_size (int): Codebook size
|
242 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
243 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
244 |
+
epsilon (float): Epsilon value for numerical stability.
|
245 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
246 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
247 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
248 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
249 |
+
randomly selected vector from the current batch.
|
250 |
+
commitment_weight (float): Weight for commitment loss.
|
251 |
+
"""
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
dim: int,
|
255 |
+
codebook_size: int,
|
256 |
+
codebook_dim: tp.Optional[int] = None,
|
257 |
+
decay: float = 0.99,
|
258 |
+
epsilon: float = 1e-5,
|
259 |
+
kmeans_init: bool = True,
|
260 |
+
kmeans_iters: int = 50,
|
261 |
+
threshold_ema_dead_code: int = 2,
|
262 |
+
commitment_weight: float = 1.,
|
263 |
+
):
|
264 |
+
super().__init__()
|
265 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
266 |
+
|
267 |
+
requires_projection = _codebook_dim != dim
|
268 |
+
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
269 |
+
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
270 |
+
|
271 |
+
self.epsilon = epsilon
|
272 |
+
self.commitment_weight = commitment_weight
|
273 |
+
|
274 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
275 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
276 |
+
decay=decay, epsilon=epsilon,
|
277 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
278 |
+
self.codebook_size = codebook_size
|
279 |
+
|
280 |
+
@property
|
281 |
+
def codebook(self):
|
282 |
+
return self._codebook.embed
|
283 |
+
|
284 |
+
def encode(self, x):
|
285 |
+
x = rearrange(x, "b d n -> b n d")
|
286 |
+
x = self.project_in(x)
|
287 |
+
embed_in = self._codebook.encode(x)
|
288 |
+
return embed_in
|
289 |
+
|
290 |
+
def decode(self, embed_ind):
|
291 |
+
quantize = self._codebook.decode(embed_ind)
|
292 |
+
quantize = self.project_out(quantize)
|
293 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
294 |
+
return quantize
|
295 |
+
|
296 |
+
def forward(self, x, do_debug=False):
|
297 |
+
device = x.device
|
298 |
+
x = rearrange(x, "b d n -> b n d")
|
299 |
+
x = self.project_in(x)
|
300 |
+
|
301 |
+
quantize, embed_ind = self._codebook(x)
|
302 |
+
|
303 |
+
if self.training:
|
304 |
+
quantize = x + (quantize - x).detach()
|
305 |
+
|
306 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
307 |
+
|
308 |
+
if self.training:
|
309 |
+
if self.commitment_weight > 0:
|
310 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
311 |
+
loss = loss + commit_loss * self.commitment_weight
|
312 |
+
quantize = self.project_out(quantize)
|
313 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
314 |
+
return quantize, embed_ind, loss
|
315 |
+
|
316 |
+
|
317 |
+
class ResidualVectorQuantization(nn.Module):
|
318 |
+
"""Residual vector quantization implementation.
|
319 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
320 |
+
"""
|
321 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
322 |
+
super().__init__()
|
323 |
+
self.layers = nn.ModuleList(
|
324 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
325 |
+
)
|
326 |
+
|
327 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
328 |
+
quantized_out = 0.0
|
329 |
+
residual = x
|
330 |
+
|
331 |
+
all_losses = []
|
332 |
+
all_indices = []
|
333 |
+
|
334 |
+
n_q = n_q or len(self.layers)
|
335 |
+
|
336 |
+
for layerinx, layer in enumerate(self.layers[:n_q]):
|
337 |
+
print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.))
|
338 |
+
quantized, indices, loss = layer(residual)
|
339 |
+
residual = residual - quantized
|
340 |
+
quantized_out = quantized_out + quantized
|
341 |
+
|
342 |
+
all_indices.append(indices)
|
343 |
+
all_losses.append(loss)
|
344 |
+
|
345 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
346 |
+
return quantized_out, out_indices, out_losses
|
347 |
+
|
348 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
349 |
+
residual = x
|
350 |
+
all_indices = []
|
351 |
+
n_q = n_q or len(self.layers)
|
352 |
+
for layer in self.layers[:n_q]:
|
353 |
+
indices = layer.encode(residual)
|
354 |
+
quantized = layer.decode(indices)
|
355 |
+
residual = residual - quantized
|
356 |
+
all_indices.append(indices)
|
357 |
+
out_indices = torch.stack(all_indices)
|
358 |
+
return out_indices
|
359 |
+
|
360 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
361 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
362 |
+
for i, indices in enumerate(q_indices):
|
363 |
+
layer = self.layers[i]
|
364 |
+
quantized = layer.decode(indices)
|
365 |
+
quantized_out = quantized_out + quantized
|
366 |
+
return quantized_out
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
def WNConv1d(*args, **kwargs):
|
11 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""
|
15 |
+
Implementation of VQ similar to Karpathy's repo:
|
16 |
+
https://github.com/karpathy/deep-vector-quantization
|
17 |
+
Additionally uses following tricks from Improved VQGAN
|
18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
20 |
+
for improved codebook usage
|
21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
22 |
+
improves training stability
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
26 |
+
super().__init__()
|
27 |
+
self.codebook_size = codebook_size
|
28 |
+
self.codebook_dim = codebook_dim
|
29 |
+
|
30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
33 |
+
|
34 |
+
def forward(self, z):
|
35 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
36 |
+
the corresponding codebook vectors
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
z : Tensor[B x D x T]
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
Tensor[B x D x T]
|
45 |
+
Quantized continuous representation of input
|
46 |
+
Tensor[1]
|
47 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
48 |
+
entries
|
49 |
+
Tensor[1]
|
50 |
+
Codebook loss to update the codebook
|
51 |
+
Tensor[B x T]
|
52 |
+
Codebook indices (quantized discrete representation of input)
|
53 |
+
Tensor[B x D x T]
|
54 |
+
Projected latents (continuous representation of input before quantization)
|
55 |
+
"""
|
56 |
+
|
57 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
58 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
59 |
+
z_q, indices = self.decode_latents(z_e)
|
60 |
+
|
61 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
62 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
63 |
+
|
64 |
+
z_q = (
|
65 |
+
z_e + (z_q - z_e).detach()
|
66 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
67 |
+
|
68 |
+
z_q = self.out_proj(z_q)
|
69 |
+
|
70 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
71 |
+
|
72 |
+
def embed_code(self, embed_id):
|
73 |
+
return F.embedding(embed_id, self.codebook.weight)
|
74 |
+
|
75 |
+
def decode_code(self, embed_id):
|
76 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
77 |
+
|
78 |
+
def decode_latents(self, latents):
|
79 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
80 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
81 |
+
|
82 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
83 |
+
encodings = F.normalize(encodings)
|
84 |
+
codebook = F.normalize(codebook)
|
85 |
+
|
86 |
+
# Compute euclidean distance with codebook
|
87 |
+
dist = (
|
88 |
+
encodings.pow(2).sum(1, keepdim=True)
|
89 |
+
- 2 * encodings @ codebook.t()
|
90 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
91 |
+
)
|
92 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
93 |
+
z_q = self.decode_code(indices)
|
94 |
+
return z_q, indices
|
95 |
+
|
96 |
+
|
97 |
+
class ResidualVectorQuantize(nn.Module):
|
98 |
+
"""
|
99 |
+
Introduced in SoundStream: An end2end neural audio codec
|
100 |
+
https://arxiv.org/abs/2107.03312
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
input_dim: int = 512,
|
106 |
+
n_codebooks: int = 9,
|
107 |
+
codebook_size: int = 1024,
|
108 |
+
codebook_dim: Union[int, list] = 8,
|
109 |
+
quantizer_dropout: float = 0.0,
|
110 |
+
):
|
111 |
+
super().__init__()
|
112 |
+
if isinstance(codebook_dim, int):
|
113 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
114 |
+
|
115 |
+
self.n_codebooks = n_codebooks
|
116 |
+
self.codebook_dim = codebook_dim
|
117 |
+
self.codebook_size = codebook_size
|
118 |
+
|
119 |
+
self.quantizers = nn.ModuleList(
|
120 |
+
[
|
121 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i])
|
122 |
+
for i in range(n_codebooks)
|
123 |
+
]
|
124 |
+
)
|
125 |
+
self.quantizer_dropout = quantizer_dropout
|
126 |
+
|
127 |
+
def forward(self, z, n_quantizers: int = None):
|
128 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
129 |
+
the corresponding codebook vectors
|
130 |
+
Parameters
|
131 |
+
----------
|
132 |
+
z : Tensor[B x D x T]
|
133 |
+
n_quantizers : int, optional
|
134 |
+
No. of quantizers to use
|
135 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
136 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
137 |
+
when in training mode, and a random number of quantizers is used.
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
dict
|
141 |
+
A dictionary with the following keys:
|
142 |
+
|
143 |
+
"z" : Tensor[B x D x T]
|
144 |
+
Quantized continuous representation of input
|
145 |
+
"codes" : Tensor[B x N x T]
|
146 |
+
Codebook indices for each codebook
|
147 |
+
(quantized discrete representation of input)
|
148 |
+
"latents" : Tensor[B x N*D x T]
|
149 |
+
Projected latents (continuous representation of input before quantization)
|
150 |
+
"vq/commitment_loss" : Tensor[1]
|
151 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
152 |
+
entries
|
153 |
+
"vq/codebook_loss" : Tensor[1]
|
154 |
+
Codebook loss to update the codebook
|
155 |
+
"""
|
156 |
+
z_q = 0
|
157 |
+
residual = z
|
158 |
+
commitment_loss = 0
|
159 |
+
codebook_loss = 0
|
160 |
+
|
161 |
+
codebook_indices = []
|
162 |
+
latents = []
|
163 |
+
|
164 |
+
if n_quantizers is None:
|
165 |
+
n_quantizers = self.n_codebooks
|
166 |
+
if self.training:
|
167 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
168 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
169 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
170 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
171 |
+
n_quantizers = n_quantizers.to(z.device)
|
172 |
+
|
173 |
+
for i, quantizer in enumerate(self.quantizers):
|
174 |
+
if self.training is False and i >= n_quantizers:
|
175 |
+
break
|
176 |
+
|
177 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
178 |
+
residual
|
179 |
+
)
|
180 |
+
|
181 |
+
# Create mask to apply quantizer dropout
|
182 |
+
mask = (
|
183 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
184 |
+
)
|
185 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
186 |
+
residual = residual - z_q_i
|
187 |
+
|
188 |
+
# Sum losses
|
189 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
190 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
191 |
+
|
192 |
+
codebook_indices.append(indices_i)
|
193 |
+
latents.append(z_e_i)
|
194 |
+
|
195 |
+
codes = torch.stack(codebook_indices, dim=1)
|
196 |
+
latents = torch.cat(latents, dim=1)
|
197 |
+
|
198 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
199 |
+
for n in range(encodings.shape[1]):
|
200 |
+
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
201 |
+
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
202 |
+
))
|
203 |
+
|
204 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
205 |
+
|
206 |
+
def from_codes(self, codes: torch.Tensor):
|
207 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
208 |
+
Parameters
|
209 |
+
----------
|
210 |
+
codes : Tensor[B x N x T]
|
211 |
+
Quantized discrete representation of input
|
212 |
+
Returns
|
213 |
+
-------
|
214 |
+
Tensor[B x D x T]
|
215 |
+
Quantized continuous representation of input
|
216 |
+
"""
|
217 |
+
z_q = 0.0
|
218 |
+
z_p = []
|
219 |
+
n_codebooks = codes.shape[1]
|
220 |
+
for i in range(n_codebooks):
|
221 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
222 |
+
z_p.append(z_p_i)
|
223 |
+
|
224 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
225 |
+
z_q = z_q + z_q_i
|
226 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
227 |
+
|
228 |
+
def from_latents(self, latents: torch.Tensor):
|
229 |
+
"""Given the unquantized latents, reconstruct the
|
230 |
+
continuous representation after quantization.
|
231 |
+
|
232 |
+
Parameters
|
233 |
+
----------
|
234 |
+
latents : Tensor[B x N x T]
|
235 |
+
Continuous representation of input after projection
|
236 |
+
|
237 |
+
Returns
|
238 |
+
-------
|
239 |
+
Tensor[B x D x T]
|
240 |
+
Quantized representation of full-projected space
|
241 |
+
Tensor[B x D x T]
|
242 |
+
Quantized representation of latent space
|
243 |
+
"""
|
244 |
+
z_q = 0
|
245 |
+
z_p = []
|
246 |
+
codes = []
|
247 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
248 |
+
|
249 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
250 |
+
0
|
251 |
+
]
|
252 |
+
for i in range(n_codebooks):
|
253 |
+
j, k = dims[i], dims[i + 1]
|
254 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
255 |
+
z_p.append(z_p_i)
|
256 |
+
codes.append(codes_i)
|
257 |
+
|
258 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
259 |
+
z_q = z_q + z_q_i
|
260 |
+
|
261 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == "__main__":
|
265 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
266 |
+
x = torch.randn(16, 512, 80)
|
267 |
+
y = rvq(x)
|
268 |
+
print(y["latents"].shape)
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
+
def WNConv1d(*args, **kwargs):
|
11 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
12 |
+
|
13 |
+
class VectorQuantize(nn.Module):
|
14 |
+
"""
|
15 |
+
Implementation of VQ similar to Karpathy's repo:
|
16 |
+
https://github.com/karpathy/deep-vector-quantization
|
17 |
+
Additionally uses following tricks from Improved VQGAN
|
18 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
19 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
20 |
+
for improved codebook usage
|
21 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
22 |
+
improves training stability
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
26 |
+
super().__init__()
|
27 |
+
self.codebook_size = codebook_size
|
28 |
+
self.codebook_dim = codebook_dim
|
29 |
+
|
30 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
31 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
32 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
33 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
34 |
+
self.stale_tolerance = stale_tolerance
|
35 |
+
|
36 |
+
def forward(self, z):
|
37 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
38 |
+
the corresponding codebook vectors
|
39 |
+
|
40 |
+
Parameters
|
41 |
+
----------
|
42 |
+
z : Tensor[B x D x T]
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
Tensor[B x D x T]
|
47 |
+
Quantized continuous representation of input
|
48 |
+
Tensor[1]
|
49 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
50 |
+
entries
|
51 |
+
Tensor[1]
|
52 |
+
Codebook loss to update the codebook
|
53 |
+
Tensor[B x T]
|
54 |
+
Codebook indices (quantized discrete representation of input)
|
55 |
+
Tensor[B x D x T]
|
56 |
+
Projected latents (continuous representation of input before quantization)
|
57 |
+
"""
|
58 |
+
|
59 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
60 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
61 |
+
z_q, indices = self.decode_latents(z_e)
|
62 |
+
|
63 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
64 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
65 |
+
|
66 |
+
z_q = (
|
67 |
+
z_e + (z_q - z_e).detach()
|
68 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
69 |
+
|
70 |
+
z_q = self.out_proj(z_q)
|
71 |
+
|
72 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
73 |
+
|
74 |
+
def embed_code(self, embed_id):
|
75 |
+
return F.embedding(embed_id, self.codebook.weight)
|
76 |
+
|
77 |
+
def decode_code(self, embed_id):
|
78 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
79 |
+
|
80 |
+
def decode_latents(self, latents):
|
81 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
82 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
83 |
+
|
84 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
85 |
+
encodings = F.normalize(encodings)
|
86 |
+
codebook = F.normalize(codebook)
|
87 |
+
|
88 |
+
# Compute euclidean distance with codebook
|
89 |
+
dist = (
|
90 |
+
encodings.pow(2).sum(1, keepdim=True)
|
91 |
+
- 2 * encodings @ codebook.t()
|
92 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
93 |
+
)
|
94 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
95 |
+
z_q = self.decode_code(indices)
|
96 |
+
|
97 |
+
if(self.training):
|
98 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
99 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
100 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
101 |
+
|
102 |
+
# random replace codes that haven't been used for a while
|
103 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
104 |
+
if replace_code.sum(-1) > 0:
|
105 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
106 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
107 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
108 |
+
if random_input.shape[0] < self.codebook_size:
|
109 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
110 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
111 |
+
|
112 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
113 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
114 |
+
|
115 |
+
return z_q, indices
|
116 |
+
|
117 |
+
|
118 |
+
class ResidualVectorQuantize(nn.Module):
|
119 |
+
"""
|
120 |
+
Introduced in SoundStream: An end2end neural audio codec
|
121 |
+
https://arxiv.org/abs/2107.03312
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
input_dim: int = 512,
|
127 |
+
n_codebooks: int = 9,
|
128 |
+
codebook_size: int = 1024,
|
129 |
+
codebook_dim: Union[int, list] = 8,
|
130 |
+
quantizer_dropout: float = 0.0,
|
131 |
+
stale_tolerance: int = 100,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
if isinstance(codebook_dim, int):
|
135 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
136 |
+
|
137 |
+
self.n_codebooks = n_codebooks
|
138 |
+
self.codebook_dim = codebook_dim
|
139 |
+
self.codebook_size = codebook_size
|
140 |
+
|
141 |
+
self.quantizers = nn.ModuleList(
|
142 |
+
[
|
143 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
144 |
+
for i in range(n_codebooks)
|
145 |
+
]
|
146 |
+
)
|
147 |
+
self.quantizer_dropout = quantizer_dropout
|
148 |
+
|
149 |
+
def forward(self, z, n_quantizers: int = None):
|
150 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
151 |
+
the corresponding codebook vectors
|
152 |
+
Parameters
|
153 |
+
----------
|
154 |
+
z : Tensor[B x D x T]
|
155 |
+
n_quantizers : int, optional
|
156 |
+
No. of quantizers to use
|
157 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
158 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
159 |
+
when in training mode, and a random number of quantizers is used.
|
160 |
+
Returns
|
161 |
+
-------
|
162 |
+
dict
|
163 |
+
A dictionary with the following keys:
|
164 |
+
|
165 |
+
"z" : Tensor[B x D x T]
|
166 |
+
Quantized continuous representation of input
|
167 |
+
"codes" : Tensor[B x N x T]
|
168 |
+
Codebook indices for each codebook
|
169 |
+
(quantized discrete representation of input)
|
170 |
+
"latents" : Tensor[B x N*D x T]
|
171 |
+
Projected latents (continuous representation of input before quantization)
|
172 |
+
"vq/commitment_loss" : Tensor[1]
|
173 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
174 |
+
entries
|
175 |
+
"vq/codebook_loss" : Tensor[1]
|
176 |
+
Codebook loss to update the codebook
|
177 |
+
"""
|
178 |
+
z_q = 0
|
179 |
+
residual = z
|
180 |
+
commitment_loss = 0
|
181 |
+
codebook_loss = 0
|
182 |
+
|
183 |
+
codebook_indices = []
|
184 |
+
latents = []
|
185 |
+
|
186 |
+
if n_quantizers is None:
|
187 |
+
n_quantizers = self.n_codebooks
|
188 |
+
if self.training:
|
189 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
190 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
191 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
192 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
193 |
+
n_quantizers = n_quantizers.to(z.device)
|
194 |
+
|
195 |
+
for i, quantizer in enumerate(self.quantizers):
|
196 |
+
if self.training is False and i >= n_quantizers:
|
197 |
+
break
|
198 |
+
|
199 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
200 |
+
residual
|
201 |
+
)
|
202 |
+
|
203 |
+
# Create mask to apply quantizer dropout
|
204 |
+
mask = (
|
205 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
206 |
+
)
|
207 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
208 |
+
residual = residual - z_q_i
|
209 |
+
|
210 |
+
# Sum losses
|
211 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
212 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
213 |
+
|
214 |
+
codebook_indices.append(indices_i)
|
215 |
+
latents.append(z_e_i)
|
216 |
+
|
217 |
+
codes = torch.stack(codebook_indices, dim=1)
|
218 |
+
latents = torch.cat(latents, dim=1)
|
219 |
+
|
220 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
221 |
+
for n in range(encodings.shape[1]):
|
222 |
+
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
223 |
+
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
224 |
+
))
|
225 |
+
|
226 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
227 |
+
|
228 |
+
def from_codes(self, codes: torch.Tensor):
|
229 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
230 |
+
Parameters
|
231 |
+
----------
|
232 |
+
codes : Tensor[B x N x T]
|
233 |
+
Quantized discrete representation of input
|
234 |
+
Returns
|
235 |
+
-------
|
236 |
+
Tensor[B x D x T]
|
237 |
+
Quantized continuous representation of input
|
238 |
+
"""
|
239 |
+
z_q = 0.0
|
240 |
+
z_p = []
|
241 |
+
n_codebooks = codes.shape[1]
|
242 |
+
for i in range(n_codebooks):
|
243 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
244 |
+
z_p.append(z_p_i)
|
245 |
+
|
246 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
247 |
+
z_q = z_q + z_q_i
|
248 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
249 |
+
|
250 |
+
def from_latents(self, latents: torch.Tensor):
|
251 |
+
"""Given the unquantized latents, reconstruct the
|
252 |
+
continuous representation after quantization.
|
253 |
+
|
254 |
+
Parameters
|
255 |
+
----------
|
256 |
+
latents : Tensor[B x N x T]
|
257 |
+
Continuous representation of input after projection
|
258 |
+
|
259 |
+
Returns
|
260 |
+
-------
|
261 |
+
Tensor[B x D x T]
|
262 |
+
Quantized representation of full-projected space
|
263 |
+
Tensor[B x D x T]
|
264 |
+
Quantized representation of latent space
|
265 |
+
"""
|
266 |
+
z_q = 0
|
267 |
+
z_p = []
|
268 |
+
codes = []
|
269 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
270 |
+
|
271 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
272 |
+
0
|
273 |
+
]
|
274 |
+
for i in range(n_codebooks):
|
275 |
+
j, k = dims[i], dims[i + 1]
|
276 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
277 |
+
z_p.append(z_p_i)
|
278 |
+
codes.append(codes_i)
|
279 |
+
|
280 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
281 |
+
z_q = z_q + z_q_i
|
282 |
+
|
283 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
rvq = ResidualVectorQuantize(quantizer_dropout=True)
|
288 |
+
x = torch.randn(16, 512, 80)
|
289 |
+
y = rvq(x)
|
290 |
+
print(y["latents"].shape)
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compared with `descript_quantize2`, we use rvq & random_dropout
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
def WNConv1d(*args, **kwargs):
|
12 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
13 |
+
|
14 |
+
class VectorQuantize(nn.Module):
|
15 |
+
"""
|
16 |
+
Implementation of VQ similar to Karpathy's repo:
|
17 |
+
https://github.com/karpathy/deep-vector-quantization
|
18 |
+
Additionally uses following tricks from Improved VQGAN
|
19 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
20 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
21 |
+
for improved codebook usage
|
22 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
23 |
+
improves training stability
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
27 |
+
super().__init__()
|
28 |
+
self.codebook_size = codebook_size
|
29 |
+
self.codebook_dim = codebook_dim
|
30 |
+
|
31 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
32 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
33 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
34 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
35 |
+
self.stale_tolerance = stale_tolerance
|
36 |
+
|
37 |
+
def forward(self, z):
|
38 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
39 |
+
the corresponding codebook vectors
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
z : Tensor[B x D x T]
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
Tensor[B x D x T]
|
48 |
+
Quantized continuous representation of input
|
49 |
+
Tensor[1]
|
50 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
51 |
+
entries
|
52 |
+
Tensor[1]
|
53 |
+
Codebook loss to update the codebook
|
54 |
+
Tensor[B x T]
|
55 |
+
Codebook indices (quantized discrete representation of input)
|
56 |
+
Tensor[B x D x T]
|
57 |
+
Projected latents (continuous representation of input before quantization)
|
58 |
+
"""
|
59 |
+
|
60 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
61 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
62 |
+
z_q, indices = self.decode_latents(z_e)
|
63 |
+
|
64 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
65 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
66 |
+
|
67 |
+
z_q = (
|
68 |
+
z_e + (z_q - z_e).detach()
|
69 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
70 |
+
|
71 |
+
z_q = self.out_proj(z_q)
|
72 |
+
|
73 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
74 |
+
|
75 |
+
def embed_code(self, embed_id):
|
76 |
+
return F.embedding(embed_id, self.codebook.weight)
|
77 |
+
|
78 |
+
def decode_code(self, embed_id):
|
79 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
80 |
+
|
81 |
+
def decode_latents(self, latents):
|
82 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
83 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
84 |
+
|
85 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
86 |
+
encodings = F.normalize(encodings)
|
87 |
+
codebook = F.normalize(codebook)
|
88 |
+
|
89 |
+
# Compute euclidean distance with codebook
|
90 |
+
dist = (
|
91 |
+
encodings.pow(2).sum(1, keepdim=True)
|
92 |
+
- 2 * encodings @ codebook.t()
|
93 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
94 |
+
)
|
95 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
96 |
+
z_q = self.decode_code(indices)
|
97 |
+
|
98 |
+
if(self.training):
|
99 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
100 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
101 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
102 |
+
|
103 |
+
# random replace codes that haven't been used for a while
|
104 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
105 |
+
if replace_code.sum(-1) > 0:
|
106 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
107 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
108 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
109 |
+
if random_input.shape[0] < self.codebook_size:
|
110 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
111 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
112 |
+
|
113 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
114 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
115 |
+
|
116 |
+
return z_q, indices
|
117 |
+
|
118 |
+
|
119 |
+
class ResidualVectorQuantize(nn.Module):
|
120 |
+
"""
|
121 |
+
Introduced in SoundStream: An end2end neural audio codec
|
122 |
+
https://arxiv.org/abs/2107.03312
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
input_dim: int = 512,
|
128 |
+
n_codebooks: int = 9,
|
129 |
+
codebook_size: int = 1024,
|
130 |
+
codebook_dim: Union[int, list] = 8,
|
131 |
+
quantizer_dropout: float = 0.0,
|
132 |
+
stale_tolerance: int = 100,
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
if isinstance(codebook_dim, int):
|
136 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
137 |
+
|
138 |
+
self.n_codebooks = n_codebooks
|
139 |
+
self.codebook_dim = codebook_dim
|
140 |
+
self.codebook_size = codebook_size
|
141 |
+
|
142 |
+
self.quantizers = nn.ModuleList(
|
143 |
+
[
|
144 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
145 |
+
for i in range(n_codebooks)
|
146 |
+
]
|
147 |
+
)
|
148 |
+
self.quantizer_dropout = quantizer_dropout
|
149 |
+
|
150 |
+
def forward(self, z, n_quantizers: int = None):
|
151 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
152 |
+
the corresponding codebook vectors
|
153 |
+
Parameters
|
154 |
+
----------
|
155 |
+
z : Tensor[B x D x T]
|
156 |
+
n_quantizers : int, optional
|
157 |
+
No. of quantizers to use
|
158 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
159 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
160 |
+
when in training mode, and a random number of quantizers is used.
|
161 |
+
Returns
|
162 |
+
-------
|
163 |
+
dict
|
164 |
+
A dictionary with the following keys:
|
165 |
+
|
166 |
+
"z" : Tensor[B x D x T]
|
167 |
+
Quantized continuous representation of input
|
168 |
+
"codes" : Tensor[B x N x T]
|
169 |
+
Codebook indices for each codebook
|
170 |
+
(quantized discrete representation of input)
|
171 |
+
"latents" : Tensor[B x N*D x T]
|
172 |
+
Projected latents (continuous representation of input before quantization)
|
173 |
+
"vq/commitment_loss" : Tensor[1]
|
174 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
175 |
+
entries
|
176 |
+
"vq/codebook_loss" : Tensor[1]
|
177 |
+
Codebook loss to update the codebook
|
178 |
+
"""
|
179 |
+
z_q = 0
|
180 |
+
residual = z
|
181 |
+
commitment_loss = 0
|
182 |
+
codebook_loss = 0
|
183 |
+
|
184 |
+
codebook_indices = []
|
185 |
+
latents = []
|
186 |
+
|
187 |
+
if n_quantizers is None:
|
188 |
+
n_quantizers = self.n_codebooks
|
189 |
+
if self.training:
|
190 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
191 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
192 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
193 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
194 |
+
n_quantizers = n_quantizers.to(z.device)
|
195 |
+
else:
|
196 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
|
197 |
+
n_quantizers = n_quantizers.to(z.device)
|
198 |
+
|
199 |
+
for i, quantizer in enumerate(self.quantizers):
|
200 |
+
# if self.training is False and i >= n_quantizers:
|
201 |
+
# break
|
202 |
+
|
203 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
204 |
+
residual
|
205 |
+
)
|
206 |
+
|
207 |
+
# Create mask to apply quantizer dropout
|
208 |
+
mask = (
|
209 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
210 |
+
)
|
211 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
212 |
+
residual = residual - z_q_i
|
213 |
+
|
214 |
+
# Sum losses
|
215 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
216 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
217 |
+
|
218 |
+
codebook_indices.append(indices_i)
|
219 |
+
latents.append(z_e_i)
|
220 |
+
|
221 |
+
codes = torch.stack(codebook_indices, dim=1)
|
222 |
+
latents = torch.cat(latents, dim=1)
|
223 |
+
|
224 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
225 |
+
# for n in range(encodings.shape[1]):
|
226 |
+
# print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
227 |
+
# (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
228 |
+
# ))
|
229 |
+
|
230 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
231 |
+
|
232 |
+
def from_codes(self, codes: torch.Tensor):
|
233 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
codes : Tensor[B x N x T]
|
237 |
+
Quantized discrete representation of input
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
Tensor[B x D x T]
|
241 |
+
Quantized continuous representation of input
|
242 |
+
"""
|
243 |
+
z_q = 0.0
|
244 |
+
z_p = []
|
245 |
+
n_codebooks = codes.shape[1]
|
246 |
+
for i in range(n_codebooks):
|
247 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
248 |
+
z_p.append(z_p_i)
|
249 |
+
|
250 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
251 |
+
z_q = z_q + z_q_i
|
252 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
253 |
+
|
254 |
+
def from_latents(self, latents: torch.Tensor):
|
255 |
+
"""Given the unquantized latents, reconstruct the
|
256 |
+
continuous representation after quantization.
|
257 |
+
|
258 |
+
Parameters
|
259 |
+
----------
|
260 |
+
latents : Tensor[B x N x T]
|
261 |
+
Continuous representation of input after projection
|
262 |
+
|
263 |
+
Returns
|
264 |
+
-------
|
265 |
+
Tensor[B x D x T]
|
266 |
+
Quantized representation of full-projected space
|
267 |
+
Tensor[B x D x T]
|
268 |
+
Quantized representation of latent space
|
269 |
+
"""
|
270 |
+
z_q = 0
|
271 |
+
z_p = []
|
272 |
+
codes = []
|
273 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
274 |
+
|
275 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
276 |
+
0
|
277 |
+
]
|
278 |
+
for i in range(n_codebooks):
|
279 |
+
j, k = dims[i], dims[i + 1]
|
280 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
281 |
+
z_p.append(z_p_i)
|
282 |
+
codes.append(codes_i)
|
283 |
+
|
284 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
285 |
+
z_q = z_q + z_q_i
|
286 |
+
|
287 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == "__main__":
|
291 |
+
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
292 |
+
x = torch.randn(16, 1024, 80)
|
293 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
294 |
+
print(quantized_prompt_embeds.shape)
|
295 |
+
print(codes.shape)
|
296 |
+
# w/o reconstruction
|
297 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
298 |
+
# w/ reconstruction
|
299 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compared with `descript_quantize2`, we use rvq & random_dropout
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
import random
|
11 |
+
|
12 |
+
def WNConv1d(*args, **kwargs):
|
13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
|
15 |
+
class VectorQuantize(nn.Module):
|
16 |
+
"""
|
17 |
+
Implementation of VQ similar to Karpathy's repo:
|
18 |
+
https://github.com/karpathy/deep-vector-quantization
|
19 |
+
Additionally uses following tricks from Improved VQGAN
|
20 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
21 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
22 |
+
for improved codebook usage
|
23 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
24 |
+
improves training stability
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
28 |
+
super().__init__()
|
29 |
+
self.codebook_size = codebook_size
|
30 |
+
self.codebook_dim = codebook_dim
|
31 |
+
|
32 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
33 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
34 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
35 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
36 |
+
self.stale_tolerance = stale_tolerance
|
37 |
+
|
38 |
+
def forward(self, z):
|
39 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
40 |
+
the corresponding codebook vectors
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
z : Tensor[B x D x T]
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
Tensor[B x D x T]
|
49 |
+
Quantized continuous representation of input
|
50 |
+
Tensor[1]
|
51 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
52 |
+
entries
|
53 |
+
Tensor[1]
|
54 |
+
Codebook loss to update the codebook
|
55 |
+
Tensor[B x T]
|
56 |
+
Codebook indices (quantized discrete representation of input)
|
57 |
+
Tensor[B x D x T]
|
58 |
+
Projected latents (continuous representation of input before quantization)
|
59 |
+
"""
|
60 |
+
|
61 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
62 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
63 |
+
z_q, indices = self.decode_latents(z_e)
|
64 |
+
|
65 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
66 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
67 |
+
|
68 |
+
z_q = (
|
69 |
+
z_e + (z_q - z_e).detach()
|
70 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
71 |
+
|
72 |
+
z_q = self.out_proj(z_q)
|
73 |
+
|
74 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
75 |
+
|
76 |
+
def embed_code(self, embed_id):
|
77 |
+
return F.embedding(embed_id, self.codebook.weight)
|
78 |
+
|
79 |
+
def decode_code(self, embed_id):
|
80 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
81 |
+
|
82 |
+
def decode_latents(self, latents):
|
83 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
84 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
85 |
+
|
86 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
87 |
+
encodings = F.normalize(encodings)
|
88 |
+
codebook = F.normalize(codebook)
|
89 |
+
|
90 |
+
# Compute euclidean distance with codebook
|
91 |
+
dist = (
|
92 |
+
encodings.pow(2).sum(1, keepdim=True)
|
93 |
+
- 2 * encodings @ codebook.t()
|
94 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
95 |
+
)
|
96 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
97 |
+
z_q = self.decode_code(indices)
|
98 |
+
|
99 |
+
if(self.training):
|
100 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
101 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
102 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
103 |
+
|
104 |
+
# random replace codes that haven't been used for a while
|
105 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
106 |
+
if replace_code.sum(-1) > 0:
|
107 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
108 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
109 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
110 |
+
if random_input.shape[0] < self.codebook_size:
|
111 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
112 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
113 |
+
|
114 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
115 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
116 |
+
|
117 |
+
return z_q, indices
|
118 |
+
|
119 |
+
|
120 |
+
class ResidualVectorQuantize(nn.Module):
|
121 |
+
"""
|
122 |
+
Introduced in SoundStream: An end2end neural audio codec
|
123 |
+
https://arxiv.org/abs/2107.03312
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
input_dim: int = 512,
|
129 |
+
n_codebooks: int = 9,
|
130 |
+
codebook_size: int = 1024,
|
131 |
+
codebook_dim: Union[int, list] = 8,
|
132 |
+
quantizer_dropout: float = 0.0,
|
133 |
+
stale_tolerance: int = 100,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
if isinstance(codebook_dim, int):
|
137 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
138 |
+
|
139 |
+
self.n_codebooks = n_codebooks
|
140 |
+
self.codebook_dim = codebook_dim
|
141 |
+
self.codebook_size = codebook_size
|
142 |
+
|
143 |
+
self.quantizers = nn.ModuleList(
|
144 |
+
[
|
145 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
146 |
+
for i in range(n_codebooks)
|
147 |
+
]
|
148 |
+
)
|
149 |
+
self.quantizer_dropout = quantizer_dropout
|
150 |
+
|
151 |
+
def forward(self, z, n_quantizers: int = None):
|
152 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
153 |
+
the corresponding codebook vectors
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
z : Tensor[B x D x T]
|
157 |
+
n_quantizers : int, optional
|
158 |
+
No. of quantizers to use
|
159 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
160 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
161 |
+
when in training mode, and a random number of quantizers is used.
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
dict
|
165 |
+
A dictionary with the following keys:
|
166 |
+
|
167 |
+
"z" : Tensor[B x D x T]
|
168 |
+
Quantized continuous representation of input
|
169 |
+
"codes" : Tensor[B x N x T]
|
170 |
+
Codebook indices for each codebook
|
171 |
+
(quantized discrete representation of input)
|
172 |
+
"latents" : Tensor[B x N*D x T]
|
173 |
+
Projected latents (continuous representation of input before quantization)
|
174 |
+
"vq/commitment_loss" : Tensor[1]
|
175 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
176 |
+
entries
|
177 |
+
"vq/codebook_loss" : Tensor[1]
|
178 |
+
Codebook loss to update the codebook
|
179 |
+
"""
|
180 |
+
z_q = 0
|
181 |
+
residual = z
|
182 |
+
commitment_loss = 0
|
183 |
+
codebook_loss = 0
|
184 |
+
|
185 |
+
codebook_indices = []
|
186 |
+
latents = []
|
187 |
+
|
188 |
+
if n_quantizers is None:
|
189 |
+
n_quantizers = self.n_codebooks
|
190 |
+
if self.training:
|
191 |
+
random_num = random.random()
|
192 |
+
if random_num<0.6:
|
193 |
+
n_quantizers = torch.ones((z.shape[0],)) * 1
|
194 |
+
elif random_num<0.8:
|
195 |
+
n_quantizers = torch.ones((z.shape[0],)) * 2
|
196 |
+
else:
|
197 |
+
n_quantizers = torch.ones((z.shape[0],)) * 4
|
198 |
+
n_quantizers = n_quantizers.to(z.device)
|
199 |
+
else:
|
200 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
|
201 |
+
n_quantizers = n_quantizers.to(z.device)
|
202 |
+
|
203 |
+
for i, quantizer in enumerate(self.quantizers):
|
204 |
+
# if self.training is False and i >= n_quantizers:
|
205 |
+
# break
|
206 |
+
|
207 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
208 |
+
residual
|
209 |
+
)
|
210 |
+
|
211 |
+
# Create mask to apply quantizer dropout
|
212 |
+
mask = (
|
213 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
214 |
+
)
|
215 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
216 |
+
residual = residual - z_q_i
|
217 |
+
|
218 |
+
# Sum losses
|
219 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
220 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
221 |
+
|
222 |
+
codebook_indices.append(indices_i)
|
223 |
+
latents.append(z_e_i)
|
224 |
+
|
225 |
+
codes = torch.stack(codebook_indices, dim=1)
|
226 |
+
latents = torch.cat(latents, dim=1)
|
227 |
+
|
228 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
229 |
+
for n in range(encodings.shape[1]):
|
230 |
+
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
231 |
+
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
232 |
+
))
|
233 |
+
|
234 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
235 |
+
|
236 |
+
def from_codes(self, codes: torch.Tensor):
|
237 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
238 |
+
Parameters
|
239 |
+
----------
|
240 |
+
codes : Tensor[B x N x T]
|
241 |
+
Quantized discrete representation of input
|
242 |
+
Returns
|
243 |
+
-------
|
244 |
+
Tensor[B x D x T]
|
245 |
+
Quantized continuous representation of input
|
246 |
+
"""
|
247 |
+
z_q = 0.0
|
248 |
+
z_p = []
|
249 |
+
n_codebooks = codes.shape[1]
|
250 |
+
for i in range(n_codebooks):
|
251 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
252 |
+
z_p.append(z_p_i)
|
253 |
+
|
254 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
255 |
+
z_q = z_q + z_q_i
|
256 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
257 |
+
|
258 |
+
def from_latents(self, latents: torch.Tensor):
|
259 |
+
"""Given the unquantized latents, reconstruct the
|
260 |
+
continuous representation after quantization.
|
261 |
+
|
262 |
+
Parameters
|
263 |
+
----------
|
264 |
+
latents : Tensor[B x N x T]
|
265 |
+
Continuous representation of input after projection
|
266 |
+
|
267 |
+
Returns
|
268 |
+
-------
|
269 |
+
Tensor[B x D x T]
|
270 |
+
Quantized representation of full-projected space
|
271 |
+
Tensor[B x D x T]
|
272 |
+
Quantized representation of latent space
|
273 |
+
"""
|
274 |
+
z_q = 0
|
275 |
+
z_p = []
|
276 |
+
codes = []
|
277 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
278 |
+
|
279 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
280 |
+
0
|
281 |
+
]
|
282 |
+
for i in range(n_codebooks):
|
283 |
+
j, k = dims[i], dims[i + 1]
|
284 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
285 |
+
z_p.append(z_p_i)
|
286 |
+
codes.append(codes_i)
|
287 |
+
|
288 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
289 |
+
z_q = z_q + z_q_i
|
290 |
+
|
291 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
296 |
+
x = torch.randn(16, 1024, 80)
|
297 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
298 |
+
print(quantized_prompt_embeds.shape)
|
299 |
+
print(codes.shape)
|
300 |
+
# w/o reconstruction
|
301 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
302 |
+
# w/ reconstruction
|
303 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compared with `descript_quantize2`, we use rvq & random_dropout
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
import random
|
11 |
+
|
12 |
+
def WNConv1d(*args, **kwargs):
|
13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
|
15 |
+
class VectorQuantize(nn.Module):
|
16 |
+
"""
|
17 |
+
Implementation of VQ similar to Karpathy's repo:
|
18 |
+
https://github.com/karpathy/deep-vector-quantization
|
19 |
+
Additionally uses following tricks from Improved VQGAN
|
20 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
21 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
22 |
+
for improved codebook usage
|
23 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
24 |
+
improves training stability
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
28 |
+
super().__init__()
|
29 |
+
self.codebook_size = codebook_size
|
30 |
+
self.codebook_dim = codebook_dim
|
31 |
+
|
32 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
33 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
34 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
35 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
36 |
+
self.stale_tolerance = stale_tolerance
|
37 |
+
|
38 |
+
def forward(self, z):
|
39 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
40 |
+
the corresponding codebook vectors
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
z : Tensor[B x D x T]
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
Tensor[B x D x T]
|
49 |
+
Quantized continuous representation of input
|
50 |
+
Tensor[1]
|
51 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
52 |
+
entries
|
53 |
+
Tensor[1]
|
54 |
+
Codebook loss to update the codebook
|
55 |
+
Tensor[B x T]
|
56 |
+
Codebook indices (quantized discrete representation of input)
|
57 |
+
Tensor[B x D x T]
|
58 |
+
Projected latents (continuous representation of input before quantization)
|
59 |
+
"""
|
60 |
+
|
61 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
62 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
63 |
+
z_q, indices = self.decode_latents(z_e)
|
64 |
+
|
65 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
66 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
67 |
+
|
68 |
+
z_q = (
|
69 |
+
z_e + (z_q - z_e).detach()
|
70 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
71 |
+
|
72 |
+
z_q = self.out_proj(z_q)
|
73 |
+
|
74 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
75 |
+
|
76 |
+
def embed_code(self, embed_id):
|
77 |
+
return F.embedding(embed_id, self.codebook.weight)
|
78 |
+
|
79 |
+
def decode_code(self, embed_id):
|
80 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
81 |
+
|
82 |
+
def decode_latents(self, latents):
|
83 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
84 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
85 |
+
|
86 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
87 |
+
encodings = F.normalize(encodings)
|
88 |
+
codebook = F.normalize(codebook)
|
89 |
+
|
90 |
+
# Compute euclidean distance with codebook
|
91 |
+
dist = (
|
92 |
+
encodings.pow(2).sum(1, keepdim=True)
|
93 |
+
- 2 * encodings @ codebook.t()
|
94 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
95 |
+
)
|
96 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
97 |
+
z_q = self.decode_code(indices)
|
98 |
+
|
99 |
+
if(self.training):
|
100 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
101 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
102 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
103 |
+
|
104 |
+
# random replace codes that haven't been used for a while
|
105 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
106 |
+
if replace_code.sum(-1) > 0:
|
107 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
108 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
109 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
110 |
+
if random_input.shape[0] < self.codebook_size:
|
111 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
112 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
113 |
+
|
114 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
115 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
116 |
+
|
117 |
+
return z_q, indices
|
118 |
+
|
119 |
+
|
120 |
+
class ResidualVectorQuantize(nn.Module):
|
121 |
+
"""
|
122 |
+
Introduced in SoundStream: An end2end neural audio codec
|
123 |
+
https://arxiv.org/abs/2107.03312
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
input_dim: int = 512,
|
129 |
+
n_codebooks: int = 9,
|
130 |
+
codebook_size: int = 1024,
|
131 |
+
codebook_dim: Union[int, list] = 8,
|
132 |
+
quantizer_dropout: float = 0.0,
|
133 |
+
stale_tolerance: int = 100,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
if isinstance(codebook_dim, int):
|
137 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
138 |
+
|
139 |
+
self.n_codebooks = n_codebooks
|
140 |
+
self.codebook_dim = codebook_dim
|
141 |
+
self.codebook_size = codebook_size
|
142 |
+
|
143 |
+
self.quantizers = nn.ModuleList(
|
144 |
+
[
|
145 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
146 |
+
for i in range(n_codebooks)
|
147 |
+
]
|
148 |
+
)
|
149 |
+
self.quantizer_dropout = quantizer_dropout
|
150 |
+
|
151 |
+
def forward(self, z, n_quantizers: int = None):
|
152 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
153 |
+
the corresponding codebook vectors
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
z : Tensor[B x D x T]
|
157 |
+
n_quantizers : int, optional
|
158 |
+
No. of quantizers to use
|
159 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
160 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
161 |
+
when in training mode, and a random number of quantizers is used.
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
dict
|
165 |
+
A dictionary with the following keys:
|
166 |
+
|
167 |
+
"z" : Tensor[B x D x T]
|
168 |
+
Quantized continuous representation of input
|
169 |
+
"codes" : Tensor[B x N x T]
|
170 |
+
Codebook indices for each codebook
|
171 |
+
(quantized discrete representation of input)
|
172 |
+
"latents" : Tensor[B x N*D x T]
|
173 |
+
Projected latents (continuous representation of input before quantization)
|
174 |
+
"vq/commitment_loss" : Tensor[1]
|
175 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
176 |
+
entries
|
177 |
+
"vq/codebook_loss" : Tensor[1]
|
178 |
+
Codebook loss to update the codebook
|
179 |
+
"""
|
180 |
+
z_q = 0
|
181 |
+
residual = z
|
182 |
+
commitment_loss = 0
|
183 |
+
codebook_loss = 0
|
184 |
+
|
185 |
+
codebook_indices = []
|
186 |
+
latents = []
|
187 |
+
|
188 |
+
if n_quantizers is None:
|
189 |
+
n_quantizers = self.n_codebooks
|
190 |
+
if self.training:
|
191 |
+
random_num = random.random()
|
192 |
+
if random_num<0.6:
|
193 |
+
n_quantizers = torch.ones((z.shape[0],)) * 2
|
194 |
+
else:
|
195 |
+
n_quantizers = torch.ones((z.shape[0],)) * 4
|
196 |
+
n_quantizers = n_quantizers.to(z.device)
|
197 |
+
else:
|
198 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
|
199 |
+
n_quantizers = n_quantizers.to(z.device)
|
200 |
+
|
201 |
+
for i, quantizer in enumerate(self.quantizers):
|
202 |
+
# if self.training is False and i >= n_quantizers:
|
203 |
+
# break
|
204 |
+
|
205 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
206 |
+
residual
|
207 |
+
)
|
208 |
+
|
209 |
+
# Create mask to apply quantizer dropout
|
210 |
+
mask = (
|
211 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
212 |
+
)
|
213 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
214 |
+
residual = residual - z_q_i
|
215 |
+
|
216 |
+
# Sum losses
|
217 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
218 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
219 |
+
|
220 |
+
codebook_indices.append(indices_i)
|
221 |
+
latents.append(z_e_i)
|
222 |
+
|
223 |
+
codes = torch.stack(codebook_indices, dim=1)
|
224 |
+
latents = torch.cat(latents, dim=1)
|
225 |
+
|
226 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
227 |
+
# for n in range(encodings.shape[1]):
|
228 |
+
# print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
229 |
+
# (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
230 |
+
# ))
|
231 |
+
|
232 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
|
233 |
+
|
234 |
+
def from_codes(self, codes: torch.Tensor):
|
235 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
236 |
+
Parameters
|
237 |
+
----------
|
238 |
+
codes : Tensor[B x N x T]
|
239 |
+
Quantized discrete representation of input
|
240 |
+
Returns
|
241 |
+
-------
|
242 |
+
Tensor[B x D x T]
|
243 |
+
Quantized continuous representation of input
|
244 |
+
"""
|
245 |
+
z_q = 0.0
|
246 |
+
z_p = []
|
247 |
+
n_codebooks = codes.shape[1]
|
248 |
+
for i in range(n_codebooks):
|
249 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
250 |
+
z_p.append(z_p_i)
|
251 |
+
|
252 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
253 |
+
z_q = z_q + z_q_i
|
254 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
255 |
+
|
256 |
+
def from_latents(self, latents: torch.Tensor):
|
257 |
+
"""Given the unquantized latents, reconstruct the
|
258 |
+
continuous representation after quantization.
|
259 |
+
|
260 |
+
Parameters
|
261 |
+
----------
|
262 |
+
latents : Tensor[B x N x T]
|
263 |
+
Continuous representation of input after projection
|
264 |
+
|
265 |
+
Returns
|
266 |
+
-------
|
267 |
+
Tensor[B x D x T]
|
268 |
+
Quantized representation of full-projected space
|
269 |
+
Tensor[B x D x T]
|
270 |
+
Quantized representation of latent space
|
271 |
+
"""
|
272 |
+
z_q = 0
|
273 |
+
z_p = []
|
274 |
+
codes = []
|
275 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
276 |
+
|
277 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
278 |
+
0
|
279 |
+
]
|
280 |
+
for i in range(n_codebooks):
|
281 |
+
j, k = dims[i], dims[i + 1]
|
282 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
283 |
+
z_p.append(z_p_i)
|
284 |
+
codes.append(codes_i)
|
285 |
+
|
286 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
287 |
+
z_q = z_q + z_q_i
|
288 |
+
|
289 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
290 |
+
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
294 |
+
x = torch.randn(16, 1024, 80)
|
295 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
296 |
+
print(quantized_prompt_embeds.shape)
|
297 |
+
print(codes.shape)
|
298 |
+
# w/o reconstruction
|
299 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
300 |
+
# w/ reconstruction
|
301 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|
codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compared with `descript_quantize2`, we use rvq & random_dropout
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
import random
|
11 |
+
|
12 |
+
def WNConv1d(*args, **kwargs):
|
13 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
14 |
+
|
15 |
+
class VectorQuantize(nn.Module):
|
16 |
+
"""
|
17 |
+
Implementation of VQ similar to Karpathy's repo:
|
18 |
+
https://github.com/karpathy/deep-vector-quantization
|
19 |
+
Additionally uses following tricks from Improved VQGAN
|
20 |
+
(https://arxiv.org/pdf/2110.04627.pdf):
|
21 |
+
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
22 |
+
for improved codebook usage
|
23 |
+
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
24 |
+
improves training stability
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100):
|
28 |
+
super().__init__()
|
29 |
+
self.codebook_size = codebook_size
|
30 |
+
self.codebook_dim = codebook_dim
|
31 |
+
|
32 |
+
self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
|
33 |
+
self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
|
34 |
+
self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
35 |
+
self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
|
36 |
+
self.stale_tolerance = stale_tolerance
|
37 |
+
|
38 |
+
def forward(self, z):
|
39 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
40 |
+
the corresponding codebook vectors
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
z : Tensor[B x D x T]
|
45 |
+
|
46 |
+
Returns
|
47 |
+
-------
|
48 |
+
Tensor[B x D x T]
|
49 |
+
Quantized continuous representation of input
|
50 |
+
Tensor[1]
|
51 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
52 |
+
entries
|
53 |
+
Tensor[1]
|
54 |
+
Codebook loss to update the codebook
|
55 |
+
Tensor[B x T]
|
56 |
+
Codebook indices (quantized discrete representation of input)
|
57 |
+
Tensor[B x D x T]
|
58 |
+
Projected latents (continuous representation of input before quantization)
|
59 |
+
"""
|
60 |
+
|
61 |
+
# Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
62 |
+
z_e = self.in_proj(z) # z_e : (B x D x T)
|
63 |
+
z_q, indices = self.decode_latents(z_e)
|
64 |
+
|
65 |
+
commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
66 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
67 |
+
|
68 |
+
z_q = (
|
69 |
+
z_e + (z_q - z_e).detach()
|
70 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
71 |
+
|
72 |
+
z_q = self.out_proj(z_q)
|
73 |
+
|
74 |
+
return z_q, commitment_loss, codebook_loss, indices, z_e
|
75 |
+
|
76 |
+
def embed_code(self, embed_id):
|
77 |
+
return F.embedding(embed_id, self.codebook.weight)
|
78 |
+
|
79 |
+
def decode_code(self, embed_id):
|
80 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
81 |
+
|
82 |
+
def decode_latents(self, latents):
|
83 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
84 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
85 |
+
|
86 |
+
# L2 normalize encodings and codebook (ViT-VQGAN)
|
87 |
+
encodings = F.normalize(encodings)
|
88 |
+
codebook = F.normalize(codebook)
|
89 |
+
|
90 |
+
# Compute euclidean distance with codebook
|
91 |
+
dist = (
|
92 |
+
encodings.pow(2).sum(1, keepdim=True)
|
93 |
+
- 2 * encodings @ codebook.t()
|
94 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
95 |
+
)
|
96 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
97 |
+
z_q = self.decode_code(indices)
|
98 |
+
|
99 |
+
if(self.training):
|
100 |
+
onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
|
101 |
+
stale_codes = (onehots.sum(0).sum(0) == 0).float()
|
102 |
+
self.stale_counter = self.stale_counter * stale_codes + stale_codes
|
103 |
+
|
104 |
+
# random replace codes that haven't been used for a while
|
105 |
+
replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
|
106 |
+
if replace_code.sum(-1) > 0:
|
107 |
+
print("Replace {} codes".format(replace_code.sum(-1)))
|
108 |
+
random_input_idx = torch.randperm(encodings.shape[0])
|
109 |
+
random_input = encodings[random_input_idx].view(encodings.shape)
|
110 |
+
if random_input.shape[0] < self.codebook_size:
|
111 |
+
random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
|
112 |
+
random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
|
113 |
+
|
114 |
+
self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
|
115 |
+
self.stale_counter = self.stale_counter * (1 - replace_code)
|
116 |
+
|
117 |
+
return z_q, indices
|
118 |
+
|
119 |
+
|
120 |
+
class ResidualVectorQuantize(nn.Module):
|
121 |
+
"""
|
122 |
+
Introduced in SoundStream: An end2end neural audio codec
|
123 |
+
https://arxiv.org/abs/2107.03312
|
124 |
+
"""
|
125 |
+
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
input_dim: int = 512,
|
129 |
+
n_codebooks: int = 9,
|
130 |
+
codebook_size: int = 1024,
|
131 |
+
codebook_dim: Union[int, list] = 8,
|
132 |
+
quantizer_dropout: float = 0.0,
|
133 |
+
stale_tolerance: int = 100,
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
if isinstance(codebook_dim, int):
|
137 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
138 |
+
|
139 |
+
self.n_codebooks = n_codebooks
|
140 |
+
self.codebook_dim = codebook_dim
|
141 |
+
self.codebook_size = codebook_size
|
142 |
+
|
143 |
+
self.quantizers = nn.ModuleList(
|
144 |
+
[
|
145 |
+
VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance)
|
146 |
+
for i in range(n_codebooks)
|
147 |
+
]
|
148 |
+
)
|
149 |
+
self.quantizer_dropout = quantizer_dropout
|
150 |
+
|
151 |
+
def forward(self, z, n_quantizers: int = None):
|
152 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
153 |
+
the corresponding codebook vectors
|
154 |
+
Parameters
|
155 |
+
----------
|
156 |
+
z : Tensor[B x D x T]
|
157 |
+
n_quantizers : int, optional
|
158 |
+
No. of quantizers to use
|
159 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
160 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
161 |
+
when in training mode, and a random number of quantizers is used.
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
dict
|
165 |
+
A dictionary with the following keys:
|
166 |
+
|
167 |
+
"z" : Tensor[B x D x T]
|
168 |
+
Quantized continuous representation of input
|
169 |
+
"codes" : Tensor[B x N x T]
|
170 |
+
Codebook indices for each codebook
|
171 |
+
(quantized discrete representation of input)
|
172 |
+
"latents" : Tensor[B x N*D x T]
|
173 |
+
Projected latents (continuous representation of input before quantization)
|
174 |
+
"vq/commitment_loss" : Tensor[1]
|
175 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
176 |
+
entries
|
177 |
+
"vq/codebook_loss" : Tensor[1]
|
178 |
+
Codebook loss to update the codebook
|
179 |
+
"""
|
180 |
+
z_q = 0
|
181 |
+
residual = z
|
182 |
+
commitment_loss = 0
|
183 |
+
codebook_loss = 0
|
184 |
+
layer = self.n_codebooks
|
185 |
+
codebook_indices = []
|
186 |
+
latents = []
|
187 |
+
|
188 |
+
if n_quantizers is None:
|
189 |
+
n_quantizers = self.n_codebooks
|
190 |
+
if self.training:
|
191 |
+
random_num = random.random()
|
192 |
+
if random_num<0.6:
|
193 |
+
n_quantizers = torch.ones((z.shape[0],)) * 1
|
194 |
+
elif random_num<0.8:
|
195 |
+
n_quantizers = torch.ones((z.shape[0],)) * 2
|
196 |
+
layer = 2
|
197 |
+
else:
|
198 |
+
n_quantizers = torch.ones((z.shape[0],)) * 4
|
199 |
+
layer = 4
|
200 |
+
n_quantizers = n_quantizers.to(z.device)
|
201 |
+
else:
|
202 |
+
n_quantizers = torch.ones((z.shape[0],)) * n_quantizers
|
203 |
+
n_quantizers = n_quantizers.to(z.device)
|
204 |
+
|
205 |
+
for i, quantizer in enumerate(self.quantizers):
|
206 |
+
# if self.training is False and i >= n_quantizers:
|
207 |
+
# break
|
208 |
+
|
209 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
|
210 |
+
residual
|
211 |
+
)
|
212 |
+
|
213 |
+
# Create mask to apply quantizer dropout
|
214 |
+
mask = (
|
215 |
+
torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
216 |
+
)
|
217 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
218 |
+
residual = residual - z_q_i
|
219 |
+
|
220 |
+
# Sum losses
|
221 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
222 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
223 |
+
|
224 |
+
codebook_indices.append(indices_i)
|
225 |
+
latents.append(z_e_i)
|
226 |
+
|
227 |
+
codes = torch.stack(codebook_indices, dim=1)
|
228 |
+
latents = torch.cat(latents, dim=1)
|
229 |
+
|
230 |
+
encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
|
231 |
+
for n in range(encodings.shape[1]):
|
232 |
+
print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n,
|
233 |
+
(encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100.
|
234 |
+
))
|
235 |
+
|
236 |
+
return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1,layer
|
237 |
+
|
238 |
+
def from_codes(self, codes: torch.Tensor):
|
239 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
240 |
+
Parameters
|
241 |
+
----------
|
242 |
+
codes : Tensor[B x N x T]
|
243 |
+
Quantized discrete representation of input
|
244 |
+
Returns
|
245 |
+
-------
|
246 |
+
Tensor[B x D x T]
|
247 |
+
Quantized continuous representation of input
|
248 |
+
"""
|
249 |
+
z_q = 0.0
|
250 |
+
z_p = []
|
251 |
+
n_codebooks = codes.shape[1]
|
252 |
+
for i in range(n_codebooks):
|
253 |
+
z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
|
254 |
+
z_p.append(z_p_i)
|
255 |
+
|
256 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
257 |
+
z_q = z_q + z_q_i
|
258 |
+
return z_q, torch.cat(z_p, dim=1), codes
|
259 |
+
|
260 |
+
def from_latents(self, latents: torch.Tensor):
|
261 |
+
"""Given the unquantized latents, reconstruct the
|
262 |
+
continuous representation after quantization.
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
latents : Tensor[B x N x T]
|
267 |
+
Continuous representation of input after projection
|
268 |
+
|
269 |
+
Returns
|
270 |
+
-------
|
271 |
+
Tensor[B x D x T]
|
272 |
+
Quantized representation of full-projected space
|
273 |
+
Tensor[B x D x T]
|
274 |
+
Quantized representation of latent space
|
275 |
+
"""
|
276 |
+
z_q = 0
|
277 |
+
z_p = []
|
278 |
+
codes = []
|
279 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
280 |
+
|
281 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
|
282 |
+
0
|
283 |
+
]
|
284 |
+
for i in range(n_codebooks):
|
285 |
+
j, k = dims[i], dims[i + 1]
|
286 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
287 |
+
z_p.append(z_p_i)
|
288 |
+
codes.append(codes_i)
|
289 |
+
|
290 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
291 |
+
z_q = z_q + z_q_i
|
292 |
+
|
293 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0)
|
298 |
+
x = torch.randn(16, 1024, 80)
|
299 |
+
quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x)
|
300 |
+
print(quantized_prompt_embeds.shape)
|
301 |
+
print(codes.shape)
|
302 |
+
# w/o reconstruction
|
303 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0
|
304 |
+
# w/ reconstruction
|
305 |
+
loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
|